├── .dockerignore ├── .gitignore ├── README.md ├── cog.yaml └── predict.py /.dockerignore: -------------------------------------------------------------------------------- 1 | **.mp3 2 | **.wav 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | hf_models/** 2 | torch_models/** 3 | **.mp3 4 | **.wav 5 | .cog/** 6 | **/__pycache__/** -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Cog implementation of whisperX, a library that adds batch processing on top of whisper (and also faster-whisper), leading to very fast audio transcription. -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | # set to true if your model requires a GPU 6 | gpu: true 7 | cuda: "11.8" 8 | 9 | system_packages: 10 | - "ffmpeg" 11 | 12 | # python version in the form '3.8' or '3.8.12' 13 | python_version: "3.11" 14 | 15 | # a list of packages in the format == 16 | python_packages: 17 | - "torch==2.0" 18 | - "torchaudio==2.0.0" 19 | - "git+https://github.com/m-bain/whisperX.git@befe2b242eb59dcd7a8a122d127614d5c63d36e9" 20 | 21 | run: 22 | - "pip install ipython" 23 | 24 | predict: 'predict.py:Predictor' 25 | 26 | 27 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # Prediction interface for Cog ⚙️ 2 | # https://github.com/replicate/cog/blob/main/docs/python.md 3 | import os 4 | os.environ['HF_HOME'] = '/src/hf_models' 5 | os.environ['TORCH_HOME'] = '/src/torch_models' 6 | from cog import BasePredictor, Input, Path 7 | import torch 8 | import whisperx 9 | import json 10 | 11 | 12 | compute_type="float16" 13 | class Predictor(BasePredictor): 14 | def setup(self): 15 | """Load the model into memory to make running multiple predictions efficient""" 16 | self.device = "cuda" 17 | self.model = whisperx.load_model("large-v2", self.device, language="en", compute_type=compute_type) 18 | self.alignment_model, self.metadata = whisperx.load_align_model(language_code="en", device=self.device) 19 | 20 | def predict( 21 | self, 22 | audio: Path = Input(description="Audio file"), 23 | batch_size: int = Input(description="Parallelization of input audio transcription", default=32), 24 | align_output: bool = Input(description="Use if you need word-level timing and not just batched transcription", default=False), 25 | only_text: bool = Input(description="Set if you only want to return text; otherwise, segment metadata will be returned as well.", default=False), 26 | debug: bool = Input(description="Print out memory usage information.", default=False) 27 | ) -> str: 28 | """Run a single prediction on the model""" 29 | with torch.inference_mode(): 30 | result = self.model.transcribe(str(audio), batch_size=batch_size) 31 | # result is dict w/keys ['segments', 'language'] 32 | # segments is a list of dicts,each dict has {'text': , 'start': , 'end': } 33 | if align_output: 34 | # NOTE - the "only_text" flag makes no sense with this flag, but we'll do it anyway 35 | result = whisperx.align(result['segments'], self.alignment_model, self.metadata, str(audio), self.device, return_char_alignments=False) 36 | # dict w/keys ['segments', 'word_segments'] 37 | # aligned_result['word_segments'] = list[dict], each dict contains {'word': , 'start': , 'end': , 'score': probability} 38 | # it is also sorted 39 | # aligned_result['segments'] - same as result segments, but w/a ['words'] segment which contains timing information above. 40 | # return_char_alignments adds in character level alignments. it is: too many. 41 | if only_text: 42 | return ''.join([val.text for val in result['segments']]) 43 | if debug: 44 | print(f"max gpu memory allocated over runtime: {torch.cuda.max_memory_reserved() / (1024 ** 3):.2f} GB") 45 | return json.dumps(result['segments']) 46 | 47 | --------------------------------------------------------------------------------