├── .env-example ├── .gitignore ├── .vscode └── launch.json ├── readme.md ├── requirements.txt ├── src ├── __init__.py ├── api │ ├── __init__.py │ ├── auth.py │ ├── config.py │ ├── database.py │ ├── main.py │ ├── models.py │ └── tasks.py ├── tests │ ├── __init__.py │ ├── test_auth.py │ ├── test_database.py │ ├── test_file_utils.py │ ├── test_main.py │ ├── test_tasks.py │ └── test_transcription_utils.py └── utils │ ├── __init__.py │ ├── file_utils.py │ └── transcription_utils.py └── start.py /.env-example: -------------------------------------------------------------------------------- 1 | API_PORT=11300 2 | HUGGING_FACE_TOKEN= 3 | SECRET_KEY= 4 | MASTER_KEY= 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | data/ 162 | data-old/ 163 | transcriptions/ 164 | transcriptions-old/ 165 | temp/ 166 | users.db 167 | celery.db 168 | start.sh 169 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.2.0", 3 | "configurations": [ 4 | { 5 | "name": "Debug pytest", 6 | "type": "python", 7 | "request": "launch", 8 | "module": "pytest", 9 | "cwd": "${workspaceFolder}", 10 | "args": [ 11 | "-v", 12 | "-s" 13 | ], 14 | "env": { 15 | "PYTHONPATH": "${workspaceFolder}" 16 | }, 17 | "justMyCode": false 18 | } 19 | ] 20 | } -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Whisperx API Wrapper 2 | 3 | An API Wrapper for [Whisperx Library](https://github.com/m-bain/whisperX) 4 | 5 | ## Overview 6 | 7 | This is a FastAPI application that provides an endpoint for video/audio transcription using the `whisperx` command. The application supports multiple audio and video formats. It performs the transcription, alignment, and diarization of the uploaded media files. 8 | 9 | ## Features 10 | 11 | - User Authentication with JWT 12 | - Support for multiple audio and video formats 13 | - Diarization support 14 | - Customizable language and model settings 15 | 16 | ## Requirements 17 | 18 | - whisperx 19 | - Python 3.8+ 20 | - FastAPI 21 | - ffmpeg 22 | - SQLite 23 | - pyjwt 24 | - dotenv 25 | 26 | Follow the instructions on how to install Whisperx [in the official repository](https://github.com/m-bain/whisperX#3-install-this-repo) 27 | You can install these dependencies using the `requirements.txt` file: 28 | 29 | ```bash 30 | pip install -r requirements.txt 31 | ``` 32 | 33 | ## Environment Variables 34 | 35 | Create a `.env` file in your root directory and add the following variables: 36 | 37 | ```env 38 | SECRET_KEY=your_secret_key 39 | MASTER_KEY=your_master_key 40 | HUGGING_FACE_TOKEN=your_hugging_face_token 41 | API_PORT=11300 42 | ``` 43 | 44 | ## Database Setup 45 | 46 | SQLite is used for storing user information. The database is created automatically when the application runs. 47 | 48 | ## Running the Application 49 | 50 | Run the application using: 51 | 52 | ```bash 53 | python api_whisperx.py 54 | ``` 55 | 56 | Replace `main` with the name of your Python file if it's not `main.py`. 57 | 58 | ## API Endpoints 59 | 60 | ### POST `/auth` 61 | 62 | Authenticate a user and return a JWT token. 63 | 64 | - `username`: The username of the user. 65 | - `password`: The password of the user. 66 | 67 | ### POST `/create_user` 68 | 69 | Create a new user. 70 | 71 | - `username`: Desired username. 72 | - `password`: Desired password. 73 | - `master_key`: Master key for authorized user creation. 74 | 75 | ### POST `/whisperx-transcribe/` 76 | 77 | Transcribe an uploaded audio or video file. 78 | 79 | - `file`: The audio or video file to transcribe. 80 | - `lang`: Language for transcription (default is "pt"). 81 | - `model`: Model to use for transcription (default is "large-v2"). 82 | - `min_speakers`: Minimum number of speakers for diarization (default is 1). 83 | - `max_speakers`: Maximum number of speakers for diarization (default is 2). 84 | 85 | ## Logging 86 | 87 | The application has built-in logging that informs about the steps being performed and any errors that occur. 88 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi 2 | requests 3 | streamlit 4 | python-dotenv 5 | PyJWT 6 | celery 7 | pytest-mock 8 | pytest 9 | pytest-asyncio -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/namastexlabs/whisperx-api/4beba3b0afd0649d810adc74a844e1e362485982/src/__init__.py -------------------------------------------------------------------------------- /src/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/namastexlabs/whisperx-api/4beba3b0afd0649d810adc74a844e1e362485982/src/api/__init__.py -------------------------------------------------------------------------------- /src/api/auth.py: -------------------------------------------------------------------------------- 1 | import os 2 | from src.api.config import MASTER_KEY, SECRET_KEY, TOKEN_EXPIRATION_DAYS 3 | from fastapi import Depends, HTTPException, Query 4 | from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials 5 | import jwt 6 | from datetime import datetime, timedelta 7 | from src.api.database import get_db 8 | import sqlite3 9 | 10 | 11 | security = HTTPBearer() 12 | oauth2_scheme = HTTPBearer(scheme_name="JWT") 13 | 14 | 15 | def create_jwt_token(data: dict): 16 | token_expiration = datetime.utcnow() + timedelta(days=TOKEN_EXPIRATION_DAYS) 17 | data["exp"] = token_expiration 18 | return jwt.encode(data, SECRET_KEY, algorithm="HS256") 19 | 20 | 21 | def decode_jwt_token(token: str): 22 | try: 23 | return jwt.decode(token, SECRET_KEY, algorithms=["HS256"]) 24 | except jwt.ExpiredSignatureError: 25 | raise HTTPException(status_code=401, detail="Signature has expired") 26 | except jwt.InvalidTokenError: 27 | raise HTTPException(status_code=401, detail="Invalid token") 28 | 29 | 30 | async def get_current_user( 31 | credentials: HTTPAuthorizationCredentials = Depends(oauth2_scheme), 32 | ): 33 | token = credentials.credentials 34 | conn, cursor = get_db() 35 | cursor.execute( 36 | "SELECT username, token_expiration FROM users WHERE token = ?", (token,) 37 | ) 38 | row = cursor.fetchone() 39 | conn.close() 40 | if row: 41 | username, token_expiration_str = row 42 | if token_expiration_str: 43 | token_expiration = datetime.fromisoformat(token_expiration_str) 44 | if token_expiration > datetime.utcnow(): 45 | return {"username": username} 46 | raise HTTPException(status_code=401, detail="Invalid or expired token") 47 | 48 | 49 | def auth(username: str, password: str): 50 | conn, cursor = get_db() 51 | cursor.execute("SELECT password FROM users WHERE username = ?", (username,)) 52 | row = cursor.fetchone() 53 | if row and row[0] == password: 54 | token_expiration = datetime.utcnow() + timedelta(days=TOKEN_EXPIRATION_DAYS) 55 | token = create_jwt_token({"sub": username, "exp": token_expiration}) 56 | cursor.execute( 57 | "UPDATE users SET token = ?, token_expiration = ? WHERE username = ?", 58 | (token, token_expiration.isoformat(), username), 59 | ) 60 | conn.commit() 61 | conn.close() 62 | return {"access_token": token, "token_type": "bearer"} 63 | else: 64 | conn.close() 65 | raise HTTPException(status_code=401, detail="Invalid username or password") 66 | 67 | 68 | def create_user(username: str, password: str, master_key: str): 69 | if master_key != MASTER_KEY: 70 | raise HTTPException(status_code=403, detail="Not authorized") 71 | conn, cursor = get_db() 72 | try: 73 | cursor.execute( 74 | "INSERT INTO users (username, password) VALUES (?, ?)", (username, password) 75 | ) 76 | conn.commit() 77 | conn.close() 78 | return {"detail": "User created successfully"} 79 | except sqlite3.IntegrityError: 80 | conn.close() 81 | raise HTTPException(status_code=400, detail="Username already exists") 82 | -------------------------------------------------------------------------------- /src/api/config.py: -------------------------------------------------------------------------------- 1 | # Environment Variables 2 | import os 3 | import dotenv 4 | 5 | # Load environment variables 6 | dotenv.load_dotenv() 7 | 8 | 9 | API_PORT = os.getenv("API_PORT", 11300) 10 | API_HOST = os.getenv("API_HOST", "localhost") 11 | 12 | BROKER_URL = os.getenv("RABBIT_MQ_URI", "amqp://guest:guest@localhost:5672//") 13 | 14 | TOKEN_EXPIRATION_DAYS = 7 15 | SECRET_KEY = os.getenv("SECRET_KEY", "super_secret_key") 16 | MASTER_KEY = os.getenv("MASTER_KEY", "master_key") 17 | 18 | 19 | HF_TOKEN = os.getenv("HUGGING_FACE_TOKEN", "") 20 | -------------------------------------------------------------------------------- /src/api/database.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sqlite3 3 | 4 | 5 | def get_db(): 6 | project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 7 | db_path = os.path.join(project_root, "users.db") 8 | conn = sqlite3.connect(db_path) 9 | cursor = conn.cursor() 10 | cursor.execute( 11 | "CREATE TABLE IF NOT EXISTS users (username TEXT, password TEXT, token TEXT, token_expiration TEXT)" 12 | ) 13 | return conn, cursor 14 | -------------------------------------------------------------------------------- /src/api/main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | import os 4 | import subprocess 5 | from src.api.config import API_HOST, API_PORT 6 | from fastapi import FastAPI, Depends, HTTPException, Query, Form, UploadFile 7 | from src.api.auth import auth, get_current_user, create_user 8 | from src.api.database import get_db 9 | from src.api.models import LanguageEnum, ModelEnum, ResponseTypeEnum 10 | from src.api.tasks import transcribe_file, celery_app 11 | from src.utils.file_utils import create_directories, save_uploaded_file 12 | from celery import states 13 | 14 | current_dir = os.path.dirname(os.path.abspath(__file__)) 15 | parent_dir = os.path.dirname(current_dir) 16 | sys.path.insert(0, parent_dir) 17 | 18 | # Initialize FastAPI and logging 19 | app = FastAPI( 20 | title="Whisperx API Wrapper", 21 | description="Upload a video or audio file and get a transcription in return, max file size is 100MB.", 22 | version="0.1.2", 23 | license_info={ 24 | "name": "Apache 2.0", 25 | "url": "https://www.apache.org/licenses/LICENSE-2.0.html", 26 | }, 27 | ) 28 | logging.basicConfig(level=logging.INFO) 29 | 30 | 31 | @app.get("/") 32 | def read_root(): 33 | return {"info": "WhisperX API"} 34 | 35 | 36 | @app.post("/auth") 37 | def auth_endpoint(username: str, password: str): 38 | return auth(username, password) 39 | 40 | 41 | @app.post("/create_user") 42 | def create_user_endpoint(username: str, password: str, master_key: str = Query(...)): 43 | return create_user(username, password, master_key) 44 | 45 | 46 | @app.post("/jobs") 47 | async def create_transcription_job( 48 | current_user: dict = Depends(get_current_user), 49 | lang: LanguageEnum = Form( 50 | LanguageEnum.pt, description="Language for transcription" 51 | ), 52 | model: ModelEnum = Form(ModelEnum.largeV3, description="Model for transcription"), 53 | min_speakers: int = Form(1, description="Minimum number of speakers"), 54 | max_speakers: int = Form(2, description="Maximum number of speakers"), 55 | file: UploadFile = None, 56 | ): 57 | try: 58 | create_directories() 59 | temp_video_path = save_uploaded_file(file) 60 | task = transcribe_file.delay( 61 | temp_video_path, lang, model, min_speakers, max_speakers 62 | ) 63 | return {"task_id": task.id, "status": "PENDING"} 64 | except Exception as e: 65 | logging.error(f"An error occurred: {str(e)}") 66 | raise HTTPException(status_code=500, detail=str(e)) 67 | 68 | 69 | @app.get("/jobs") 70 | async def list_jobs(current_user: dict = Depends(get_current_user)): 71 | tasks = celery_app.control.inspect().active() 72 | jobs = [] 73 | for worker, task_list in tasks.items(): 74 | for task in task_list: 75 | jobs.append({"task_id": task["id"], "status": task["state"]}) 76 | return jobs 77 | 78 | 79 | @app.get("/jobs/{task_id}") 80 | async def get_job_status(task_id: str, current_user: dict = Depends(get_current_user)): 81 | task_result = celery_app.AsyncResult(task_id) 82 | if task_result.state == states.PENDING: 83 | response = { 84 | "task_id": task_id, 85 | "status": task_result.state, 86 | } 87 | elif task_result.state == states.FAILURE: 88 | response = { 89 | "task_id": task_id, 90 | "status": task_result.state, 91 | "error": str(task_result.result), 92 | } 93 | else: 94 | response = { 95 | "task_id": task_id, 96 | "status": task_result.state, 97 | "result": task_result.result, 98 | } 99 | return response 100 | 101 | 102 | @app.post("/jobs/{task_id}/stop") 103 | async def stop_job(task_id: str, current_user: dict = Depends(get_current_user)): 104 | celery_app.control.revoke(task_id, terminate=True) 105 | return {"task_id": task_id, "status": "STOPPED"} 106 | 107 | 108 | if __name__ == "__main__": 109 | import uvicorn 110 | from multiprocessing import Process 111 | 112 | def start_celery_worker(): 113 | subprocess.run( 114 | ["celery", "-A", "api.main.celery_app", "worker", "--loglevel=info"] 115 | ) 116 | 117 | celery_process = Process(target=start_celery_worker) 118 | celery_process.start() 119 | 120 | uvicorn.run(app, host=API_HOST, port=int(API_PORT)) 121 | -------------------------------------------------------------------------------- /src/api/models.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | class LanguageEnum(str, Enum): 4 | pt = "pt" 5 | en = "en" 6 | es = "es" 7 | fr = "fr" 8 | it = "it" 9 | de = "de" 10 | 11 | class ModelEnum(str, Enum): 12 | tiny = "tiny" 13 | small = "small" 14 | base = "base" 15 | medium = "medium" 16 | largeV2 = "large-v2" 17 | largeV3 = "large-v3" 18 | 19 | class ResponseTypeEnum(str, Enum): 20 | json = "json" 21 | file = "file" -------------------------------------------------------------------------------- /src/api/tasks.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from src.api.config import BROKER_URL 4 | from celery import Celery 5 | from src.utils.file_utils import convert_to_mp3, read_output_files 6 | from src.utils.transcription_utils import run_whisperx 7 | 8 | celery_app = Celery( 9 | "whisperx-tasks", backend="db+sqlite:///celery.db", broker=BROKER_URL 10 | ) 11 | 12 | 13 | @celery_app.task(name="transcribe_file") 14 | def transcribe_file(temp_video_path, lang, model, min_speakers, max_speakers): 15 | try: 16 | temp_mp3_path = convert_to_mp3(temp_video_path) 17 | base_name = os.path.splitext(os.path.basename(temp_mp3_path))[0] 18 | run_whisperx(temp_mp3_path, lang, model, min_speakers, max_speakers) 19 | output_files = read_output_files(base_name) 20 | result = { 21 | "status": "success", 22 | "vtt_content": output_files["vtt_content"], 23 | "txt_content": output_files["txt_content"], 24 | "json_content": output_files["json_content"], 25 | "srt_content": output_files["srt_content"], 26 | "vtt_path": output_files["vtt_path"], 27 | "txt_path": output_files["txt_path"], 28 | "json_path": output_files["json_path"], 29 | "srt_path": output_files["srt_path"], 30 | } 31 | return result 32 | 33 | except Exception as e: 34 | logging.error(f"An error occurred: {str(e)}") 35 | raise e 36 | -------------------------------------------------------------------------------- /src/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/namastexlabs/whisperx-api/4beba3b0afd0649d810adc74a844e1e362485982/src/tests/__init__.py -------------------------------------------------------------------------------- /src/tests/test_auth.py: -------------------------------------------------------------------------------- 1 | import jwt 2 | import pytest 3 | from fastapi import HTTPException 4 | from datetime import datetime, timedelta 5 | from src.api.auth import ( 6 | create_jwt_token, 7 | decode_jwt_token, 8 | get_current_user, 9 | auth, 10 | create_user, 11 | ) 12 | from src.api.config import MASTER_KEY, SECRET_KEY 13 | import sqlite3 14 | 15 | 16 | def test_create_jwt_token(): 17 | data = {"sub": "testuser"} 18 | token = create_jwt_token(data) 19 | assert isinstance(token, str) 20 | 21 | 22 | def test_decode_jwt_token_valid(): 23 | data = {"sub": "testuser"} 24 | token = create_jwt_token(data) 25 | decoded = decode_jwt_token(token) 26 | assert decoded["sub"] == "testuser" 27 | 28 | 29 | def test_decode_jwt_token_expired(): 30 | data = {"sub": "testuser"} 31 | token = create_jwt_token(data) 32 | 33 | # Manually set the expiration time to a past timestamp 34 | decoded_token = jwt.decode(token, SECRET_KEY, algorithms=["HS256"]) 35 | decoded_token["exp"] = datetime.utcnow() - timedelta(days=1) 36 | expired_token = jwt.encode(decoded_token, SECRET_KEY, algorithm="HS256") 37 | 38 | with pytest.raises(HTTPException) as exc_info: 39 | decode_jwt_token(expired_token) 40 | assert exc_info.value.status_code == 401 41 | assert exc_info.value.detail == "Signature has expired" 42 | 43 | 44 | def test_decode_jwt_token_invalid(): 45 | token = "invalid_token" 46 | with pytest.raises(HTTPException) as exc_info: 47 | decode_jwt_token(token) 48 | assert exc_info.value.status_code == 401 49 | assert exc_info.value.detail == "Invalid token" 50 | 51 | 52 | @pytest.mark.asyncio 53 | async def test_get_current_user(mocker): 54 | mock_credentials = mocker.Mock() 55 | mock_credentials.credentials = "valid_token" 56 | 57 | mock_get_db = mocker.patch("api.auth.get_db") 58 | mock_get_db.return_value = (mocker.Mock(), mocker.Mock()) 59 | mock_get_db.return_value[1].fetchone.return_value = ( 60 | "testuser", 61 | (datetime.utcnow() + timedelta(days=1)).isoformat(), 62 | ) 63 | 64 | user = await get_current_user(mock_credentials) 65 | assert user == {"username": "testuser"} 66 | 67 | 68 | def test_auth_valid(mocker): 69 | mock_get_db = mocker.patch("api.auth.get_db") 70 | mock_get_db.return_value = (mocker.Mock(), mocker.Mock()) 71 | mock_get_db.return_value[1].fetchone.return_value = ("password",) 72 | 73 | result = auth("testuser", "password") 74 | assert "access_token" in result 75 | assert result["token_type"] == "bearer" 76 | 77 | 78 | def test_auth_invalid(mocker): 79 | mock_get_db = mocker.patch("api.auth.get_db") 80 | mock_get_db.return_value = (mocker.Mock(), mocker.Mock()) 81 | mock_get_db.return_value[1].fetchone.return_value = None 82 | 83 | with pytest.raises(HTTPException) as exc_info: 84 | auth("testuser", "wrongpassword") 85 | assert exc_info.value.status_code == 401 86 | assert exc_info.value.detail == "Invalid username or password" 87 | 88 | 89 | def test_create_user_valid(mocker): 90 | mock_get_db = mocker.patch("api.auth.get_db") 91 | mock_get_db.return_value = (mocker.Mock(), mocker.Mock()) 92 | 93 | result = create_user("newuser", "password", MASTER_KEY) 94 | assert result == {"detail": "User created successfully"} 95 | 96 | 97 | def test_create_user_invalid_master_key(mocker): 98 | mock_get_db = mocker.patch("api.auth.get_db") 99 | mock_get_db.return_value = (mocker.Mock(), mocker.Mock()) 100 | 101 | with pytest.raises(HTTPException) as exc_info: 102 | create_user("newuser", "password", "invalid_master_key") 103 | assert exc_info.value.status_code == 403 104 | assert exc_info.value.detail == "Not authorized" 105 | 106 | 107 | def test_create_user_duplicate_username(mocker): 108 | mock_get_db = mocker.patch("api.auth.get_db") 109 | mock_get_db.return_value = (mocker.Mock(), mocker.Mock()) 110 | mock_get_db.return_value[1].execute.side_effect = sqlite3.IntegrityError 111 | 112 | with pytest.raises(HTTPException) as exc_info: 113 | create_user("existinguser", "password", MASTER_KEY) 114 | assert exc_info.value.status_code == 400 115 | assert exc_info.value.detail == "Username already exists" 116 | -------------------------------------------------------------------------------- /src/tests/test_database.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | from src.api.database import get_db 3 | 4 | 5 | def test_get_db(): 6 | conn, cursor = get_db() 7 | assert isinstance(conn, sqlite3.Connection) 8 | assert isinstance(cursor, sqlite3.Cursor) 9 | 10 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='users'") 11 | assert cursor.fetchone()[0] == "users" 12 | 13 | conn.close() 14 | -------------------------------------------------------------------------------- /src/tests/test_file_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from unittest.mock import MagicMock, patch 4 | from src.utils.file_utils import ( 5 | create_directories, 6 | save_uploaded_file, 7 | convert_to_mp3, 8 | read_output_files, 9 | zip_files, 10 | ) 11 | 12 | 13 | def test_create_directories(): 14 | # Test that the directories are created if they don't exist 15 | with tempfile.TemporaryDirectory() as temp_dir: 16 | os.chdir(temp_dir) 17 | create_directories() 18 | assert os.path.exists("./temp") 19 | assert os.path.exists("./data") 20 | 21 | 22 | def test_save_uploaded_file(): 23 | # Test that the uploaded file is saved correctly 24 | with tempfile.TemporaryDirectory() as temp_dir: 25 | os.chdir(temp_dir) 26 | os.makedirs("./temp") 27 | file_mock = MagicMock() 28 | file_mock.filename = "test.mp4" 29 | file_mock.file.read.return_value = b"file content" 30 | temp_video_path = save_uploaded_file(file_mock) 31 | assert os.path.exists(temp_video_path) 32 | with open(temp_video_path, "rb") as f: 33 | assert f.read() == b"file content" 34 | 35 | 36 | def test_convert_to_mp3(mocker): 37 | # Test that the video file is converted to MP3 correctly 38 | with tempfile.TemporaryDirectory() as temp_dir: 39 | os.chdir(temp_dir) 40 | # Create a sample video file path 41 | video_file_path = "test.mp4" 42 | 43 | # Create a sample video file 44 | with open(video_file_path, "wb") as f: 45 | f.write(b"sample video content") 46 | 47 | # Mock the subprocess.run function 48 | mock_run = mocker.patch("subprocess.run") 49 | mock_run.return_value.returncode = 0 50 | 51 | temp_mp3_path = convert_to_mp3(video_file_path) 52 | 53 | assert temp_mp3_path.endswith(".mp3") 54 | 55 | # Check if ffmpeg command was called with the correct arguments 56 | mock_run.assert_called_once_with( 57 | ["ffmpeg", "-y", "-i", video_file_path, temp_mp3_path], check=True 58 | ) 59 | 60 | 61 | def test_read_output_files(): 62 | # Test that the output files are read correctly 63 | with tempfile.TemporaryDirectory() as temp_dir: 64 | os.chdir(temp_dir) 65 | os.makedirs("./data") 66 | # Create sample output files 67 | with open("./data/test.vtt", "w") as f: 68 | f.write("vtt content") 69 | with open("./data/test.txt", "w") as f: 70 | f.write("txt content") 71 | with open("./data/test.json", "w") as f: 72 | f.write("json content") 73 | with open("./data/test.srt", "w") as f: 74 | f.write("srt content") 75 | output_files = read_output_files("test") 76 | assert output_files["vtt_content"] == "vtt content" 77 | assert output_files["txt_content"] == "txt content" 78 | assert output_files["json_content"] == "json content" 79 | assert output_files["srt_content"] == "srt content" 80 | assert output_files["vtt_path"] == "test.vtt" 81 | assert output_files["txt_path"] == "test.txt" 82 | assert output_files["json_path"] == "test.json" 83 | assert output_files["srt_path"] == "test.srt" 84 | 85 | 86 | def test_zip_files(): 87 | # Test that the files are zipped correctly 88 | with tempfile.TemporaryDirectory() as temp_dir: 89 | os.chdir(temp_dir) 90 | os.makedirs("./data") 91 | # Create sample files 92 | with open("./data/test.vtt", "w") as f: 93 | f.write("vtt content") 94 | with open("./data/test.txt", "w") as f: 95 | f.write("txt content") 96 | memory_file = zip_files("test.vtt", "test.txt") 97 | assert memory_file.getvalue().startswith( 98 | b"PK" 99 | ) # Check if the file starts with a ZIP file signature 100 | -------------------------------------------------------------------------------- /src/tests/test_main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import sqlite3 4 | import logging 5 | from fastapi.testclient import TestClient 6 | from src.api.config import MASTER_KEY 7 | from src.api.main import app, auth 8 | import tempfile 9 | 10 | client = TestClient(app) 11 | 12 | 13 | @pytest.fixture(scope="module") 14 | def db(): 15 | with tempfile.TemporaryDirectory() as temp_dir: 16 | db_path = os.path.join(temp_dir, "users.db") 17 | conn = sqlite3.connect(db_path) 18 | conn.execute("CREATE TABLE IF NOT EXISTS users (username TEXT, password TEXT)") 19 | yield conn 20 | conn.close() 21 | 22 | 23 | @pytest.fixture(scope="module") 24 | def auth_token(db): 25 | cursor = db.cursor() 26 | cursor.execute( 27 | "INSERT INTO users (username, password) VALUES (?, ?)", ("testuser", "testpass") 28 | ) 29 | db.commit() 30 | 31 | response = client.post( 32 | "/auth", params={"username": "testuser", "password": "testpass"} 33 | ) 34 | return response.json()["access_token"] 35 | 36 | 37 | def test_read_root(): 38 | response = client.get("/") 39 | assert response.status_code == 200 40 | assert response.json() == {"info": "WhisperX API"} 41 | 42 | 43 | def test_create_user_endpoint(db): 44 | response = client.post( 45 | "/create_user", 46 | params={ 47 | "master_key": MASTER_KEY, 48 | "username": "testuser", 49 | "password": "testpass", 50 | }, 51 | ) 52 | assert response.status_code == 200 53 | 54 | 55 | def test_auth_endpoint(db): 56 | response = auth("testuser", "testpass") 57 | assert response["access_token"] 58 | 59 | 60 | def test_create_transcription_job(auth_token): 61 | files = {"file": open("/home/namastex/whisperx-api/data/test.mp4", "rb")} 62 | response = client.post( 63 | "/jobs", files=files, headers={"Authorization": f"Bearer {auth_token}"} 64 | ) 65 | assert response.status_code == 200 66 | assert "task_id" in response.json() 67 | assert response.json()["status"] == "PENDING" 68 | 69 | 70 | def test_list_jobs(auth_token): 71 | response = client.get("/jobs", headers={"Authorization": f"Bearer {auth_token}"}) 72 | assert response.status_code == 200 73 | 74 | 75 | def test_get_job_status(auth_token): 76 | files = {"file": open("/home/namastex/whisperx-api/data/test.mp4", "rb")} 77 | create_response = client.post( 78 | "/jobs", files=files, headers={"Authorization": f"Bearer {auth_token}"} 79 | ) 80 | task_id = create_response.json()["task_id"] 81 | 82 | response = client.get( 83 | f"/jobs/{task_id}", headers={"Authorization": f"Bearer {auth_token}"} 84 | ) 85 | assert response.status_code == 200 86 | 87 | 88 | def test_stop_job(auth_token): 89 | files = {"file": open("/home/namastex/whisperx-api/data/test.mp4", "rb")} 90 | create_response = client.post( 91 | "/jobs", files=files, headers={"Authorization": f"Bearer {auth_token}"} 92 | ) 93 | task_id = create_response.json()["task_id"] 94 | 95 | response = client.post( 96 | f"/jobs/{task_id}/stop", headers={"Authorization": f"Bearer {auth_token}"} 97 | ) 98 | assert response.status_code == 200 99 | assert response.json()["status"] == "STOPPED" 100 | -------------------------------------------------------------------------------- /src/tests/test_tasks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from unittest.mock import patch 4 | from api.tasks import transcribe_file 5 | 6 | def test_transcribe_file_success(): 7 | with tempfile.NamedTemporaryFile(suffix=".mp4") as temp_video_file: 8 | # Create a temporary video file for testing 9 | temp_video_file.write(b"dummy video content") 10 | temp_video_file.flush() 11 | 12 | with patch("api.tasks.convert_to_mp3") as mock_convert_to_mp3, \ 13 | patch("api.tasks.run_whisperx") as mock_run_whisperx, \ 14 | patch("api.tasks.read_output_files") as mock_read_output_files: 15 | 16 | # Mock the behavior of the utility functions 17 | mock_convert_to_mp3.return_value = "temp_audio.mp3" 18 | mock_read_output_files.return_value = { 19 | "vtt_content": "dummy vtt content", 20 | "txt_content": "dummy txt content", 21 | "json_content": "dummy json content", 22 | "srt_content": "dummy srt content", 23 | "vtt_path": "output.vtt", 24 | "txt_path": "output.txt", 25 | "json_path": "output.json", 26 | "srt_path": "output.srt" 27 | } 28 | 29 | result = transcribe_file(temp_video_file.name, "en", "base", 1, 2) 30 | 31 | assert result["status"] == "success" 32 | assert result["vtt_content"] == "dummy vtt content" 33 | assert result["txt_content"] == "dummy txt content" 34 | assert result["json_content"] == "dummy json content" 35 | assert result["srt_content"] == "dummy srt content" 36 | assert result["vtt_path"] == "output.vtt" 37 | assert result["txt_path"] == "output.txt" 38 | assert result["json_path"] == "output.json" 39 | assert result["srt_path"] == "output.srt" 40 | 41 | mock_convert_to_mp3.assert_called_once_with(temp_video_file.name) 42 | mock_run_whisperx.assert_called_once_with("temp_audio.mp3", "en", "base", 1, 2) 43 | mock_read_output_files.assert_called_once() 44 | 45 | def test_transcribe_file_exception(): 46 | with tempfile.NamedTemporaryFile(suffix=".mp4") as temp_video_file: 47 | temp_video_file.write(b"dummy video content") 48 | temp_video_file.flush() 49 | 50 | with patch("api.tasks.convert_to_mp3") as mock_convert_to_mp3: 51 | mock_convert_to_mp3.side_effect = Exception("Test exception") 52 | 53 | try: 54 | transcribe_file(temp_video_file.name, "en", "base", 1, 2) 55 | assert False, "Expected an exception to be raised" 56 | except Exception as e: 57 | assert str(e) == "Test exception" -------------------------------------------------------------------------------- /src/tests/test_transcription_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from unittest.mock import patch 4 | from src.api.config import HF_TOKEN 5 | from src.utils.transcription_utils import run_whisperx 6 | 7 | 8 | def test_run_whisperx(mocker): 9 | # Test that the whisperx command is executed correctly 10 | with tempfile.TemporaryDirectory() as temp_dir: 11 | os.chdir(temp_dir) 12 | os.makedirs("./data") 13 | # Create a sample MP3 file 14 | with open("test.mp3", "wb") as f: 15 | f.write(b"audio content") 16 | 17 | # Mock the subprocess.run function 18 | mock_run = mocker.patch("subprocess.run") 19 | 20 | run_whisperx("test.mp3", "en", "base", 1, 2) 21 | 22 | expected_cmd = f"whisperx test.mp3 --model base --language en --hf_token {HF_TOKEN} --output_format all --output_dir ./data/ --align_model WAV2VEC2_ASR_LARGE_LV60K_960H --diarize --min_speakers 1 --max_speakers 2" 23 | mock_run.assert_called_once_with(expected_cmd.split(), check=True) 24 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/namastexlabs/whisperx-api/4beba3b0afd0649d810adc74a844e1e362485982/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from zipfile import ZipFile 4 | from io import BytesIO 5 | 6 | 7 | def create_directories(): 8 | if not os.path.exists('./temp'): 9 | os.makedirs('./temp') 10 | if not os.path.exists('./data'): 11 | os.makedirs('./data') 12 | 13 | def save_uploaded_file(file): 14 | temp_video_path = f"./temp/{file.filename}" 15 | with open(temp_video_path, "wb") as buffer: 16 | buffer.write(file.file.read()) 17 | return temp_video_path 18 | 19 | def convert_to_mp3(file_path): 20 | temp_mp3_path = os.path.splitext(file_path)[0] + ".mp3" 21 | subprocess.run(["ffmpeg", "-y", "-i", file_path, temp_mp3_path], check=True) 22 | return temp_mp3_path 23 | 24 | 25 | def read_output_files(base_name): 26 | output_dir = "./data/" 27 | vtt_path = f"{base_name}.vtt" 28 | txt_path = f"{base_name}.txt" 29 | json_path = f"{base_name}.json" 30 | srt_path = f"{base_name}.srt" 31 | 32 | with open(os.path.join(output_dir, vtt_path), "r") as vtt_file: 33 | vtt_content = vtt_file.read() 34 | 35 | with open(os.path.join(output_dir, txt_path), "r") as txt_file: 36 | txt_content = txt_file.read() 37 | 38 | with open(os.path.join(output_dir, json_path), "r") as json_file: 39 | json_content = json_file.read() 40 | 41 | with open(os.path.join(output_dir, srt_path), "r") as srt_file: 42 | srt_content = srt_file.read() 43 | 44 | return { 45 | "vtt_content": vtt_content, 46 | "txt_content": txt_content, 47 | "json_content": json_content, 48 | "srt_content": srt_content, 49 | "vtt_path": vtt_path, 50 | "txt_path": txt_path, 51 | "json_path": json_path, 52 | "srt_path": srt_path 53 | } 54 | 55 | def zip_files(vtt_path, txt_path): 56 | memory_file = BytesIO() 57 | with ZipFile(memory_file, 'w') as zf: 58 | zf.write(os.path.join("./data/", vtt_path), vtt_path) 59 | zf.write(os.path.join("./data/", txt_path), txt_path) 60 | memory_file.seek(0) 61 | return memory_file 62 | 63 | -------------------------------------------------------------------------------- /src/utils/transcription_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | 4 | from src.api.config import HF_TOKEN 5 | 6 | 7 | def run_whisperx(temp_mp3_path, lang, model, min_speakers, max_speakers): 8 | output_dir = "./data/" 9 | cmd = f"whisperx {temp_mp3_path} --model {model} --language {lang} --hf_token {HF_TOKEN} --output_format all --output_dir {output_dir} --align_model WAV2VEC2_ASR_LARGE_LV60K_960H --diarize --min_speakers {min_speakers} --max_speakers {max_speakers}" 10 | subprocess.run(cmd.split(), check=True) 11 | -------------------------------------------------------------------------------- /start.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from multiprocessing import Process 4 | from src.api.config import API_HOST, API_PORT 5 | import uvicorn 6 | from src.api.main import app 7 | 8 | 9 | def start_celery_worker(): 10 | subprocess.run( 11 | ["celery", "-A", "src.api.tasks.celery_app", "worker", "--loglevel=info"] 12 | ) 13 | 14 | 15 | if __name__ == "__main__": 16 | celery_process = Process(target=start_celery_worker) 17 | celery_process.start() 18 | 19 | uvicorn.run(app, host=API_HOST, port=int(API_PORT)) 20 | --------------------------------------------------------------------------------