├── README.md ├── benchmarking ├── k6_benchmarking.js └── ted_60.wav └── model-server ├── Dockerfile ├── app ├── __init__.py ├── main.py ├── tests │ ├── polyai-minds14-0.wav │ └── test_main.py └── utils │ ├── __init__.py │ ├── config.py │ ├── diarization_utils.py │ ├── log_config.yaml │ └── validation_utils.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | A blog post with the details of inner workings: https://huggingface.co/blog/asr-diarization 2 | 3 | Use with a prebuilt image: 4 | ``` 5 | docker run --gpus all -p 7860:7860 --env-file .env ghcr.io/plaggy/asrdiarization-server:latest 6 | ``` 7 | and parametrize via `.env`: 8 | ``` 9 | ASR_MODEL= 10 | DIARIZATION_MODEL= 11 | ASSISTANT_MODEL= 12 | HF_TOKEN= 13 | ``` 14 | Or build your own 15 | 16 | Once deployed, send your audio with inference parameters like this: 17 | ```python 18 | import requests 19 | import json 20 | import aiohttp 21 | 22 | # synchronous call 23 | def sync_post(): 24 | files = {"file": open("", "rb")} 25 | data = {"parameters": json.dumps({"batch_size": 1, "assisted": "true"})} 26 | resp = requests.post("", files=files, data=data) 27 | print(resp.json()) 28 | 29 | # asynchronous call 30 | async def async_post(): 31 | data = { 32 | "file": open("", "rb"), 33 | "parameters": json.dumps({"batch_size": 30}) 34 | } 35 | async with aiohttp.ClientSession() as session: 36 | response = await session.post("", data=data) 37 | print(await response.json()) 38 | ``` 39 | -------------------------------------------------------------------------------- /benchmarking/k6_benchmarking.js: -------------------------------------------------------------------------------- 1 | import http from 'k6/http'; 2 | import { Trend } from 'k6/metrics'; 3 | 4 | const shortAssistedTime = new Trend('short_assisted', true); 5 | const shortNotAssistedTime = new Trend('short_not_assisted', true); 6 | const longAssistedTime = new Trend('long_assisted', true); 7 | const longNotAssistedTime = new Trend('long_not_assisted', true); 8 | 9 | const url = 'http://0.0.0.0:7860'; 10 | 11 | const audios = { 12 | 'short': open('model-server/app/tests/polyai-minds14-0.wav', 'rb'), // 8s 13 | 'long': open('ted_60.wav', 'rb') // 60s 14 | }; 15 | 16 | export const options = { 17 | scenarios: { 18 | short_audio_not_assisted: { 19 | executor: 'constant-vus', 20 | vus: 1, 21 | startTime: '0s', 22 | duration: '30s', 23 | env: { 24 | AUDIO: 'short', 25 | BATCH_SIZE: '24', 26 | ASSISTED: 'false' 27 | }, 28 | }, 29 | short_audio_assisted: { 30 | executor: 'constant-vus', 31 | vus: 1, 32 | startTime: '30s', 33 | duration: '30s', 34 | env: { 35 | AUDIO: 'short', 36 | BATCH_SIZE: '1', 37 | ASSISTED: 'true' 38 | }, 39 | }, 40 | long_audio_not_assisted: { 41 | executor: 'constant-vus', 42 | vus: 1, 43 | startTime: '1m', 44 | duration: '1m', 45 | env: { 46 | AUDIO: 'long', 47 | BATCH_SIZE: '24', 48 | ASSISTED: 'false' 49 | }, 50 | }, 51 | long_audio_assisted: { 52 | executor: 'constant-vus', 53 | vus: 1, 54 | startTime: '2m', 55 | duration: '1m', 56 | env: { 57 | AUDIO: 'long', 58 | BATCH_SIZE: '1', 59 | ASSISTED: 'true' 60 | }, 61 | }, 62 | }, 63 | }; 64 | 65 | 66 | export default function() { 67 | let parameters = JSON.stringify({'batch_size': __ENV.BATCH_SIZE, 'assisted': __ENV.ASSISTED}); 68 | const data = { 69 | parameters: parameters, 70 | file: http.file(audios[__ENV.AUDIO], 'filename', 'audio/wav'), 71 | }; 72 | 73 | const resp = http.post(url, data); 74 | if (__ENV.AUDIO == 'short' && __ENV.ASSISTED == 'false'){ 75 | shortNotAssistedTime.add(resp.timings.duration); 76 | } else if (__ENV.AUDIO == 'short' && __ENV.ASSISTED == 'true'){ 77 | shortAssistedTime.add(resp.timings.duration); 78 | } else if (__ENV.AUDIO == 'long' && __ENV.ASSISTED == 'false'){ 79 | longNotAssistedTime.add(resp.timings.duration); 80 | } else if (__ENV.AUDIO == 'long' && __ENV.ASSISTED == 'true'){ 81 | longAssistedTime.add(resp.timings.duration); 82 | } 83 | console.log(resp.status); 84 | }; -------------------------------------------------------------------------------- /benchmarking/ted_60.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/plaggy/fast-whisper-server/b26cc73d967df9f76012f9eb3efb0ddd9f93d51c/benchmarking/ted_60.wav -------------------------------------------------------------------------------- /model-server/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.3.2-devel-ubuntu22.04 2 | 3 | RUN apt-get update 4 | RUN DEBIAN_FRONTEND=noninteractive apt-get install -y \ 5 | build-essential \ 6 | libssl-dev \ 7 | libffi-dev \ 8 | libncurses5-dev \ 9 | zlib1g-dev \ 10 | libreadline-dev \ 11 | libbz2-dev \ 12 | libsqlite3-dev \ 13 | wget \ 14 | ffmpeg \ 15 | git \ 16 | && apt-get clean autoremove --yes \ 17 | && rm -rf /var/lib/{apt,dpkg,cache,log} 18 | 19 | ENV HOME="/root" 20 | WORKDIR ${HOME} 21 | 22 | RUN git clone --depth=1 https://github.com/pyenv/pyenv.git .pyenv 23 | 24 | ENV PYENV_ROOT="${HOME}/.pyenv" 25 | ENV PATH="${PYENV_ROOT}/shims:${PYENV_ROOT}/bin:${PATH}" 26 | 27 | RUN eval "$(pyenv init -)" 28 | 29 | RUN pyenv install 3.10 30 | RUN pyenv global 3.10 31 | 32 | RUN pip install --upgrade pip 33 | COPY requirements.txt . 34 | RUN pip install -r requirements.txt 35 | 36 | WORKDIR /usr/src 37 | 38 | COPY app app 39 | 40 | EXPOSE 7860 41 | 42 | ENTRYPOINT ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860", "--log-config", "app/utils/log_config.yaml"] -------------------------------------------------------------------------------- /model-server/app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/plaggy/fast-whisper-server/b26cc73d967df9f76012f9eb3efb0ddd9f93d51c/model-server/app/__init__.py -------------------------------------------------------------------------------- /model-server/app/main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | 4 | from typing import Annotated 5 | from pydantic import Json 6 | from fastapi import FastAPI, UploadFile, File, Form, HTTPException 7 | from fastapi.responses import PlainTextResponse 8 | from contextlib import asynccontextmanager 9 | from pyannote.audio import Pipeline 10 | from transformers import pipeline, AutoModelForCausalLM 11 | from huggingface_hub import HfApi 12 | 13 | from app.utils.validation_utils import validate_file, process_params 14 | from app.utils.diarization_utils import diarize 15 | from app.utils.config import model_settings 16 | 17 | logger = logging.getLogger(__name__) 18 | models = {} 19 | 20 | 21 | @asynccontextmanager 22 | async def lifespan(app: FastAPI): 23 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 24 | logger.info(f"Using device: {device.type}") 25 | 26 | torch_dtype = torch.float32 if device.type == "cpu" else torch.float16 27 | 28 | # from pytorch 2.2 sdpa implements flash attention 2 29 | models["asr_pipeline"] = pipeline( 30 | "automatic-speech-recognition", 31 | model=model_settings.asr_model, 32 | torch_dtype=torch_dtype, 33 | device=device 34 | ) 35 | 36 | models["assistant_model"] = AutoModelForCausalLM.from_pretrained( 37 | model_settings.assistant_model, 38 | torch_dtype=torch_dtype, 39 | low_cpu_mem_usage=True, 40 | use_safetensors=True 41 | ) if model_settings.assistant_model else None 42 | 43 | if models["assistant_model"]: 44 | models["assistant_model"].to(device) 45 | 46 | if model_settings.diarization_model: 47 | # diarization pipeline doesn't raise if there is no token 48 | HfApi().whoami(model_settings.hf_token) 49 | models["diarization_pipeline"] = Pipeline.from_pretrained( 50 | checkpoint_path=model_settings.diarization_model, 51 | use_auth_token=model_settings.hf_token, 52 | ) 53 | models["diarization_pipeline"].to(device) 54 | else: 55 | models["diarization_pipeline"] = None 56 | 57 | yield 58 | models.clear() 59 | 60 | 61 | app = FastAPI(lifespan=lifespan) 62 | 63 | 64 | @app.get("/", response_class=PlainTextResponse) 65 | @app.get("/health", response_class=PlainTextResponse) 66 | async def health(): 67 | return "OK" 68 | 69 | 70 | @app.post("/") 71 | @app.post("/predict") 72 | async def predict( 73 | file: Annotated[UploadFile, File()], 74 | parameters: Annotated[Json , Form()] = {} 75 | ): 76 | parameters = process_params(parameters) 77 | file = await validate_file(file) 78 | 79 | logger.info(f"inference parameters: {parameters}") 80 | 81 | generate_kwargs = { 82 | "task": parameters.task, 83 | "language": parameters.language, 84 | "assistant_model": models["assistant_model"] if parameters.assisted else None 85 | } 86 | 87 | try: 88 | logger.info("starting ASR pipeline") 89 | asr_outputs = models["asr_pipeline"]( 90 | file, 91 | chunk_length_s=parameters.chunk_length_s, 92 | batch_size=parameters.batch_size, 93 | generate_kwargs=generate_kwargs, 94 | return_timestamps=True, 95 | ) 96 | except RuntimeError as e: 97 | logger.error(f"ASR inference error: {str(e)}") 98 | raise HTTPException(status_code=400, detail=f"ASR inference error: {str(e)}") 99 | except Exception as e: 100 | logger.error(f"Unknown error diring ASR inference: {str(e)}") 101 | raise HTTPException(status_code=500, detail=f"Unknown error diring ASR inference: {str(e)}") 102 | 103 | if models["diarization_pipeline"]: 104 | try: 105 | transcript = diarize(models["diarization_pipeline"], file, parameters, asr_outputs) 106 | except RuntimeError as e: 107 | logger.error(f"Diarization inference error: {str(e)}") 108 | raise HTTPException(status_code=400, detail=f"Diarization inference error: {str(e)}") 109 | except Exception as e: 110 | logger.error(f"Unknown error during diarization: {str(e)}") 111 | raise HTTPException(status_code=500, detail=f"Unknown error during diarization: {str(e)}") 112 | else: 113 | transcript = [] 114 | 115 | return { 116 | "speakers": transcript, 117 | "chunks": asr_outputs["chunks"], 118 | "text": asr_outputs["text"], 119 | } 120 | -------------------------------------------------------------------------------- /model-server/app/tests/polyai-minds14-0.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/plaggy/fast-whisper-server/b26cc73d967df9f76012f9eb3efb0ddd9f93d51c/model-server/app/tests/polyai-minds14-0.wav -------------------------------------------------------------------------------- /model-server/app/tests/test_main.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import logging 4 | import pytest 5 | from fastapi.testclient import TestClient 6 | 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | @pytest.fixture 12 | def mock_app(monkeypatch): 13 | # setting up environment before import 14 | # make sure to set env HF_TOKEN 15 | monkeypatch.setenv("ASR_MODEL", "openai/whisper-small") 16 | monkeypatch.setenv("ASSISTANT_MODEL", "distil-whisper/distil-small.en") 17 | monkeypatch.setenv("DIARIZATION_MODEL", "pyannote/speaker-diarization-3.1") 18 | monkeypatch.setenv("HF_TOKEN", os.getenv("HF_TOKEN")) 19 | 20 | from app.main import app 21 | return app 22 | 23 | 24 | def test_predict(mock_app): 25 | with TestClient(mock_app) as test_client: 26 | files = {"file": open("app/tests/polyai-minds14-0.wav", "rb")} 27 | data = {"parameters": json.dumps({"batch_size": 12, "sampling_rate": 24000, "non-existent": "dummy"})} 28 | resp = test_client.post("/predict", data=data, files=files) 29 | resp_json = resp.json() 30 | assert resp.status_code == 200 31 | assert resp_json["speakers"] and resp_json["text"] 32 | 33 | 34 | def test_predict_no_params(mock_app): 35 | with TestClient(mock_app) as test_client: 36 | files = {"file": open("app/tests/polyai-minds14-0.wav", "rb")} 37 | resp = test_client.post("/predict", files=files) 38 | resp_json = resp.json() 39 | assert resp.status_code == 200 40 | assert resp_json["speakers"] and resp_json["text"] 41 | 42 | 43 | def test_predict_assisted(mock_app): 44 | with TestClient(mock_app) as test_client: 45 | files = {"file": open("app/tests/polyai-minds14-0.wav", "rb")} 46 | data = {"parameters": json.dumps({"batch_size": 1, "assisted": True})} 47 | resp = test_client.post("/predict", data=data, files=files) 48 | resp_json = resp.json() 49 | assert resp.status_code == 200 50 | assert resp_json["speakers"] and resp_json["text"] 51 | 52 | 53 | def test_predict_all_params(mock_app): 54 | with TestClient(mock_app) as test_client: 55 | files = {"file": open("app/tests/polyai-minds14-0.wav", "rb")} 56 | data = { 57 | "parameters": 58 | json.dumps({ 59 | "batch_size": 1, 60 | "assisted": True, 61 | "chunk_length_s": 24, 62 | "sampling_rate": 24000, 63 | "language": "en", 64 | "min_speakers": 1, 65 | "max_speakers": 2 66 | }) 67 | } 68 | resp = test_client.post("/predict", data=data, files=files) 69 | resp_json = resp.json() 70 | assert resp.status_code == 200 71 | assert resp_json["speakers"] and resp_json["text"] -------------------------------------------------------------------------------- /model-server/app/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/plaggy/fast-whisper-server/b26cc73d967df9f76012f9eb3efb0ddd9f93d51c/model-server/app/utils/__init__.py -------------------------------------------------------------------------------- /model-server/app/utils/config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from pydantic import BaseModel 4 | from pydantic_settings import BaseSettings 5 | 6 | from typing import Optional, Literal 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | audio_types = { 11 | "audio/x-flac", 12 | "audio/flac", 13 | "audio/mpeg", 14 | "audio/x-mpeg-3", 15 | "audio/wave", 16 | "audio/wav", 17 | "audio/x-wav", 18 | "audio/ogg", 19 | "audio/x-audio", 20 | "audio/webm", 21 | "audio/webm;codecs=opus", 22 | "audio/AMR", 23 | "audio/amr", 24 | "audio/AMR-WB", 25 | "audio/AMR-WB+", 26 | "audio/m4a", 27 | "audio/x-m4a" 28 | } 29 | 30 | 31 | class ModelSettings(BaseSettings): 32 | asr_model: str 33 | assistant_model: Optional[str] = None 34 | diarization_model: Optional[str] = None 35 | hf_token: Optional[str] = None 36 | 37 | 38 | class InferenceConfig(BaseModel): 39 | task: Literal["transcribe", "translate"] = "transcribe" 40 | batch_size: int = 24 41 | chunk_length_s: int = 30 42 | sampling_rate: int = 16000 43 | assisted: bool = False 44 | language: Optional[str] = None 45 | num_speakers: Optional[int] = None 46 | min_speakers: Optional[int] = None 47 | max_speakers: Optional[int] = None 48 | 49 | 50 | model_settings = ModelSettings() 51 | 52 | logger.info(f"asr model: {model_settings.asr_model}") 53 | logger.info(f"assist model: {model_settings.assistant_model}") 54 | logger.info(f"diar model: {model_settings.diarization_model}") -------------------------------------------------------------------------------- /model-server/app/utils/diarization_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torchaudio import functional as F 4 | from transformers.pipelines.audio_utils import ffmpeg_read 5 | from fastapi import HTTPException 6 | import sys 7 | 8 | # Code from insanely-fast-whisper: 9 | # https://github.com/Vaibhavs10/insanely-fast-whisper 10 | 11 | import logging 12 | logger = logging.getLogger(__name__) 13 | 14 | def preprocess_inputs(inputs, sampling_rate): 15 | inputs = ffmpeg_read(inputs, sampling_rate) 16 | 17 | if sampling_rate != 16000: 18 | inputs = F.resample( 19 | torch.from_numpy(inputs), sampling_rate, 16000 20 | ).numpy() 21 | 22 | if len(inputs.shape) != 1: 23 | logger.error(f"Diarization pipeline expecs single channel audio, received {inputs.shape}") 24 | raise HTTPException( 25 | status_code=400, 26 | detail=f"Diarization pipeline expecs single channel audio, received {inputs.shape}" 27 | ) 28 | 29 | # diarization model expects float32 torch tensor of shape `(channels, seq_len)` 30 | diarizer_inputs = torch.from_numpy(inputs).float() 31 | diarizer_inputs = diarizer_inputs.unsqueeze(0) 32 | 33 | return inputs, diarizer_inputs 34 | 35 | 36 | def diarize_audio(diarizer_inputs, diarization_pipeline, parameters): 37 | diarization = diarization_pipeline( 38 | {"waveform": diarizer_inputs, "sample_rate": parameters.sampling_rate}, 39 | num_speakers=parameters.num_speakers, 40 | min_speakers=parameters.min_speakers, 41 | max_speakers=parameters.max_speakers, 42 | ) 43 | 44 | segments = [] 45 | for segment, track, label in diarization.itertracks(yield_label=True): 46 | segments.append( 47 | { 48 | "segment": {"start": segment.start, "end": segment.end}, 49 | "track": track, 50 | "label": label, 51 | } 52 | ) 53 | 54 | # diarizer output may contain consecutive segments from the same speaker (e.g. {(0 -> 1, speaker_1), (1 -> 1.5, speaker_1), ...}) 55 | # we combine these segments to give overall timestamps for each speaker's turn (e.g. {(0 -> 1.5, speaker_1), ...}) 56 | new_segments = [] 57 | prev_segment = cur_segment = segments[0] 58 | 59 | for i in range(1, len(segments)): 60 | cur_segment = segments[i] 61 | 62 | # check if we have changed speaker ("label") 63 | if cur_segment["label"] != prev_segment["label"] and i < len(segments): 64 | # add the start/end times for the super-segment to the new list 65 | new_segments.append( 66 | { 67 | "segment": { 68 | "start": prev_segment["segment"]["start"], 69 | "end": cur_segment["segment"]["start"], 70 | }, 71 | "speaker": prev_segment["label"], 72 | } 73 | ) 74 | prev_segment = segments[i] 75 | 76 | # add the last segment(s) if there was no speaker change 77 | new_segments.append( 78 | { 79 | "segment": { 80 | "start": prev_segment["segment"]["start"], 81 | "end": cur_segment["segment"]["end"], 82 | }, 83 | "speaker": prev_segment["label"], 84 | } 85 | ) 86 | 87 | return new_segments 88 | 89 | 90 | def post_process_segments_and_transcripts(new_segments, transcript, group_by_speaker) -> list: 91 | # get the end timestamps for each chunk from the ASR output 92 | end_timestamps = np.array( 93 | [chunk["timestamp"][-1] if chunk["timestamp"][-1] is not None else sys.float_info.max for chunk in transcript]) 94 | segmented_preds = [] 95 | 96 | # align the diarizer timestamps and the ASR timestamps 97 | for segment in new_segments: 98 | # get the diarizer end timestamp 99 | end_time = segment["segment"]["end"] 100 | # find the ASR end timestamp that is closest to the diarizer's end timestamp and cut the transcript to here 101 | upto_idx = np.argmin(np.abs(end_timestamps - end_time)) 102 | 103 | if group_by_speaker: 104 | segmented_preds.append( 105 | { 106 | "speaker": segment["speaker"], 107 | "text": "".join( 108 | [chunk["text"] for chunk in transcript[: upto_idx + 1]] 109 | ), 110 | "timestamp": ( 111 | transcript[0]["timestamp"][0], 112 | transcript[upto_idx]["timestamp"][1], 113 | ), 114 | } 115 | ) 116 | else: 117 | for i in range(upto_idx + 1): 118 | segmented_preds.append({"speaker": segment["speaker"], **transcript[i]}) 119 | 120 | # crop the transcripts and timestamp lists according to the latest timestamp (for faster argmin) 121 | transcript = transcript[upto_idx + 1:] 122 | end_timestamps = end_timestamps[upto_idx + 1:] 123 | 124 | if len(end_timestamps) == 0: 125 | break 126 | 127 | return segmented_preds 128 | 129 | 130 | def diarize(diarization_pipeline, file, parameters, asr_outputs): 131 | _, diarizer_inputs = preprocess_inputs(file, parameters.sampling_rate) 132 | 133 | segments = diarize_audio( 134 | diarizer_inputs, 135 | diarization_pipeline, 136 | parameters 137 | ) 138 | 139 | return post_process_segments_and_transcripts( 140 | segments, asr_outputs["chunks"], group_by_speaker=False 141 | ) -------------------------------------------------------------------------------- /model-server/app/utils/log_config.yaml: -------------------------------------------------------------------------------- 1 | # Mostly based on uvicorn.config.LOGGING_CONFIG 2 | version: 1 3 | disable_existing_loggers: False 4 | formatters: 5 | default: 6 | "()": uvicorn.logging.DefaultFormatter 7 | format: '%(levelprefix)s %(asctime)s - %(name)s : %(message)s' 8 | use_colors: True 9 | access: 10 | "()": uvicorn.logging.AccessFormatter 11 | format: '%(levelprefix)s %(asctime)s : %(message)s' 12 | use_colors: True 13 | handlers: 14 | default: 15 | formatter: default 16 | class: logging.StreamHandler 17 | stream: ext://sys.stderr 18 | access: 19 | formatter: access 20 | class: logging.StreamHandler 21 | stream: ext://sys.stdout 22 | loggers: 23 | uvicorn.error: 24 | level: INFO 25 | handlers: 26 | - default 27 | propagate: False 28 | uvicorn.access: 29 | level: INFO 30 | handlers: 31 | - access 32 | propagate: False 33 | root: 34 | level: INFO 35 | handlers: 36 | - default 37 | propagate: False 38 | 39 | -------------------------------------------------------------------------------- /model-server/app/utils/validation_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import mimetypes 3 | 4 | from pydantic import BaseModel, ValidationError 5 | from fastapi import UploadFile, HTTPException 6 | from typing import Type 7 | 8 | from app.utils.config import audio_types, model_settings, InferenceConfig 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def process_params(parameters: dict[str, any]) -> Type[BaseModel]: 14 | default_fields = InferenceConfig.model_fields 15 | unsupported = [k for k in parameters if k not in default_fields] 16 | 17 | try: 18 | parameters = InferenceConfig(**parameters) 19 | except ValidationError as e: 20 | logger.error(f"Error validating parameters: {e}") 21 | raise HTTPException( 22 | status_code=400, 23 | detail=f"Error validating parameters: {e}" 24 | ) 25 | 26 | if parameters.assisted: 27 | if not model_settings.assistant_model: 28 | logger.error("Assisted generation is on but to assistant model was provided") 29 | raise HTTPException( 30 | status_code=400, 31 | detail="Assisted generation is on but to assistant model was provided" 32 | ) 33 | if parameters.batch_size > 1: 34 | logger.error("Batch size must be 1 when assisted generation is on") 35 | raise HTTPException( 36 | status_code=400, 37 | detail="Batch size must be 1 when assisted generation is on" 38 | ) 39 | 40 | if unsupported: 41 | logger.warning(f"parameters are not supported and will be ignored: {unsupported}") 42 | 43 | return parameters 44 | 45 | 46 | async def validate_file(file: UploadFile) -> bytes: 47 | content_type = file.content_type 48 | if not content_type: 49 | content_type = mimetypes.guess_type(file.filename)[0] 50 | logger.warning(f"content type was not provided, guessed as {content_type}") 51 | 52 | if content_type not in audio_types: 53 | logger.error(f"File type {file.content_type} not supported") 54 | raise HTTPException( 55 | status_code=400, 56 | detail=f"File type {file.content_type} not supported" 57 | ) 58 | 59 | return await file.read() -------------------------------------------------------------------------------- /model-server/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.27.2 2 | torch==2.2.1 3 | fastapi==0.110.0 4 | pyannote-audio==3.1.1 5 | transformers==4.38.2 6 | numpy==1.26.4 7 | torchaudio==2.2.1 8 | uvicorn==0.27.1 9 | httpx==0.27.0 10 | python-multipart==0.0.9 11 | pydantic-settings==2.2.1 12 | pytest==8.1.1 --------------------------------------------------------------------------------