├── .github ├── FUNDING.yml └── workflows │ ├── build-and-release.yml │ └── python-compatibility.yml ├── .gitignore ├── EXAMPLES.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── figures └── pipeline.png ├── pyproject.toml ├── uv.lock └── whisperx ├── SubtitlesProcessor.py ├── __init__.py ├── __main__.py ├── alignment.py ├── asr.py ├── assets ├── mel_filters.npz └── pytorch_model.bin ├── audio.py ├── conjunctions.py ├── diarize.py ├── transcribe.py ├── types.py ├── utils.py └── vads ├── __init__.py ├── pyannote.py ├── silero.py └── vad.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | custom: https://www.buymeacoffee.com/maxhbain 2 | -------------------------------------------------------------------------------- /.github/workflows/build-and-release.yml: -------------------------------------------------------------------------------- 1 | name: Build and release 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | build: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Checkout 12 | uses: actions/checkout@v4 13 | 14 | - name: Install uv 15 | uses: astral-sh/setup-uv@v5 16 | with: 17 | version: "0.5.14" 18 | python-version: "3.9" 19 | 20 | - name: Check if lockfile is up to date 21 | run: uv lock --check 22 | 23 | - name: Build package 24 | run: uv build 25 | 26 | - name: Release to Github 27 | uses: softprops/action-gh-release@v2 28 | with: 29 | files: dist/*.whl 30 | 31 | - name: Publish package to PyPi 32 | run: uv publish 33 | env: 34 | UV_PUBLISH_TOKEN: ${{ secrets.PYPI_API_TOKEN }} 35 | -------------------------------------------------------------------------------- /.github/workflows/python-compatibility.yml: -------------------------------------------------------------------------------- 1 | name: Python Compatibility Test 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | workflow_dispatch: # Allows manual triggering from GitHub UI 9 | 10 | jobs: 11 | test: 12 | runs-on: ubuntu-latest 13 | strategy: 14 | matrix: 15 | python-version: ["3.9", "3.10", "3.11", "3.12"] 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | 20 | - name: Install uv 21 | uses: astral-sh/setup-uv@v5 22 | with: 23 | version: "0.5.14" 24 | python-version: ${{ matrix.python-version }} 25 | 26 | - name: Check if lockfile is up to date 27 | run: uv lock --check 28 | 29 | - name: Install the project 30 | run: uv sync --all-extras 31 | 32 | - name: Test import 33 | run: | 34 | uv run python -c "import whisperx; print('Successfully imported whisperx')" 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc -------------------------------------------------------------------------------- /EXAMPLES.md: -------------------------------------------------------------------------------- 1 | # More Examples 2 | 3 | ## Other Languages 4 | 5 | For non-english ASR, it is best to use the `large` whisper model. Alignment models are automatically picked by the chosen language from the default [lists](https://github.com/m-bain/whisperX/blob/main/whisperx/alignment.py#L18). 6 | 7 | Currently support default models tested for {en, fr, de, es, it, ja, zh, nl} 8 | 9 | 10 | If the detected language is not in this list, you need to find a phoneme-based ASR model from [huggingface model hub](https://huggingface.co/models) and test it on your data. 11 | 12 | ### French 13 | whisperx --model large --language fr examples/sample_fr_01.wav 14 | 15 | 16 | https://user-images.githubusercontent.com/36994049/208298804-31c49d6f-6787-444e-a53f-e93c52706752.mov 17 | 18 | 19 | ### German 20 | whisperx --model large --language de examples/sample_de_01.wav 21 | 22 | 23 | https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov 24 | 25 | 26 | ### Italian 27 | whisperx --model large --language de examples/sample_it_01.wav 28 | 29 | 30 | https://user-images.githubusercontent.com/36994049/208298819-6f462b2c-8cae-4c54-b8e1-90855794efc7.mov 31 | 32 | 33 | ### Japanese 34 | whisperx --model large --language ja examples/sample_ja_01.wav 35 | 36 | 37 | https://user-images.githubusercontent.com/19920981/208731743-311f2360-b73b-4c60-809d-aaf3cd7e06f4.mov 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2024, Max Bain 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 19 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 20 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 21 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 22 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 23 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 24 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include whisperx/assets/* 2 | include LICENSE 3 | include requirements.txt 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

WhisperX

2 | 3 |

4 | 5 | GitHub stars 7 | 8 | 9 | GitHub issues 11 | 12 | 13 | GitHub license 15 | 16 | 17 | ArXiv paper 19 | 20 | 21 | Twitter 22 | 23 |

24 | 25 | whisperx-arch 26 | 27 | 28 | 29 | 30 | 31 | This repository provides fast automatic speech recognition (70x realtime with large-v2) with word-level timestamps and speaker diarization. 32 | 33 | - ⚡️ Batched inference for 70x realtime transcription using whisper large-v2 34 | - 🪶 [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend, requires <8GB gpu memory for large-v2 with beam_size=5 35 | - 🎯 Accurate word-level timestamps using wav2vec2 alignment 36 | - 👯‍♂️ Multispeaker ASR using speaker diarization from [pyannote-audio](https://github.com/pyannote/pyannote-audio) (speaker ID labels) 37 | - 🗣️ VAD preprocessing, reduces hallucination & batching with no WER degradation 38 | 39 | **Whisper** is an ASR model [developed by OpenAI](https://github.com/openai/whisper), trained on a large dataset of diverse audio. Whilst it does produces highly accurate transcriptions, the corresponding timestamps are at the utterance-level, not per word, and can be inaccurate by several seconds. OpenAI's whisper does not natively support batching. 40 | 41 | **Phoneme-Based ASR** A suite of models finetuned to recognise the smallest unit of speech distinguishing one word from another, e.g. the element p in "tap". A popular example model is [wav2vec2.0](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self). 42 | 43 | **Forced Alignment** refers to the process by which orthographic transcriptions are aligned to audio recordings to automatically generate phone level segmentation. 44 | 45 | **Voice Activity Detection (VAD)** is the detection of the presence or absence of human speech. 46 | 47 | **Speaker Diarization** is the process of partitioning an audio stream containing human speech into homogeneous segments according to the identity of each speaker. 48 | 49 |

New🚨

50 | 51 | - 1st place at [Ego4d transcription challenge](https://eval.ai/web/challenges/challenge-page/1637/leaderboard/3931/WER) 🏆 52 | - _WhisperX_ accepted at INTERSPEECH 2023 53 | - v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization 54 | - v3 released, 70x speed-up open-sourced. Using batched whisper with [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend! 55 | - v2 released, code cleanup, imports whisper library VAD filtering is now turned on by default, as in the paper. 56 | - Paper drop🎓👨‍🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with \*60-70x REAL TIME speed. 57 | 58 |

Setup ⚙️

59 | 60 | ### 1. Simple Installation (Recommended) 61 | 62 | The easiest way to install WhisperX is through PyPi: 63 | 64 | ```bash 65 | pip install whisperx 66 | ``` 67 | 68 | Or if using [uvx](https://docs.astral.sh/uv/guides/tools/#running-tools): 69 | 70 | ```bash 71 | uvx whisperx 72 | ``` 73 | 74 | ### 2. Advanced Installation Options 75 | 76 | These installation methods are for developers or users with specific needs. If you're not sure, stick with the simple installation above. 77 | 78 | #### Option A: Install from GitHub 79 | 80 | To install directly from the GitHub repository: 81 | 82 | ```bash 83 | uvx git+https://github.com/m-bain/whisperX.git 84 | ``` 85 | 86 | #### Option B: Developer Installation 87 | 88 | If you want to modify the code or contribute to the project: 89 | 90 | ```bash 91 | git clone https://github.com/m-bain/whisperX.git 92 | cd whisperX 93 | uv sync --all-extras --dev 94 | ``` 95 | 96 | > **Note**: The development version may contain experimental features and bugs. Use the stable PyPI release for production environments. 97 | 98 | You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup. 99 | 100 | ### Common Issues & Troubleshooting 🔧 101 | 102 | #### libcudnn Dependencies (GPU Users) 103 | 104 | If you're using WhisperX with GPU support and encounter errors like: 105 | 106 | - `Could not load library libcudnn_ops_infer.so.8` 107 | - `Unable to load any of {libcudnn_cnn.so.9.1.0, libcudnn_cnn.so.9.1, libcudnn_cnn.so.9, libcudnn_cnn.so}` 108 | - `libcudnn_ops_infer.so.8: cannot open shared object file: No such file or directory` 109 | 110 | This means your system is missing the CUDA Deep Neural Network library (cuDNN). This library is needed for GPU acceleration but isn't always installed by default. 111 | 112 | **Install cuDNN (example for apt based systems):** 113 | 114 | ```bash 115 | sudo apt update 116 | sudo apt install libcudnn8 libcudnn8-dev -y 117 | ``` 118 | 119 | ### Speaker Diarization 120 | 121 | To **enable Speaker Diarization**, include your Hugging Face access token (read) that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation-3.0) and [Speaker-Diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) (if you choose to use Speaker-Diarization 2.x, follow requirements [here](https://huggingface.co/pyannote/speaker-diarization) instead.) 122 | 123 | > **Note**
124 | > As of Oct 11, 2023, there is a known issue regarding slow performance with pyannote/Speaker-Diarization-3.0 in whisperX. It is due to dependency conflicts between faster-whisper and pyannote-audio 3.0.0. Please see [this issue](https://github.com/m-bain/whisperX/issues/499) for more details and potential workarounds. 125 | 126 |

Usage 💬 (command line)

127 | 128 | ### English 129 | 130 | Run whisper on example segment (using default params, whisper small) add `--highlight_words True` to visualise word timings in the .srt file. 131 | 132 | whisperx path/to/audio.wav 133 | 134 | Result using _WhisperX_ with forced alignment to wav2vec2.0 large: 135 | 136 | https://user-images.githubusercontent.com/36994049/208253969-7e35fe2a-7541-434a-ae91-8e919540555d.mp4 137 | 138 | Compare this to original whisper out the box, where many transcriptions are out of sync: 139 | 140 | https://user-images.githubusercontent.com/36994049/207743923-b4f0d537-29ae-4be2-b404-bb941db73652.mov 141 | 142 | For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models (bigger alignment model not found to be that helpful, see paper) e.g. 143 | 144 | whisperx path/to/audio.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H --batch_size 4 145 | 146 | To label the transcript with speaker ID's (set number of speakers if known e.g. `--min_speakers 2` `--max_speakers 2`): 147 | 148 | whisperx path/to/audio.wav --model large-v2 --diarize --highlight_words True 149 | 150 | To run on CPU instead of GPU (and for running on Mac OS X): 151 | 152 | whisperx path/to/audio.wav --compute_type int8 153 | 154 | ### Other languages 155 | 156 | The phoneme ASR alignment model is _language-specific_, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/f2da2f858e99e4211fe4f64b5f2938b007827e17/whisperx/alignment.py#L24-L58). 157 | Just pass in the `--language` code, and use the whisper `--model large`. 158 | 159 | Currently default models provided for `{en, fr, de, es, it}` via torchaudio pipelines and many other languages via Hugging Face. Please find the list of currently supported languages under `DEFAULT_ALIGN_MODELS_HF` on [alignment.py](https://github.com/m-bain/whisperX/blob/main/whisperx/alignment.py). If the detected language is not in this list, you need to find a phoneme-based ASR model from [huggingface model hub](https://huggingface.co/models) and test it on your data. 160 | 161 | #### E.g. German 162 | 163 | whisperx --model large-v2 --language de path/to/audio.wav 164 | 165 | https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov 166 | 167 | See more examples in other languages [here](EXAMPLES.md). 168 | 169 | ## Python usage 🐍 170 | 171 | ```python 172 | import whisperx 173 | import gc 174 | 175 | device = "cuda" 176 | audio_file = "audio.mp3" 177 | batch_size = 16 # reduce if low on GPU mem 178 | compute_type = "float16" # change to "int8" if low on GPU mem (may reduce accuracy) 179 | 180 | # 1. Transcribe with original whisper (batched) 181 | model = whisperx.load_model("large-v2", device, compute_type=compute_type) 182 | 183 | # save model to local path (optional) 184 | # model_dir = "/path/" 185 | # model = whisperx.load_model("large-v2", device, compute_type=compute_type, download_root=model_dir) 186 | 187 | audio = whisperx.load_audio(audio_file) 188 | result = model.transcribe(audio, batch_size=batch_size) 189 | print(result["segments"]) # before alignment 190 | 191 | # delete model if low on GPU resources 192 | # import gc; gc.collect(); torch.cuda.empty_cache(); del model 193 | 194 | # 2. Align whisper output 195 | model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device) 196 | result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False) 197 | 198 | print(result["segments"]) # after alignment 199 | 200 | # delete model if low on GPU resources 201 | # import gc; gc.collect(); torch.cuda.empty_cache(); del model_a 202 | 203 | # 3. Assign speaker labels 204 | diarize_model = whisperx.diarize.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device) 205 | 206 | # add min/max number of speakers if known 207 | diarize_segments = diarize_model(audio) 208 | # diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers) 209 | 210 | result = whisperx.assign_word_speakers(diarize_segments, result) 211 | print(diarize_segments) 212 | print(result["segments"]) # segments are now assigned speaker IDs 213 | ``` 214 | 215 | ## Demos 🚀 216 | 217 | [![Replicate (large-v3](https://img.shields.io/static/v1?label=Replicate+WhisperX+large-v3&message=Demo+%26+Cloud+API&color=blue)](https://replicate.com/victor-upmeet/whisperx) 218 | [![Replicate (large-v2](https://img.shields.io/static/v1?label=Replicate+WhisperX+large-v2&message=Demo+%26+Cloud+API&color=blue)](https://replicate.com/daanelson/whisperx) 219 | [![Replicate (medium)](https://img.shields.io/static/v1?label=Replicate+WhisperX+medium&message=Demo+%26+Cloud+API&color=blue)](https://replicate.com/carnifexer/whisperx) 220 | 221 | If you don't have access to your own GPUs, use the links above to try out WhisperX. 222 | 223 |

Technical Details 👷‍♂️

224 | 225 | For specific details on the batching and alignment, the effect of VAD, as well as the chosen alignment model, see the preprint [paper](https://www.robots.ox.ac.uk/~vgg/publications/2023/Bain23/bain23.pdf). 226 | 227 | To reduce GPU memory requirements, try any of the following (2. & 3. can affect quality): 228 | 229 | 1. reduce batch size, e.g. `--batch_size 4` 230 | 2. use a smaller ASR model `--model base` 231 | 3. Use lighter compute type `--compute_type int8` 232 | 233 | Transcription differences from openai's whisper: 234 | 235 | 1. Transcription without timestamps. To enable single pass batching, whisper inference is performed `--without_timestamps True`, this ensures 1 forward pass per sample in the batch. However, this can cause discrepancies the default whisper output. 236 | 2. VAD-based segment transcription, unlike the buffered transcription of openai's. In the WhisperX paper we show this reduces WER, and enables accurate batched inference 237 | 3. `--condition_on_prev_text` is set to `False` by default (reduces hallucination) 238 | 239 |

Limitations ⚠️

240 | 241 | - Transcript words which do not contain characters in the alignment models dictionary e.g. "2014." or "£13.60" cannot be aligned and therefore are not given a timing. 242 | - Overlapping speech is not handled particularly well by whisper nor whisperx 243 | - Diarization is far from perfect 244 | - Language specific wav2vec2 model is needed 245 | 246 |

Contribute 🧑‍🏫

247 | 248 | If you are multilingual, a major way you can contribute to this project is to find phoneme models on huggingface (or train your own) and test them on speech for the target language. If the results look good send a pull request and some examples showing its success. 249 | 250 | Bug finding and pull requests are also highly appreciated to keep this project going, since it's already diverging from the original research scope. 251 | 252 |

TODO 🗓

253 | 254 | - [x] Multilingual init 255 | 256 | - [x] Automatic align model selection based on language detection 257 | 258 | - [x] Python usage 259 | 260 | - [x] Incorporating speaker diarization 261 | 262 | - [x] Model flush, for low gpu mem resources 263 | 264 | - [x] Faster-whisper backend 265 | 266 | - [x] Add max-line etc. see (openai's whisper utils.py) 267 | 268 | - [x] Sentence-level segments (nltk toolbox) 269 | 270 | - [x] Improve alignment logic 271 | 272 | - [ ] update examples with diarization and word highlighting 273 | 274 | - [ ] Subtitle .ass output <- bring this back (removed in v3) 275 | 276 | - [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation) 277 | 278 | - [x] Allow silero-vad as alternative VAD option 279 | 280 | - [ ] Improve diarization (word level). _Harder than first thought..._ 281 | 282 |

Contact/Support 📇

283 | 284 | Contact maxhbain@gmail.com for queries. 285 | 286 | Buy Me A Coffee 287 | 288 |

Acknowledgements 🙏

289 | 290 | This work, and my PhD, is supported by the [VGG (Visual Geometry Group)](https://www.robots.ox.ac.uk/~vgg/) and the University of Oxford. 291 | 292 | Of course, this is builds on [openAI's whisper](https://github.com/openai/whisper). 293 | Borrows important alignment code from [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html) 294 | And uses the wonderful pyannote VAD / Diarization https://github.com/pyannote/pyannote-audio 295 | 296 | Valuable VAD & Diarization Models from: 297 | 298 | - [pyannote audio][https://github.com/pyannote/pyannote-audio] 299 | - [silero vad][https://github.com/snakers4/silero-vad] 300 | 301 | Great backend from [faster-whisper](https://github.com/guillaumekln/faster-whisper) and [CTranslate2](https://github.com/OpenNMT/CTranslate2) 302 | 303 | Those who have [supported this work financially](https://www.buymeacoffee.com/maxhbain) 🙏 304 | 305 | Finally, thanks to the OS [contributors](https://github.com/m-bain/whisperX/graphs/contributors) of this project, keeping it going and identifying bugs. 306 | 307 |

Citation

308 | If you use this in your research, please cite the paper: 309 | 310 | ```bibtex 311 | @article{bain2022whisperx, 312 | title={WhisperX: Time-Accurate Speech Transcription of Long-Form Audio}, 313 | author={Bain, Max and Huh, Jaesung and Han, Tengda and Zisserman, Andrew}, 314 | journal={INTERSPEECH 2023}, 315 | year={2023} 316 | } 317 | ``` 318 | -------------------------------------------------------------------------------- /figures/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m-bain/whisperX/b3432412530ecb0cc5ac923f161da281e41d23d2/figures/pipeline.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | urls = { repository = "https://github.com/m-bain/whisperx" } 3 | authors = [{ name = "Max Bain" }] 4 | name = "whisperx" 5 | version = "3.3.4" 6 | description = "Time-Accurate Automatic Speech Recognition using Whisper." 7 | readme = "README.md" 8 | requires-python = ">=3.9, <3.13" 9 | license = { text = "BSD-2-Clause" } 10 | 11 | dependencies = [ 12 | "ctranslate2<4.5.0", 13 | "faster-whisper>=1.1.1", 14 | "nltk>=3.9.1", 15 | "numpy>=2.0.2", 16 | "onnxruntime>=1.19", 17 | "pandas>=2.2.3", 18 | "pyannote-audio>=3.3.2", 19 | "torch>=2.5.1", 20 | "torchaudio>=2.5.1", 21 | "transformers>=4.48.0", 22 | ] 23 | 24 | 25 | [project.scripts] 26 | whisperx = "whisperx.__main__:cli" 27 | 28 | [build-system] 29 | requires = ["setuptools"] 30 | 31 | [tool.setuptools] 32 | include-package-data = true 33 | 34 | [tool.setuptools.packages.find] 35 | where = ["."] 36 | include = ["whisperx*"] 37 | -------------------------------------------------------------------------------- /whisperx/SubtitlesProcessor.py: -------------------------------------------------------------------------------- 1 | import math 2 | from whisperx.conjunctions import get_conjunctions, get_comma 3 | 4 | def normal_round(n): 5 | if n - math.floor(n) < 0.5: 6 | return math.floor(n) 7 | return math.ceil(n) 8 | 9 | 10 | def format_timestamp(seconds: float, is_vtt: bool = False): 11 | 12 | assert seconds >= 0, "non-negative timestamp expected" 13 | milliseconds = round(seconds * 1000.0) 14 | 15 | hours = milliseconds // 3_600_000 16 | milliseconds -= hours * 3_600_000 17 | 18 | minutes = milliseconds // 60_000 19 | milliseconds -= minutes * 60_000 20 | 21 | seconds = milliseconds // 1_000 22 | milliseconds -= seconds * 1_000 23 | 24 | separator = '.' if is_vtt else ',' 25 | 26 | hours_marker = f"{hours:02d}:" 27 | return ( 28 | f"{hours_marker}{minutes:02d}:{seconds:02d}{separator}{milliseconds:03d}" 29 | ) 30 | 31 | 32 | 33 | class SubtitlesProcessor: 34 | def __init__(self, segments, lang, max_line_length = 45, min_char_length_splitter = 30, is_vtt = False): 35 | self.comma = get_comma(lang) 36 | self.conjunctions = set(get_conjunctions(lang)) 37 | self.segments = segments 38 | self.lang = lang 39 | self.max_line_length = max_line_length 40 | self.min_char_length_splitter = min_char_length_splitter 41 | self.is_vtt = is_vtt 42 | complex_script_languages = ['th', 'lo', 'my', 'km', 'am', 'ko', 'ja', 'zh', 'ti', 'ta', 'te', 'kn', 'ml', 'hi', 'ne', 'mr', 'ar', 'fa', 'ur', 'ka'] 43 | if self.lang in complex_script_languages: 44 | self.max_line_length = 30 45 | self.min_char_length_splitter = 20 46 | 47 | def estimate_timestamp_for_word(self, words, i, next_segment_start_time=None): 48 | k = 0.25 49 | has_prev_end = i > 0 and 'end' in words[i - 1] 50 | has_next_start = i < len(words) - 1 and 'start' in words[i + 1] 51 | 52 | if has_prev_end: 53 | words[i]['start'] = words[i - 1]['end'] 54 | if has_next_start: 55 | words[i]['end'] = words[i + 1]['start'] 56 | else: 57 | if next_segment_start_time: 58 | words[i]['end'] = next_segment_start_time if next_segment_start_time - words[i - 1]['end'] <= 1 else next_segment_start_time - 0.5 59 | else: 60 | words[i]['end'] = words[i]['start'] + len(words[i]['word']) * k 61 | 62 | elif has_next_start: 63 | words[i]['start'] = words[i + 1]['start'] - len(words[i]['word']) * k 64 | words[i]['end'] = words[i + 1]['start'] 65 | 66 | else: 67 | if next_segment_start_time: 68 | words[i]['start'] = next_segment_start_time - 1 69 | words[i]['end'] = next_segment_start_time - 0.5 70 | else: 71 | words[i]['start'] = 0 72 | words[i]['end'] = 0 73 | 74 | 75 | 76 | def process_segments(self, advanced_splitting=True): 77 | subtitles = [] 78 | for i, segment in enumerate(self.segments): 79 | next_segment_start_time = self.segments[i + 1]['start'] if i + 1 < len(self.segments) else None 80 | 81 | if advanced_splitting: 82 | 83 | split_points = self.determine_advanced_split_points(segment, next_segment_start_time) 84 | subtitles.extend(self.generate_subtitles_from_split_points(segment, split_points, next_segment_start_time)) 85 | else: 86 | words = segment['words'] 87 | for i, word in enumerate(words): 88 | if 'start' not in word or 'end' not in word: 89 | self.estimate_timestamp_for_word(words, i, next_segment_start_time) 90 | 91 | subtitles.append({ 92 | 'start': segment['start'], 93 | 'end': segment['end'], 94 | 'text': segment['text'] 95 | }) 96 | 97 | return subtitles 98 | 99 | def determine_advanced_split_points(self, segment, next_segment_start_time=None): 100 | split_points = [] 101 | last_split_point = 0 102 | char_count = 0 103 | 104 | words = segment.get('words', segment['text'].split()) 105 | add_space = 0 if self.lang in ['zh', 'ja'] else 1 106 | 107 | total_char_count = sum(len(word['word']) if isinstance(word, dict) else len(word) + add_space for word in words) 108 | char_count_after = total_char_count 109 | 110 | for i, word in enumerate(words): 111 | word_text = word['word'] if isinstance(word, dict) else word 112 | word_length = len(word_text) + add_space 113 | char_count += word_length 114 | char_count_after -= word_length 115 | 116 | char_count_before = char_count - word_length 117 | 118 | if isinstance(word, dict) and ('start' not in word or 'end' not in word): 119 | self.estimate_timestamp_for_word(words, i, next_segment_start_time) 120 | 121 | if char_count >= self.max_line_length: 122 | midpoint = normal_round((last_split_point + i) / 2) 123 | if char_count_before >= self.min_char_length_splitter: 124 | split_points.append(midpoint) 125 | last_split_point = midpoint + 1 126 | char_count = sum(len(words[j]['word']) if isinstance(words[j], dict) else len(words[j]) + add_space for j in range(last_split_point, i + 1)) 127 | 128 | elif word_text.endswith(self.comma) and char_count_before >= self.min_char_length_splitter and char_count_after >= self.min_char_length_splitter: 129 | split_points.append(i) 130 | last_split_point = i + 1 131 | char_count = 0 132 | 133 | elif word_text.lower() in self.conjunctions and char_count_before >= self.min_char_length_splitter and char_count_after >= self.min_char_length_splitter: 134 | split_points.append(i - 1) 135 | last_split_point = i 136 | char_count = word_length 137 | 138 | return split_points 139 | 140 | 141 | def generate_subtitles_from_split_points(self, segment, split_points, next_start_time=None): 142 | subtitles = [] 143 | 144 | words = segment.get('words', segment['text'].split()) 145 | total_word_count = len(words) 146 | total_time = segment['end'] - segment['start'] 147 | elapsed_time = segment['start'] 148 | prefix = ' ' if self.lang not in ['zh', 'ja'] else '' 149 | start_idx = 0 150 | for split_point in split_points: 151 | 152 | fragment_words = words[start_idx:split_point + 1] 153 | current_word_count = len(fragment_words) 154 | 155 | 156 | if isinstance(fragment_words[0], dict): 157 | start_time = fragment_words[0]['start'] 158 | end_time = fragment_words[-1]['end'] 159 | next_start_time_for_word = words[split_point + 1]['start'] if split_point + 1 < len(words) else None 160 | if next_start_time_for_word and (next_start_time_for_word - end_time) <= 0.8: 161 | end_time = next_start_time_for_word 162 | else: 163 | fragment = prefix.join(fragment_words).strip() 164 | current_duration = (current_word_count / total_word_count) * total_time 165 | start_time = elapsed_time 166 | end_time = elapsed_time + current_duration 167 | elapsed_time += current_duration 168 | 169 | 170 | subtitles.append({ 171 | 'start': start_time, 172 | 'end': end_time, 173 | 'text': fragment if not isinstance(fragment_words[0], dict) else prefix.join(word['word'] for word in fragment_words) 174 | }) 175 | 176 | start_idx = split_point + 1 177 | 178 | # Handle the last fragment 179 | if start_idx < len(words): 180 | fragment_words = words[start_idx:] 181 | current_word_count = len(fragment_words) 182 | 183 | if isinstance(fragment_words[0], dict): 184 | start_time = fragment_words[0]['start'] 185 | end_time = fragment_words[-1]['end'] 186 | else: 187 | fragment = prefix.join(fragment_words).strip() 188 | current_duration = (current_word_count / total_word_count) * total_time 189 | start_time = elapsed_time 190 | end_time = elapsed_time + current_duration 191 | 192 | if next_start_time and (next_start_time - end_time) <= 0.8: 193 | end_time = next_start_time 194 | 195 | subtitles.append({ 196 | 'start': start_time, 197 | 'end': end_time if end_time is not None else segment['end'], 198 | 'text': fragment if not isinstance(fragment_words[0], dict) else prefix.join(word['word'] for word in fragment_words) 199 | }) 200 | 201 | return subtitles 202 | 203 | 204 | 205 | def save(self, filename="subtitles.srt", advanced_splitting=True): 206 | 207 | subtitles = self.process_segments(advanced_splitting) 208 | 209 | def write_subtitle(file, idx, start_time, end_time, text): 210 | 211 | file.write(f"{idx}\n") 212 | file.write(f"{start_time} --> {end_time}\n") 213 | file.write(text + "\n\n") 214 | 215 | with open(filename, 'w', encoding='utf-8') as file: 216 | if self.is_vtt: 217 | file.write("WEBVTT\n\n") 218 | 219 | if advanced_splitting: 220 | for idx, subtitle in enumerate(subtitles, 1): 221 | start_time = format_timestamp(subtitle['start'], self.is_vtt) 222 | end_time = format_timestamp(subtitle['end'], self.is_vtt) 223 | text = subtitle['text'].strip() 224 | write_subtitle(file, idx, start_time, end_time, text) 225 | 226 | return len(subtitles) -------------------------------------------------------------------------------- /whisperx/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | def _lazy_import(name): 5 | module = importlib.import_module(f"whisperx.{name}") 6 | return module 7 | 8 | 9 | def load_align_model(*args, **kwargs): 10 | alignment = _lazy_import("alignment") 11 | return alignment.load_align_model(*args, **kwargs) 12 | 13 | 14 | def align(*args, **kwargs): 15 | alignment = _lazy_import("alignment") 16 | return alignment.align(*args, **kwargs) 17 | 18 | 19 | def load_model(*args, **kwargs): 20 | asr = _lazy_import("asr") 21 | return asr.load_model(*args, **kwargs) 22 | 23 | 24 | def load_audio(*args, **kwargs): 25 | audio = _lazy_import("audio") 26 | return audio.load_audio(*args, **kwargs) 27 | 28 | 29 | def assign_word_speakers(*args, **kwargs): 30 | diarize = _lazy_import("diarize") 31 | return diarize.assign_word_speakers(*args, **kwargs) 32 | -------------------------------------------------------------------------------- /whisperx/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import importlib.metadata 3 | import platform 4 | 5 | import torch 6 | 7 | from whisperx.utils import (LANGUAGES, TO_LANGUAGE_CODE, optional_float, 8 | optional_int, str2bool) 9 | 10 | 11 | def cli(): 12 | # fmt: off 13 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 14 | parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") 15 | parser.add_argument("--model", default="small", help="name of the Whisper model to use") 16 | parser.add_argument("--model_cache_only", type=str2bool, default=False, help="If True, will not attempt to download models, instead using cached models from --model_dir") 17 | parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") 18 | parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") 19 | parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference") 20 | parser.add_argument("--batch_size", default=8, type=int, help="the preferred batch size for inference") 21 | parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation") 22 | 23 | parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") 24 | parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json", "aud"], help="format of the output file; if not specified, all available formats will be produced") 25 | parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") 26 | 27 | parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") 28 | parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection") 29 | 30 | # alignment params 31 | parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment") 32 | parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.") 33 | parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment") 34 | parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file") 35 | 36 | # vad params 37 | parser.add_argument("--vad_method", type=str, default="pyannote", choices=["pyannote", "silero"], help="VAD method to be used") 38 | parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected") 39 | parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.") 40 | parser.add_argument("--chunk_size", type=int, default=30, help="Chunk size for merging VAD segments. Default is 30, reduce this if the chunk is too long.") 41 | 42 | # diarization params 43 | parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word") 44 | parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file") 45 | parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file") 46 | parser.add_argument("--diarize_model", default="pyannote/speaker-diarization-3.1", type=str, help="Name of the speaker diarization model to use") 47 | 48 | parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling") 49 | parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature") 50 | parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero") 51 | parser.add_argument("--patience", type=float, default=1.0, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search") 52 | parser.add_argument("--length_penalty", type=float, default=1.0, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default") 53 | 54 | parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations") 55 | parser.add_argument("--suppress_numerals", action="store_true", help="whether to suppress numeric symbols and currency symbols during sampling, since wav2vec2 cannot align them correctly") 56 | 57 | parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") 58 | parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") 59 | parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default") 60 | 61 | parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below") 62 | parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed") 63 | parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed") 64 | parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") 65 | 66 | parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line") 67 | parser.add_argument("--max_line_count", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of lines in a segment") 68 | parser.add_argument("--highlight_words", type=str2bool, default=False, help="(not possible with --no_align) underline each word as it is spoken in srt and vtt") 69 | parser.add_argument("--segment_resolution", type=str, default="sentence", choices=["sentence", "chunk"], help="(not possible with --no_align) the maximum number of characters in a line before breaking the line") 70 | 71 | parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") 72 | 73 | parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models") 74 | 75 | parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.") 76 | parser.add_argument("--version", "-V", action="version", version=f"%(prog)s {importlib.metadata.version('whisperx')}",help="Show whisperx version information and exit") 77 | parser.add_argument("--python-version", "-P", action="version", version=f"Python {platform.python_version()} ({platform.python_implementation()})",help="Show python version information and exit") 78 | # fmt: on 79 | 80 | args = parser.parse_args().__dict__ 81 | 82 | from whisperx.transcribe import transcribe_task 83 | 84 | transcribe_task(args, parser) 85 | 86 | 87 | if __name__ == "__main__": 88 | cli() 89 | -------------------------------------------------------------------------------- /whisperx/alignment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Forced Alignment with Whisper 3 | C. Max Bain 4 | """ 5 | import math 6 | 7 | from dataclasses import dataclass 8 | from typing import Iterable, Optional, Union, List 9 | 10 | import numpy as np 11 | import pandas as pd 12 | import torch 13 | import torchaudio 14 | from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor 15 | 16 | from whisperx.audio import SAMPLE_RATE, load_audio 17 | from whisperx.utils import interpolate_nans 18 | from whisperx.types import ( 19 | AlignedTranscriptionResult, 20 | SingleSegment, 21 | SingleAlignedSegment, 22 | SingleWordSegment, 23 | SegmentData, 24 | ) 25 | from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters 26 | 27 | PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof'] 28 | 29 | LANGUAGES_WITHOUT_SPACES = ["ja", "zh"] 30 | 31 | DEFAULT_ALIGN_MODELS_TORCH = { 32 | "en": "WAV2VEC2_ASR_BASE_960H", 33 | "fr": "VOXPOPULI_ASR_BASE_10K_FR", 34 | "de": "VOXPOPULI_ASR_BASE_10K_DE", 35 | "es": "VOXPOPULI_ASR_BASE_10K_ES", 36 | "it": "VOXPOPULI_ASR_BASE_10K_IT", 37 | } 38 | 39 | DEFAULT_ALIGN_MODELS_HF = { 40 | "ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese", 41 | "zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn", 42 | "nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch", 43 | "uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm", 44 | "pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese", 45 | "ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic", 46 | "cs": "comodoro/wav2vec2-xls-r-300m-cs-250", 47 | "ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian", 48 | "pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish", 49 | "hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian", 50 | "fi": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish", 51 | "fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian", 52 | "el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek", 53 | "tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish", 54 | "da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech", 55 | "he": "imvladikon/wav2vec2-xls-r-300m-hebrew", 56 | "vi": 'nguyenvulebinh/wav2vec2-base-vi', 57 | "ko": "kresnik/wav2vec2-large-xlsr-korean", 58 | "ur": "kingabzpro/wav2vec2-large-xls-r-300m-Urdu", 59 | "te": "anuragshas/wav2vec2-large-xlsr-53-telugu", 60 | "hi": "theainerd/Wav2Vec2-large-xlsr-hindi", 61 | "ca": "softcatala/wav2vec2-large-xlsr-catala", 62 | "ml": "gvs/wav2vec2-large-xlsr-malayalam", 63 | "no": "NbAiLab/nb-wav2vec2-1b-bokmaal-v2", 64 | "nn": "NbAiLab/nb-wav2vec2-1b-nynorsk", 65 | "sk": "comodoro/wav2vec2-xls-r-300m-sk-cv8", 66 | "sl": "anton-l/wav2vec2-large-xlsr-53-slovenian", 67 | "hr": "classla/wav2vec2-xls-r-parlaspeech-hr", 68 | "ro": "gigant/romanian-wav2vec2", 69 | "eu": "stefan-it/wav2vec2-large-xlsr-53-basque", 70 | "gl": "ifrz/wav2vec2-large-xlsr-galician", 71 | "ka": "xsway/wav2vec2-large-xlsr-georgian", 72 | "lv": "jimregan/wav2vec2-large-xlsr-latvian-cv", 73 | "tl": "Khalsuu/filipino-wav2vec2-l-xls-r-300m-official", 74 | } 75 | 76 | 77 | def load_align_model(language_code: str, device: str, model_name: Optional[str] = None, model_dir=None): 78 | if model_name is None: 79 | # use default model 80 | if language_code in DEFAULT_ALIGN_MODELS_TORCH: 81 | model_name = DEFAULT_ALIGN_MODELS_TORCH[language_code] 82 | elif language_code in DEFAULT_ALIGN_MODELS_HF: 83 | model_name = DEFAULT_ALIGN_MODELS_HF[language_code] 84 | else: 85 | print(f"There is no default alignment model set for this language ({language_code}).\ 86 | Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]") 87 | raise ValueError(f"No default align-model for language: {language_code}") 88 | 89 | if model_name in torchaudio.pipelines.__all__: 90 | pipeline_type = "torchaudio" 91 | bundle = torchaudio.pipelines.__dict__[model_name] 92 | align_model = bundle.get_model(dl_kwargs={"model_dir": model_dir}).to(device) 93 | labels = bundle.get_labels() 94 | align_dictionary = {c.lower(): i for i, c in enumerate(labels)} 95 | else: 96 | try: 97 | processor = Wav2Vec2Processor.from_pretrained(model_name, cache_dir=model_dir) 98 | align_model = Wav2Vec2ForCTC.from_pretrained(model_name, cache_dir=model_dir) 99 | except Exception as e: 100 | print(e) 101 | print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models") 102 | raise ValueError(f'The chosen align_model "{model_name}" could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14)') 103 | pipeline_type = "huggingface" 104 | align_model = align_model.to(device) 105 | labels = processor.tokenizer.get_vocab() 106 | align_dictionary = {char.lower(): code for char,code in processor.tokenizer.get_vocab().items()} 107 | 108 | align_metadata = {"language": language_code, "dictionary": align_dictionary, "type": pipeline_type} 109 | 110 | return align_model, align_metadata 111 | 112 | 113 | def align( 114 | transcript: Iterable[SingleSegment], 115 | model: torch.nn.Module, 116 | align_model_metadata: dict, 117 | audio: Union[str, np.ndarray, torch.Tensor], 118 | device: str, 119 | interpolate_method: str = "nearest", 120 | return_char_alignments: bool = False, 121 | print_progress: bool = False, 122 | combined_progress: bool = False, 123 | ) -> AlignedTranscriptionResult: 124 | """ 125 | Align phoneme recognition predictions to known transcription. 126 | """ 127 | 128 | if not torch.is_tensor(audio): 129 | if isinstance(audio, str): 130 | audio = load_audio(audio) 131 | audio = torch.from_numpy(audio) 132 | if len(audio.shape) == 1: 133 | audio = audio.unsqueeze(0) 134 | 135 | MAX_DURATION = audio.shape[1] / SAMPLE_RATE 136 | 137 | model_dictionary = align_model_metadata["dictionary"] 138 | model_lang = align_model_metadata["language"] 139 | model_type = align_model_metadata["type"] 140 | 141 | # 1. Preprocess to keep only characters in dictionary 142 | total_segments = len(transcript) 143 | # Store temporary processing values 144 | segment_data: dict[int, SegmentData] = {} 145 | for sdx, segment in enumerate(transcript): 146 | # strip spaces at beginning / end, but keep track of the amount. 147 | if print_progress: 148 | base_progress = ((sdx + 1) / total_segments) * 100 149 | percent_complete = (50 + base_progress / 2) if combined_progress else base_progress 150 | print(f"Progress: {percent_complete:.2f}%...") 151 | 152 | num_leading = len(segment["text"]) - len(segment["text"].lstrip()) 153 | num_trailing = len(segment["text"]) - len(segment["text"].rstrip()) 154 | text = segment["text"] 155 | 156 | # split into words 157 | if model_lang not in LANGUAGES_WITHOUT_SPACES: 158 | per_word = text.split(" ") 159 | else: 160 | per_word = text 161 | 162 | clean_char, clean_cdx = [], [] 163 | for cdx, char in enumerate(text): 164 | char_ = char.lower() 165 | # wav2vec2 models use "|" character to represent spaces 166 | if model_lang not in LANGUAGES_WITHOUT_SPACES: 167 | char_ = char_.replace(" ", "|") 168 | 169 | # ignore whitespace at beginning and end of transcript 170 | if cdx < num_leading: 171 | pass 172 | elif cdx > len(text) - num_trailing - 1: 173 | pass 174 | elif char_ in model_dictionary.keys(): 175 | clean_char.append(char_) 176 | clean_cdx.append(cdx) 177 | else: 178 | # add placeholder 179 | clean_char.append('*') 180 | clean_cdx.append(cdx) 181 | 182 | clean_wdx = [] 183 | for wdx, wrd in enumerate(per_word): 184 | if any([c in model_dictionary.keys() for c in wrd.lower()]): 185 | clean_wdx.append(wdx) 186 | else: 187 | # index for placeholder 188 | clean_wdx.append(wdx) 189 | 190 | 191 | punkt_param = PunktParameters() 192 | punkt_param.abbrev_types = set(PUNKT_ABBREVIATIONS) 193 | sentence_splitter = PunktSentenceTokenizer(punkt_param) 194 | sentence_spans = list(sentence_splitter.span_tokenize(text)) 195 | 196 | segment_data[sdx] = { 197 | "clean_char": clean_char, 198 | "clean_cdx": clean_cdx, 199 | "clean_wdx": clean_wdx, 200 | "sentence_spans": sentence_spans 201 | } 202 | 203 | aligned_segments: List[SingleAlignedSegment] = [] 204 | 205 | # 2. Get prediction matrix from alignment model & align 206 | for sdx, segment in enumerate(transcript): 207 | 208 | t1 = segment["start"] 209 | t2 = segment["end"] 210 | text = segment["text"] 211 | 212 | aligned_seg: SingleAlignedSegment = { 213 | "start": t1, 214 | "end": t2, 215 | "text": text, 216 | "words": [], 217 | "chars": None, 218 | } 219 | 220 | if return_char_alignments: 221 | aligned_seg["chars"] = [] 222 | 223 | # check we can align 224 | if len(segment_data[sdx]["clean_char"]) == 0: 225 | print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...') 226 | aligned_segments.append(aligned_seg) 227 | continue 228 | 229 | if t1 >= MAX_DURATION: 230 | print(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping...') 231 | aligned_segments.append(aligned_seg) 232 | continue 233 | 234 | text_clean = "".join(segment_data[sdx]["clean_char"]) 235 | tokens = [model_dictionary.get(c, -1) for c in text_clean] 236 | 237 | f1 = int(t1 * SAMPLE_RATE) 238 | f2 = int(t2 * SAMPLE_RATE) 239 | 240 | # TODO: Probably can get some speedup gain with batched inference here 241 | waveform_segment = audio[:, f1:f2] 242 | # Handle the minimum input length for wav2vec2 models 243 | if waveform_segment.shape[-1] < 400: 244 | lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device) 245 | waveform_segment = torch.nn.functional.pad( 246 | waveform_segment, (0, 400 - waveform_segment.shape[-1]) 247 | ) 248 | else: 249 | lengths = None 250 | 251 | with torch.inference_mode(): 252 | if model_type == "torchaudio": 253 | emissions, _ = model(waveform_segment.to(device), lengths=lengths) 254 | elif model_type == "huggingface": 255 | emissions = model(waveform_segment.to(device)).logits 256 | else: 257 | raise NotImplementedError(f"Align model of type {model_type} not supported.") 258 | emissions = torch.log_softmax(emissions, dim=-1) 259 | 260 | emission = emissions[0].cpu().detach() 261 | 262 | blank_id = 0 263 | for char, code in model_dictionary.items(): 264 | if char == '[pad]' or char == '': 265 | blank_id = code 266 | 267 | trellis = get_trellis(emission, tokens, blank_id) 268 | # path = backtrack(trellis, emission, tokens, blank_id) 269 | path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2) 270 | 271 | if path is None: 272 | print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...') 273 | aligned_segments.append(aligned_seg) 274 | continue 275 | 276 | char_segments = merge_repeats(path, text_clean) 277 | 278 | duration = t2 - t1 279 | ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1) 280 | 281 | # assign timestamps to aligned characters 282 | char_segments_arr = [] 283 | word_idx = 0 284 | for cdx, char in enumerate(text): 285 | start, end, score = None, None, None 286 | if cdx in segment_data[sdx]["clean_cdx"]: 287 | char_seg = char_segments[segment_data[sdx]["clean_cdx"].index(cdx)] 288 | start = round(char_seg.start * ratio + t1, 3) 289 | end = round(char_seg.end * ratio + t1, 3) 290 | score = round(char_seg.score, 3) 291 | 292 | char_segments_arr.append( 293 | { 294 | "char": char, 295 | "start": start, 296 | "end": end, 297 | "score": score, 298 | "word-idx": word_idx, 299 | } 300 | ) 301 | 302 | # increment word_idx, nltk word tokenization would probably be more robust here, but us space for now... 303 | if model_lang in LANGUAGES_WITHOUT_SPACES: 304 | word_idx += 1 305 | elif cdx == len(text) - 1 or text[cdx+1] == " ": 306 | word_idx += 1 307 | 308 | char_segments_arr = pd.DataFrame(char_segments_arr) 309 | 310 | aligned_subsegments = [] 311 | # assign sentence_idx to each character index 312 | char_segments_arr["sentence-idx"] = None 313 | for sdx2, (sstart, send) in enumerate(segment_data[sdx]["sentence_spans"]): 314 | curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)] 315 | char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx2 316 | 317 | sentence_text = text[sstart:send] 318 | sentence_start = curr_chars["start"].min() 319 | end_chars = curr_chars[curr_chars["char"] != ' '] 320 | sentence_end = end_chars["end"].max() 321 | sentence_words = [] 322 | 323 | for word_idx in curr_chars["word-idx"].unique(): 324 | word_chars = curr_chars.loc[curr_chars["word-idx"] == word_idx] 325 | word_text = "".join(word_chars["char"].tolist()).strip() 326 | if len(word_text) == 0: 327 | continue 328 | 329 | # dont use space character for alignment 330 | word_chars = word_chars[word_chars["char"] != " "] 331 | 332 | word_start = word_chars["start"].min() 333 | word_end = word_chars["end"].max() 334 | word_score = round(word_chars["score"].mean(), 3) 335 | 336 | # -1 indicates unalignable 337 | word_segment = {"word": word_text} 338 | 339 | if not np.isnan(word_start): 340 | word_segment["start"] = word_start 341 | if not np.isnan(word_end): 342 | word_segment["end"] = word_end 343 | if not np.isnan(word_score): 344 | word_segment["score"] = word_score 345 | 346 | sentence_words.append(word_segment) 347 | 348 | aligned_subsegments.append({ 349 | "text": sentence_text, 350 | "start": sentence_start, 351 | "end": sentence_end, 352 | "words": sentence_words, 353 | }) 354 | 355 | if return_char_alignments: 356 | curr_chars = curr_chars[["char", "start", "end", "score"]] 357 | curr_chars.fillna(-1, inplace=True) 358 | curr_chars = curr_chars.to_dict("records") 359 | curr_chars = [{key: val for key, val in char.items() if val != -1} for char in curr_chars] 360 | aligned_subsegments[-1]["chars"] = curr_chars 361 | 362 | aligned_subsegments = pd.DataFrame(aligned_subsegments) 363 | aligned_subsegments["start"] = interpolate_nans(aligned_subsegments["start"], method=interpolate_method) 364 | aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method) 365 | # concatenate sentences with same timestamps 366 | agg_dict = {"text": " ".join, "words": "sum"} 367 | if model_lang in LANGUAGES_WITHOUT_SPACES: 368 | agg_dict["text"] = "".join 369 | if return_char_alignments: 370 | agg_dict["chars"] = "sum" 371 | aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict) 372 | aligned_subsegments = aligned_subsegments.to_dict('records') 373 | aligned_segments += aligned_subsegments 374 | 375 | # create word_segments list 376 | word_segments: List[SingleWordSegment] = [] 377 | for segment in aligned_segments: 378 | word_segments += segment["words"] 379 | 380 | return {"segments": aligned_segments, "word_segments": word_segments} 381 | 382 | """ 383 | source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html 384 | """ 385 | 386 | 387 | def get_trellis(emission, tokens, blank_id=0): 388 | num_frame = emission.size(0) 389 | num_tokens = len(tokens) 390 | 391 | trellis = torch.zeros((num_frame, num_tokens)) 392 | trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0) 393 | trellis[0, 1:] = -float("inf") 394 | trellis[-num_tokens + 1:, 0] = float("inf") 395 | 396 | for t in range(num_frame - 1): 397 | trellis[t + 1, 1:] = torch.maximum( 398 | # Score for staying at the same token 399 | trellis[t, 1:] + emission[t, blank_id], 400 | # Score for changing to the next token 401 | # trellis[t, :-1] + emission[t, tokens[1:]], 402 | trellis[t, :-1] + get_wildcard_emission(emission[t], tokens[1:], blank_id), 403 | ) 404 | return trellis 405 | 406 | 407 | def get_wildcard_emission(frame_emission, tokens, blank_id): 408 | """Processing token emission scores containing wildcards (vectorized version) 409 | 410 | Args: 411 | frame_emission: Emission probability vector for the current frame 412 | tokens: List of token indices 413 | blank_id: ID of the blank token 414 | 415 | Returns: 416 | tensor: Maximum probability score for each token position 417 | """ 418 | assert 0 <= blank_id < len(frame_emission) 419 | 420 | # Convert tokens to a tensor if they are not already 421 | tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens 422 | 423 | # Create a mask to identify wildcard positions 424 | wildcard_mask = (tokens == -1) 425 | 426 | # Get scores for non-wildcard positions 427 | regular_scores = frame_emission[tokens.clamp(min=0)] # clamp to avoid -1 index 428 | 429 | # Create a mask and compute the maximum value without modifying frame_emission 430 | max_valid_score = frame_emission.clone() # Create a copy 431 | max_valid_score[blank_id] = float('-inf') # Modify the copy to exclude the blank token 432 | max_valid_score = max_valid_score.max() 433 | 434 | # Use where operation to combine results 435 | result = torch.where(wildcard_mask, max_valid_score, regular_scores) 436 | 437 | return result 438 | 439 | 440 | @dataclass 441 | class Point: 442 | token_index: int 443 | time_index: int 444 | score: float 445 | 446 | 447 | def backtrack(trellis, emission, tokens, blank_id=0): 448 | t, j = trellis.size(0) - 1, trellis.size(1) - 1 449 | 450 | path = [Point(j, t, emission[t, blank_id].exp().item())] 451 | while j > 0: 452 | # Should not happen but just in case 453 | assert t > 0 454 | 455 | # 1. Figure out if the current position was stay or change 456 | # Frame-wise score of stay vs change 457 | p_stay = emission[t - 1, blank_id] 458 | # p_change = emission[t - 1, tokens[j]] 459 | p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0] 460 | 461 | # Context-aware score for stay vs change 462 | stayed = trellis[t - 1, j] + p_stay 463 | changed = trellis[t - 1, j - 1] + p_change 464 | 465 | # Update position 466 | t -= 1 467 | if changed > stayed: 468 | j -= 1 469 | 470 | # Store the path with frame-wise probability. 471 | prob = (p_change if changed > stayed else p_stay).exp().item() 472 | path.append(Point(j, t, prob)) 473 | 474 | # Now j == 0, which means, it reached the SoS. 475 | # Fill up the rest for the sake of visualization 476 | while t > 0: 477 | prob = emission[t - 1, blank_id].exp().item() 478 | path.append(Point(j, t - 1, prob)) 479 | t -= 1 480 | 481 | return path[::-1] 482 | 483 | 484 | 485 | @dataclass 486 | class Path: 487 | points: List[Point] 488 | score: float 489 | 490 | 491 | @dataclass 492 | class BeamState: 493 | """State in beam search.""" 494 | token_index: int # Current token position 495 | time_index: int # Current time step 496 | score: float # Cumulative score 497 | path: List[Point] # Path history 498 | 499 | 500 | def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5): 501 | """Standard CTC beam search backtracking implementation. 502 | 503 | Args: 504 | trellis (torch.Tensor): The trellis (or lattice) of shape (T, N), where T is the number of time steps 505 | and N is the number of tokens (including the blank token). 506 | emission (torch.Tensor): The emission probabilities of shape (T, N). 507 | tokens (List[int]): List of token indices (excluding the blank token). 508 | blank_id (int, optional): The ID of the blank token. Defaults to 0. 509 | beam_width (int, optional): The number of top paths to keep during beam search. Defaults to 5. 510 | 511 | Returns: 512 | List[Point]: the best path 513 | """ 514 | T, J = trellis.size(0) - 1, trellis.size(1) - 1 515 | 516 | init_state = BeamState( 517 | token_index=J, 518 | time_index=T, 519 | score=trellis[T, J], 520 | path=[Point(J, T, emission[T, blank_id].exp().item())] 521 | ) 522 | 523 | beams = [init_state] 524 | 525 | while beams and beams[0].token_index > 0: 526 | next_beams = [] 527 | 528 | for beam in beams: 529 | t, j = beam.time_index, beam.token_index 530 | 531 | if t <= 0: 532 | continue 533 | 534 | p_stay = emission[t - 1, blank_id] 535 | p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0] 536 | 537 | stay_score = trellis[t - 1, j] 538 | change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf') 539 | 540 | # Stay 541 | if not math.isinf(stay_score): 542 | new_path = beam.path.copy() 543 | new_path.append(Point(j, t - 1, p_stay.exp().item())) 544 | next_beams.append(BeamState( 545 | token_index=j, 546 | time_index=t - 1, 547 | score=stay_score, 548 | path=new_path 549 | )) 550 | 551 | # Change 552 | if j > 0 and not math.isinf(change_score): 553 | new_path = beam.path.copy() 554 | new_path.append(Point(j - 1, t - 1, p_change.exp().item())) 555 | next_beams.append(BeamState( 556 | token_index=j - 1, 557 | time_index=t - 1, 558 | score=change_score, 559 | path=new_path 560 | )) 561 | 562 | # sort by score 563 | beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width] 564 | 565 | if not beams: 566 | break 567 | 568 | if not beams: 569 | return None 570 | 571 | best_beam = beams[0] 572 | t = best_beam.time_index 573 | j = best_beam.token_index 574 | while t > 0: 575 | prob = emission[t - 1, blank_id].exp().item() 576 | best_beam.path.append(Point(j, t - 1, prob)) 577 | t -= 1 578 | 579 | return best_beam.path[::-1] 580 | 581 | 582 | # Merge the labels 583 | @dataclass 584 | class Segment: 585 | label: str 586 | start: int 587 | end: int 588 | score: float 589 | 590 | def __repr__(self): 591 | return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})" 592 | 593 | @property 594 | def length(self): 595 | return self.end - self.start 596 | 597 | def merge_repeats(path, transcript): 598 | i1, i2 = 0, 0 599 | segments = [] 600 | while i1 < len(path): 601 | while i2 < len(path) and path[i1].token_index == path[i2].token_index: 602 | i2 += 1 603 | score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1) 604 | segments.append( 605 | Segment( 606 | transcript[path[i1].token_index], 607 | path[i1].time_index, 608 | path[i2 - 1].time_index + 1, 609 | score, 610 | ) 611 | ) 612 | i1 = i2 613 | return segments 614 | 615 | def merge_words(segments, separator="|"): 616 | words = [] 617 | i1, i2 = 0, 0 618 | while i1 < len(segments): 619 | if i2 >= len(segments) or segments[i2].label == separator: 620 | if i1 != i2: 621 | segs = segments[i1:i2] 622 | word = "".join([seg.label for seg in segs]) 623 | score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs) 624 | words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score)) 625 | i1 = i2 + 1 626 | i2 = i1 627 | else: 628 | i2 += 1 629 | return words 630 | -------------------------------------------------------------------------------- /whisperx/asr.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Optional, Union 3 | from dataclasses import replace 4 | 5 | import ctranslate2 6 | import faster_whisper 7 | import numpy as np 8 | import torch 9 | from faster_whisper.tokenizer import Tokenizer 10 | from faster_whisper.transcribe import TranscriptionOptions, get_ctranslate2_storage 11 | from transformers import Pipeline 12 | from transformers.pipelines.pt_utils import PipelineIterator 13 | 14 | from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram 15 | from whisperx.types import SingleSegment, TranscriptionResult 16 | from whisperx.vads import Vad, Silero, Pyannote 17 | 18 | 19 | def find_numeral_symbol_tokens(tokenizer): 20 | numeral_symbol_tokens = [] 21 | for i in range(tokenizer.eot): 22 | token = tokenizer.decode([i]).removeprefix(" ") 23 | has_numeral_symbol = any(c in "0123456789%$£" for c in token) 24 | if has_numeral_symbol: 25 | numeral_symbol_tokens.append(i) 26 | return numeral_symbol_tokens 27 | 28 | class WhisperModel(faster_whisper.WhisperModel): 29 | ''' 30 | FasterWhisperModel provides batched inference for faster-whisper. 31 | Currently only works in non-timestamp mode and fixed prompt for all samples in batch. 32 | ''' 33 | 34 | def generate_segment_batched( 35 | self, 36 | features: np.ndarray, 37 | tokenizer: Tokenizer, 38 | options: TranscriptionOptions, 39 | encoder_output=None, 40 | ): 41 | batch_size = features.shape[0] 42 | all_tokens = [] 43 | prompt_reset_since = 0 44 | if options.initial_prompt is not None: 45 | initial_prompt = " " + options.initial_prompt.strip() 46 | initial_prompt_tokens = tokenizer.encode(initial_prompt) 47 | all_tokens.extend(initial_prompt_tokens) 48 | previous_tokens = all_tokens[prompt_reset_since:] 49 | prompt = self.get_prompt( 50 | tokenizer, 51 | previous_tokens, 52 | without_timestamps=options.without_timestamps, 53 | prefix=options.prefix, 54 | hotwords=options.hotwords 55 | ) 56 | 57 | encoder_output = self.encode(features) 58 | 59 | max_initial_timestamp_index = int( 60 | round(options.max_initial_timestamp / self.time_precision) 61 | ) 62 | 63 | result = self.model.generate( 64 | encoder_output, 65 | [prompt] * batch_size, 66 | beam_size=options.beam_size, 67 | patience=options.patience, 68 | length_penalty=options.length_penalty, 69 | max_length=self.max_length, 70 | suppress_blank=options.suppress_blank, 71 | suppress_tokens=options.suppress_tokens, 72 | ) 73 | 74 | tokens_batch = [x.sequences_ids[0] for x in result] 75 | 76 | def decode_batch(tokens: List[List[int]]) -> str: 77 | res = [] 78 | for tk in tokens: 79 | res.append([token for token in tk if token < tokenizer.eot]) 80 | # text_tokens = [token for token in tokens if token < self.eot] 81 | return tokenizer.tokenizer.decode_batch(res) 82 | 83 | text = decode_batch(tokens_batch) 84 | 85 | return text 86 | 87 | def encode(self, features: np.ndarray) -> ctranslate2.StorageView: 88 | # When the model is running on multiple GPUs, the encoder output should be moved 89 | # to the CPU since we don't know which GPU will handle the next job. 90 | to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 91 | # unsqueeze if batch size = 1 92 | if len(features.shape) == 2: 93 | features = np.expand_dims(features, 0) 94 | features = get_ctranslate2_storage(features) 95 | 96 | return self.model.encode(features, to_cpu=to_cpu) 97 | 98 | class FasterWhisperPipeline(Pipeline): 99 | """ 100 | Huggingface Pipeline wrapper for FasterWhisperModel. 101 | """ 102 | # TODO: 103 | # - add support for timestamp mode 104 | # - add support for custom inference kwargs 105 | 106 | def __init__( 107 | self, 108 | model: WhisperModel, 109 | vad, 110 | vad_params: dict, 111 | options: TranscriptionOptions, 112 | tokenizer: Optional[Tokenizer] = None, 113 | device: Union[int, str, "torch.device"] = -1, 114 | framework="pt", 115 | language: Optional[str] = None, 116 | suppress_numerals: bool = False, 117 | **kwargs, 118 | ): 119 | self.model = model 120 | self.tokenizer = tokenizer 121 | self.options = options 122 | self.preset_language = language 123 | self.suppress_numerals = suppress_numerals 124 | self._batch_size = kwargs.pop("batch_size", None) 125 | self._num_workers = 1 126 | self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs) 127 | self.call_count = 0 128 | self.framework = framework 129 | if self.framework == "pt": 130 | if isinstance(device, torch.device): 131 | self.device = device 132 | elif isinstance(device, str): 133 | self.device = torch.device(device) 134 | elif device < 0: 135 | self.device = torch.device("cpu") 136 | else: 137 | self.device = torch.device(f"cuda:{device}") 138 | else: 139 | self.device = device 140 | 141 | super(Pipeline, self).__init__() 142 | self.vad_model = vad 143 | self._vad_params = vad_params 144 | 145 | def _sanitize_parameters(self, **kwargs): 146 | preprocess_kwargs = {} 147 | if "tokenizer" in kwargs: 148 | preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"] 149 | return preprocess_kwargs, {}, {} 150 | 151 | def preprocess(self, audio): 152 | audio = audio['inputs'] 153 | model_n_mels = self.model.feat_kwargs.get("feature_size") 154 | features = log_mel_spectrogram( 155 | audio, 156 | n_mels=model_n_mels if model_n_mels is not None else 80, 157 | padding=N_SAMPLES - audio.shape[0], 158 | ) 159 | return {'inputs': features} 160 | 161 | def _forward(self, model_inputs): 162 | outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options) 163 | return {'text': outputs} 164 | 165 | def postprocess(self, model_outputs): 166 | return model_outputs 167 | 168 | def get_iterator( 169 | self, 170 | inputs, 171 | num_workers: int, 172 | batch_size: int, 173 | preprocess_params: dict, 174 | forward_params: dict, 175 | postprocess_params: dict, 176 | ): 177 | dataset = PipelineIterator(inputs, self.preprocess, preprocess_params) 178 | if "TOKENIZERS_PARALLELISM" not in os.environ: 179 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 180 | # TODO hack by collating feature_extractor and image_processor 181 | 182 | def stack(items): 183 | return {'inputs': torch.stack([x['inputs'] for x in items])} 184 | dataloader = torch.utils.data.DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=stack) 185 | model_iterator = PipelineIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size) 186 | final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params) 187 | return final_iterator 188 | 189 | def transcribe( 190 | self, 191 | audio: Union[str, np.ndarray], 192 | batch_size: Optional[int] = None, 193 | num_workers=0, 194 | language: Optional[str] = None, 195 | task: Optional[str] = None, 196 | chunk_size=30, 197 | print_progress=False, 198 | combined_progress=False, 199 | verbose=False, 200 | ) -> TranscriptionResult: 201 | if isinstance(audio, str): 202 | audio = load_audio(audio) 203 | 204 | def data(audio, segments): 205 | for seg in segments: 206 | f1 = int(seg['start'] * SAMPLE_RATE) 207 | f2 = int(seg['end'] * SAMPLE_RATE) 208 | # print(f2-f1) 209 | yield {'inputs': audio[f1:f2]} 210 | 211 | # Pre-process audio and merge chunks as defined by the respective VAD child class 212 | # In case vad_model is manually assigned (see 'load_model') follow the functionality of pyannote toolkit 213 | if issubclass(type(self.vad_model), Vad): 214 | waveform = self.vad_model.preprocess_audio(audio) 215 | merge_chunks = self.vad_model.merge_chunks 216 | else: 217 | waveform = Pyannote.preprocess_audio(audio) 218 | merge_chunks = Pyannote.merge_chunks 219 | 220 | vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE}) 221 | vad_segments = merge_chunks( 222 | vad_segments, 223 | chunk_size, 224 | onset=self._vad_params["vad_onset"], 225 | offset=self._vad_params["vad_offset"], 226 | ) 227 | if self.tokenizer is None: 228 | language = language or self.detect_language(audio) 229 | task = task or "transcribe" 230 | self.tokenizer = Tokenizer( 231 | self.model.hf_tokenizer, 232 | self.model.model.is_multilingual, 233 | task=task, 234 | language=language, 235 | ) 236 | else: 237 | language = language or self.tokenizer.language_code 238 | task = task or self.tokenizer.task 239 | if task != self.tokenizer.task or language != self.tokenizer.language_code: 240 | self.tokenizer = Tokenizer( 241 | self.model.hf_tokenizer, 242 | self.model.model.is_multilingual, 243 | task=task, 244 | language=language, 245 | ) 246 | 247 | if self.suppress_numerals: 248 | previous_suppress_tokens = self.options.suppress_tokens 249 | numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer) 250 | print(f"Suppressing numeral and symbol tokens") 251 | new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens 252 | new_suppressed_tokens = list(set(new_suppressed_tokens)) 253 | self.options = replace(self.options, suppress_tokens=new_suppressed_tokens) 254 | 255 | segments: List[SingleSegment] = [] 256 | batch_size = batch_size or self._batch_size 257 | total_segments = len(vad_segments) 258 | for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)): 259 | if print_progress: 260 | base_progress = ((idx + 1) / total_segments) * 100 261 | percent_complete = base_progress / 2 if combined_progress else base_progress 262 | print(f"Progress: {percent_complete:.2f}%...") 263 | text = out['text'] 264 | if batch_size in [0, 1, None]: 265 | text = text[0] 266 | if verbose: 267 | print(f"Transcript: [{round(vad_segments[idx]['start'], 3)} --> {round(vad_segments[idx]['end'], 3)}] {text}") 268 | segments.append( 269 | { 270 | "text": text, 271 | "start": round(vad_segments[idx]['start'], 3), 272 | "end": round(vad_segments[idx]['end'], 3) 273 | } 274 | ) 275 | 276 | # revert the tokenizer if multilingual inference is enabled 277 | if self.preset_language is None: 278 | self.tokenizer = None 279 | 280 | # revert suppressed tokens if suppress_numerals is enabled 281 | if self.suppress_numerals: 282 | self.options = replace(self.options, suppress_tokens=previous_suppress_tokens) 283 | 284 | return {"segments": segments, "language": language} 285 | 286 | def detect_language(self, audio: np.ndarray) -> str: 287 | if audio.shape[0] < N_SAMPLES: 288 | print("Warning: audio is shorter than 30s, language detection may be inaccurate.") 289 | model_n_mels = self.model.feat_kwargs.get("feature_size") 290 | segment = log_mel_spectrogram(audio[: N_SAMPLES], 291 | n_mels=model_n_mels if model_n_mels is not None else 80, 292 | padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0]) 293 | encoder_output = self.model.encode(segment) 294 | results = self.model.model.detect_language(encoder_output) 295 | language_token, language_probability = results[0][0] 296 | language = language_token[2:-2] 297 | print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...") 298 | return language 299 | 300 | 301 | def load_model( 302 | whisper_arch: str, 303 | device: str, 304 | device_index=0, 305 | compute_type="float16", 306 | asr_options: Optional[dict] = None, 307 | language: Optional[str] = None, 308 | vad_model: Optional[Vad]= None, 309 | vad_method: Optional[str] = "pyannote", 310 | vad_options: Optional[dict] = None, 311 | model: Optional[WhisperModel] = None, 312 | task="transcribe", 313 | download_root: Optional[str] = None, 314 | local_files_only=False, 315 | threads=4, 316 | ) -> FasterWhisperPipeline: 317 | """Load a Whisper model for inference. 318 | Args: 319 | whisper_arch - The name of the Whisper model to load. 320 | device - The device to load the model on. 321 | compute_type - The compute type to use for the model. 322 | vad_method - The vad method to use. vad_model has higher priority if is not None. 323 | options - A dictionary of options to use for the model. 324 | language - The language of the model. (use English for now) 325 | model - The WhisperModel instance to use. 326 | download_root - The root directory to download the model to. 327 | local_files_only - If `True`, avoid downloading the file and return the path to the local cached file if it exists. 328 | threads - The number of cpu threads to use per worker, e.g. will be multiplied by num workers. 329 | Returns: 330 | A Whisper pipeline. 331 | """ 332 | 333 | if whisper_arch.endswith(".en"): 334 | language = "en" 335 | 336 | model = model or WhisperModel(whisper_arch, 337 | device=device, 338 | device_index=device_index, 339 | compute_type=compute_type, 340 | download_root=download_root, 341 | local_files_only=local_files_only, 342 | cpu_threads=threads) 343 | if language is not None: 344 | tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language) 345 | else: 346 | print("No language specified, language will be first be detected for each audio file (increases inference time).") 347 | tokenizer = None 348 | 349 | default_asr_options = { 350 | "beam_size": 5, 351 | "best_of": 5, 352 | "patience": 1, 353 | "length_penalty": 1, 354 | "repetition_penalty": 1, 355 | "no_repeat_ngram_size": 0, 356 | "temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0], 357 | "compression_ratio_threshold": 2.4, 358 | "log_prob_threshold": -1.0, 359 | "no_speech_threshold": 0.6, 360 | "condition_on_previous_text": False, 361 | "prompt_reset_on_temperature": 0.5, 362 | "initial_prompt": None, 363 | "prefix": None, 364 | "suppress_blank": True, 365 | "suppress_tokens": [-1], 366 | "without_timestamps": True, 367 | "max_initial_timestamp": 0.0, 368 | "word_timestamps": False, 369 | "prepend_punctuations": "\"'“¿([{-", 370 | "append_punctuations": "\"'.。,,!!??::”)]}、", 371 | "multilingual": model.model.is_multilingual, 372 | "suppress_numerals": False, 373 | "max_new_tokens": None, 374 | "clip_timestamps": None, 375 | "hallucination_silence_threshold": None, 376 | "hotwords": None, 377 | } 378 | 379 | if asr_options is not None: 380 | default_asr_options.update(asr_options) 381 | 382 | suppress_numerals = default_asr_options["suppress_numerals"] 383 | del default_asr_options["suppress_numerals"] 384 | 385 | default_asr_options = TranscriptionOptions(**default_asr_options) 386 | 387 | default_vad_options = { 388 | "chunk_size": 30, # needed by silero since binarization happens before merge_chunks 389 | "vad_onset": 0.500, 390 | "vad_offset": 0.363 391 | } 392 | 393 | if vad_options is not None: 394 | default_vad_options.update(vad_options) 395 | 396 | # Note: manually assigned vad_model has higher priority than vad_method! 397 | if vad_model is not None: 398 | print("Use manually assigned vad_model. vad_method is ignored.") 399 | vad_model = vad_model 400 | else: 401 | if vad_method == "silero": 402 | vad_model = Silero(**default_vad_options) 403 | elif vad_method == "pyannote": 404 | vad_model = Pyannote(torch.device(device), use_auth_token=None, **default_vad_options) 405 | else: 406 | raise ValueError(f"Invalid vad_method: {vad_method}") 407 | 408 | return FasterWhisperPipeline( 409 | model=model, 410 | vad=vad_model, 411 | options=default_asr_options, 412 | tokenizer=tokenizer, 413 | language=language, 414 | suppress_numerals=suppress_numerals, 415 | vad_params=default_vad_options, 416 | ) 417 | -------------------------------------------------------------------------------- /whisperx/assets/mel_filters.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m-bain/whisperX/b3432412530ecb0cc5ac923f161da281e41d23d2/whisperx/assets/mel_filters.npz -------------------------------------------------------------------------------- /whisperx/assets/pytorch_model.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m-bain/whisperX/b3432412530ecb0cc5ac923f161da281e41d23d2/whisperx/assets/pytorch_model.bin -------------------------------------------------------------------------------- /whisperx/audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from functools import lru_cache 4 | from typing import Optional, Union 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from whisperx.utils import exact_div 11 | 12 | # hard-coded audio hyperparameters 13 | SAMPLE_RATE = 16000 14 | N_FFT = 400 15 | HOP_LENGTH = 160 16 | CHUNK_LENGTH = 30 17 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk 18 | N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input 19 | 20 | N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 21 | FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame 22 | TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token 23 | 24 | 25 | def load_audio(file: str, sr: int = SAMPLE_RATE) -> np.ndarray: 26 | """ 27 | Open an audio file and read as mono waveform, resampling as necessary 28 | 29 | Parameters 30 | ---------- 31 | file: str 32 | The audio file to open 33 | 34 | sr: int 35 | The sample rate to resample the audio if necessary 36 | 37 | Returns 38 | ------- 39 | A NumPy array containing the audio waveform, in float32 dtype. 40 | """ 41 | try: 42 | # Launches a subprocess to decode audio while down-mixing and resampling as necessary. 43 | # Requires the ffmpeg CLI to be installed. 44 | cmd = [ 45 | "ffmpeg", 46 | "-nostdin", 47 | "-threads", 48 | "0", 49 | "-i", 50 | file, 51 | "-f", 52 | "s16le", 53 | "-ac", 54 | "1", 55 | "-acodec", 56 | "pcm_s16le", 57 | "-ar", 58 | str(sr), 59 | "-", 60 | ] 61 | out = subprocess.run(cmd, capture_output=True, check=True).stdout 62 | except subprocess.CalledProcessError as e: 63 | raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e 64 | 65 | return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 66 | 67 | 68 | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): 69 | """ 70 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder. 71 | """ 72 | if torch.is_tensor(array): 73 | if array.shape[axis] > length: 74 | array = array.index_select( 75 | dim=axis, index=torch.arange(length, device=array.device) 76 | ) 77 | 78 | if array.shape[axis] < length: 79 | pad_widths = [(0, 0)] * array.ndim 80 | pad_widths[axis] = (0, length - array.shape[axis]) 81 | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) 82 | else: 83 | if array.shape[axis] > length: 84 | array = array.take(indices=range(length), axis=axis) 85 | 86 | if array.shape[axis] < length: 87 | pad_widths = [(0, 0)] * array.ndim 88 | pad_widths[axis] = (0, length - array.shape[axis]) 89 | array = np.pad(array, pad_widths) 90 | 91 | return array 92 | 93 | 94 | @lru_cache(maxsize=None) 95 | def mel_filters(device, n_mels: int) -> torch.Tensor: 96 | """ 97 | load the mel filterbank matrix for projecting STFT into a Mel spectrogram. 98 | Allows decoupling librosa dependency; saved using: 99 | 100 | np.savez_compressed( 101 | "mel_filters.npz", 102 | mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), 103 | ) 104 | """ 105 | assert n_mels in [80, 128], f"Unsupported n_mels: {n_mels}" 106 | with np.load( 107 | os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") 108 | ) as f: 109 | return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) 110 | 111 | 112 | def log_mel_spectrogram( 113 | audio: Union[str, np.ndarray, torch.Tensor], 114 | n_mels: int, 115 | padding: int = 0, 116 | device: Optional[Union[str, torch.device]] = None, 117 | ): 118 | """ 119 | Compute the log-Mel spectrogram of 120 | 121 | Parameters 122 | ---------- 123 | audio: Union[str, np.ndarray, torch.Tensor], shape = (*) 124 | The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz 125 | 126 | n_mels: int 127 | The number of Mel-frequency filters, only 80 is supported 128 | 129 | padding: int 130 | Number of zero samples to pad to the right 131 | 132 | device: Optional[Union[str, torch.device]] 133 | If given, the audio tensor is moved to this device before STFT 134 | 135 | Returns 136 | ------- 137 | torch.Tensor, shape = (80, n_frames) 138 | A Tensor that contains the Mel spectrogram 139 | """ 140 | if not torch.is_tensor(audio): 141 | if isinstance(audio, str): 142 | audio = load_audio(audio) 143 | audio = torch.from_numpy(audio) 144 | 145 | if device is not None: 146 | audio = audio.to(device) 147 | if padding > 0: 148 | audio = F.pad(audio, (0, padding)) 149 | window = torch.hann_window(N_FFT).to(audio.device) 150 | stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) 151 | magnitudes = stft[..., :-1].abs() ** 2 152 | 153 | filters = mel_filters(audio.device, n_mels) 154 | mel_spec = filters @ magnitudes 155 | 156 | log_spec = torch.clamp(mel_spec, min=1e-10).log10() 157 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) 158 | log_spec = (log_spec + 4.0) / 4.0 159 | return log_spec 160 | -------------------------------------------------------------------------------- /whisperx/conjunctions.py: -------------------------------------------------------------------------------- 1 | # conjunctions.py 2 | 3 | from typing import Set 4 | 5 | 6 | conjunctions_by_language = { 7 | 'en': {'and', 'whether', 'or', 'as', 'but', 'so', 'for', 'nor', 'which', 'yet', 'although', 'since', 'unless', 'when', 'while', 'because', 'if', 'how', 'that', 'than', 'who', 'where', 'what', 'near', 'before', 'after', 'across', 'through', 'until', 'once', 'whereas', 'even', 'both', 'either', 'neither', 'though'}, 8 | 'fr': {'et', 'ou', 'mais', 'parce', 'bien', 'pendant', 'quand', 'où', 'comme', 'si', 'que', 'avant', 'après', 'aussitôt', 'jusqu’à', 'à', 'malgré', 'donc', 'tant', 'puisque', 'ni', 'soit', 'bien', 'encore', 'dès', 'lorsque'}, 9 | 'de': {'und', 'oder', 'aber', 'weil', 'obwohl', 'während', 'wenn', 'wo', 'wie', 'dass', 'bevor', 'nachdem', 'sobald', 'bis', 'außer', 'trotzdem', 'also', 'sowie', 'indem', 'weder', 'sowohl', 'zwar', 'jedoch'}, 10 | 'es': {'y', 'o', 'pero', 'porque', 'aunque', 'sin', 'mientras', 'cuando', 'donde', 'como', 'si', 'que', 'antes', 'después', 'tan', 'hasta', 'a', 'a', 'por', 'ya', 'ni', 'sino'}, 11 | 'it': {'e', 'o', 'ma', 'perché', 'anche', 'mentre', 'quando', 'dove', 'come', 'se', 'che', 'prima', 'dopo', 'appena', 'fino', 'a', 'nonostante', 'quindi', 'poiché', 'né', 'ossia', 'cioè'}, 12 | 'ja': {'そして', 'または', 'しかし', 'なぜなら', 'もし', 'それとも', 'だから', 'それに', 'なのに', 'そのため', 'かつ', 'それゆえに', 'ならば', 'もしくは', 'ため'}, 13 | 'zh': {'和', '或', '但是', '因为', '任何', '也', '虽然', '而且', '所以', '如果', '除非', '尽管', '既然', '即使', '只要', '直到', '然后', '因此', '不但', '而是', '不过'}, 14 | 'nl': {'en', 'of', 'maar', 'omdat', 'hoewel', 'terwijl', 'wanneer', 'waar', 'zoals', 'als', 'dat', 'voordat', 'nadat', 'zodra', 'totdat', 'tenzij', 'ondanks', 'dus', 'zowel', 'noch', 'echter', 'toch'}, 15 | 'uk': {'та', 'або', 'але', 'тому', 'хоча', 'поки', 'бо', 'коли', 'де', 'як', 'якщо', 'що', 'перш', 'після', 'доки', 'незважаючи', 'тому', 'ані'}, 16 | 'pt': {'e', 'ou', 'mas', 'porque', 'embora', 'enquanto', 'quando', 'onde', 'como', 'se', 'que', 'antes', 'depois', 'assim', 'até', 'a', 'apesar', 'portanto', 'já', 'pois', 'nem', 'senão'}, 17 | 'ar': {'و', 'أو', 'لكن', 'لأن', 'مع', 'بينما', 'عندما', 'حيث', 'كما', 'إذا', 'الذي', 'قبل', 'بعد', 'فور', 'حتى', 'إلا', 'رغم', 'لذلك', 'بما'}, 18 | 'cs': {'a', 'nebo', 'ale', 'protože', 'ačkoli', 'zatímco', 'když', 'kde', 'jako', 'pokud', 'že', 'než', 'poté', 'jakmile', 'dokud', 'pokud ne', 'navzdory', 'tak', 'stejně', 'ani', 'tudíž'}, 19 | 'ru': {'и', 'или', 'но', 'потому', 'хотя', 'пока', 'когда', 'где', 'как', 'если', 'что', 'перед', 'после', 'несмотря', 'таким', 'также', 'ни', 'зато'}, 20 | 'pl': {'i', 'lub', 'ale', 'ponieważ', 'chociaż', 'podczas', 'kiedy', 'gdzie', 'jak', 'jeśli', 'że', 'zanim', 'po', 'jak tylko', 'dopóki', 'chyba', 'pomimo', 'więc', 'tak', 'ani', 'czyli'}, 21 | 'hu': {'és', 'vagy', 'de', 'mert', 'habár', 'míg', 'amikor', 'ahol', 'ahogy', 'ha', 'hogy', 'mielőtt', 'miután', 'amint', 'amíg', 'hacsak', 'ellenére', 'tehát', 'úgy', 'sem', 'vagyis'}, 22 | 'fi': {'ja', 'tai', 'mutta', 'koska', 'vaikka', 'kun', 'missä', 'kuten', 'jos', 'että', 'ennen', 'sen jälkeen', 'heti', 'kunnes', 'ellei', 'huolimatta', 'siis', 'sekä', 'eikä', 'vaan'}, 23 | 'fa': {'و', 'یا', 'اما', 'چون', 'اگرچه', 'در حالی', 'وقتی', 'کجا', 'چگونه', 'اگر', 'که', 'قبل', 'پس', 'به محض', 'تا زمانی', 'مگر', 'با وجود', 'پس', 'همچنین', 'نه'}, 24 | 'el': {'και', 'ή', 'αλλά', 'επειδή', 'αν', 'ενώ', 'όταν', 'όπου', 'όπως', 'αν', 'που', 'προτού', 'αφού', 'μόλις', 'μέχρι', 'εκτός', 'παρά', 'έτσι', 'όπως', 'ούτε', 'δηλαδή'}, 25 | 'tr': {'ve', 'veya', 'ama', 'çünkü', 'her ne', 'iken', 'nerede', 'nasıl', 'eğer', 'ki', 'önce', 'sonra', 'hemen', 'kadar', 'rağmen', 'hem', 'ne', 'yani'}, 26 | 'da': {'og', 'eller', 'men', 'fordi', 'selvom', 'mens', 'når', 'hvor', 'som', 'hvis', 'at', 'før', 'efter', 'indtil', 'medmindre', 'således', 'ligesom', 'hverken', 'altså'}, 27 | 'he': {'ו', 'או', 'אבל', 'כי', 'אף', 'בזמן', 'כאשר', 'היכן', 'כיצד', 'אם', 'ש', 'לפני', 'אחרי', 'ברגע', 'עד', 'אלא', 'למרות', 'לכן', 'כמו', 'לא', 'אז'}, 28 | 'vi': {'và', 'hoặc', 'nhưng', 'bởi', 'mặc', 'trong', 'khi', 'ở', 'như', 'nếu', 'rằng', 'trước', 'sau', 'ngay', 'cho', 'trừ', 'mặc', 'vì', 'giống', 'cũng', 'tức'}, 29 | 'ko': {'그리고', '또는','그런데','그래도', '이나', '결국', '마지막으로', '마찬가지로', '반면에', '아니면', '거나', '또는', '그럼에도', '그렇기', '때문에', '덧붙이자면', '게다가', '그러나', '고', '그래서', '랑', '한다면', '하지만', '무엇', '왜냐하면', '비록', '동안', '언제', '어디서', '어떻게', '만약', '그', '전에', '후에', '즉시', '까지', '아니라면', '불구하고', '따라서', '같은', '도'}, 30 | 'ur': {'اور', 'یا', 'مگر', 'کیونکہ', 'اگرچہ', 'جبکہ', 'جب', 'کہاں', 'کس طرح', 'اگر', 'کہ', 'سے پہلے', 'کے بعد', 'جیسے ہی', 'تک', 'اگر نہیں تو', 'کے باوجود', 'اس لئے', 'جیسے', 'نہ'}, 31 | 'hi': {'और', 'या', 'पर', 'तो', 'न', 'फिर', 'हालांकि', 'चूंकि', 'अगर', 'कैसे', 'वह', 'से', 'जो', 'जहां', 'क्या', 'नजदीक', 'पहले', 'बाद', 'के', 'पार', 'माध्यम', 'तक', 'एक', 'जबकि', 'यहां', 'तक', 'दोनों', 'या', 'न', 'हालांकि'} 32 | 33 | } 34 | 35 | commas_by_language = { 36 | 'ja': '、', 37 | 'zh': ',', 38 | 'fa': '،', 39 | 'ur': '،' 40 | } 41 | 42 | def get_conjunctions(lang_code: str) -> Set[str]: 43 | return conjunctions_by_language.get(lang_code, set()) 44 | 45 | 46 | def get_comma(lang_code: str) -> str: 47 | return commas_by_language.get(lang_code, ",") 48 | -------------------------------------------------------------------------------- /whisperx/diarize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from pyannote.audio import Pipeline 4 | from typing import Optional, Union 5 | import torch 6 | 7 | from whisperx.audio import load_audio, SAMPLE_RATE 8 | from whisperx.types import TranscriptionResult, AlignedTranscriptionResult 9 | 10 | 11 | class DiarizationPipeline: 12 | def __init__( 13 | self, 14 | model_name=None, 15 | use_auth_token=None, 16 | device: Optional[Union[str, torch.device]] = "cpu", 17 | ): 18 | if isinstance(device, str): 19 | device = torch.device(device) 20 | model_config = model_name or "pyannote/speaker-diarization-3.1" 21 | self.model = Pipeline.from_pretrained(model_config, use_auth_token=use_auth_token).to(device) 22 | 23 | def __call__( 24 | self, 25 | audio: Union[str, np.ndarray], 26 | num_speakers: Optional[int] = None, 27 | min_speakers: Optional[int] = None, 28 | max_speakers: Optional[int] = None, 29 | ): 30 | if isinstance(audio, str): 31 | audio = load_audio(audio) 32 | audio_data = { 33 | 'waveform': torch.from_numpy(audio[None, :]), 34 | 'sample_rate': SAMPLE_RATE 35 | } 36 | segments = self.model(audio_data, num_speakers = num_speakers, min_speakers=min_speakers, max_speakers=max_speakers) 37 | diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker']) 38 | diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start) 39 | diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end) 40 | return diarize_df 41 | 42 | 43 | def assign_word_speakers( 44 | diarize_df: pd.DataFrame, 45 | transcript_result: Union[AlignedTranscriptionResult, TranscriptionResult], 46 | fill_nearest=False, 47 | ) -> dict: 48 | transcript_segments = transcript_result["segments"] 49 | for seg in transcript_segments: 50 | # assign speaker to segment (if any) 51 | diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'], seg['start']) 52 | diarize_df['union'] = np.maximum(diarize_df['end'], seg['end']) - np.minimum(diarize_df['start'], seg['start']) 53 | # remove no hit, otherwise we look for closest (even negative intersection...) 54 | if not fill_nearest: 55 | dia_tmp = diarize_df[diarize_df['intersection'] > 0] 56 | else: 57 | dia_tmp = diarize_df 58 | if len(dia_tmp) > 0: 59 | # sum over speakers 60 | speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0] 61 | seg["speaker"] = speaker 62 | 63 | # assign speaker to words 64 | if 'words' in seg: 65 | for word in seg['words']: 66 | if 'start' in word: 67 | diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(diarize_df['start'], word['start']) 68 | diarize_df['union'] = np.maximum(diarize_df['end'], word['end']) - np.minimum(diarize_df['start'], word['start']) 69 | # remove no hit 70 | if not fill_nearest: 71 | dia_tmp = diarize_df[diarize_df['intersection'] > 0] 72 | else: 73 | dia_tmp = diarize_df 74 | if len(dia_tmp) > 0: 75 | # sum over speakers 76 | speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0] 77 | word["speaker"] = speaker 78 | 79 | return transcript_result 80 | 81 | 82 | class Segment: 83 | def __init__(self, start:int, end:int, speaker:Optional[str]=None): 84 | self.start = start 85 | self.end = end 86 | self.speaker = speaker 87 | -------------------------------------------------------------------------------- /whisperx/transcribe.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gc 3 | import os 4 | import warnings 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from whisperx.alignment import align, load_align_model 10 | from whisperx.asr import load_model 11 | from whisperx.audio import load_audio 12 | from whisperx.diarize import DiarizationPipeline, assign_word_speakers 13 | from whisperx.types import AlignedTranscriptionResult, TranscriptionResult 14 | from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE, get_writer 15 | 16 | 17 | def transcribe_task(args: dict, parser: argparse.ArgumentParser): 18 | """Transcription task to be called from CLI. 19 | 20 | Args: 21 | args: Dictionary of command-line arguments. 22 | parser: argparse.ArgumentParser object. 23 | """ 24 | # fmt: off 25 | 26 | model_name: str = args.pop("model") 27 | batch_size: int = args.pop("batch_size") 28 | model_dir: str = args.pop("model_dir") 29 | model_cache_only: bool = args.pop("model_cache_only") 30 | output_dir: str = args.pop("output_dir") 31 | output_format: str = args.pop("output_format") 32 | device: str = args.pop("device") 33 | device_index: int = args.pop("device_index") 34 | compute_type: str = args.pop("compute_type") 35 | verbose: bool = args.pop("verbose") 36 | 37 | # model_flush: bool = args.pop("model_flush") 38 | os.makedirs(output_dir, exist_ok=True) 39 | 40 | align_model: str = args.pop("align_model") 41 | interpolate_method: str = args.pop("interpolate_method") 42 | no_align: bool = args.pop("no_align") 43 | task: str = args.pop("task") 44 | if task == "translate": 45 | # translation cannot be aligned 46 | no_align = True 47 | 48 | return_char_alignments: bool = args.pop("return_char_alignments") 49 | 50 | hf_token: str = args.pop("hf_token") 51 | vad_method: str = args.pop("vad_method") 52 | vad_onset: float = args.pop("vad_onset") 53 | vad_offset: float = args.pop("vad_offset") 54 | 55 | chunk_size: int = args.pop("chunk_size") 56 | 57 | diarize: bool = args.pop("diarize") 58 | min_speakers: int = args.pop("min_speakers") 59 | max_speakers: int = args.pop("max_speakers") 60 | diarize_model_name: str = args.pop("diarize_model") 61 | print_progress: bool = args.pop("print_progress") 62 | 63 | if args["language"] is not None: 64 | args["language"] = args["language"].lower() 65 | if args["language"] not in LANGUAGES: 66 | if args["language"] in TO_LANGUAGE_CODE: 67 | args["language"] = TO_LANGUAGE_CODE[args["language"]] 68 | else: 69 | raise ValueError(f"Unsupported language: {args['language']}") 70 | 71 | if model_name.endswith(".en") and args["language"] != "en": 72 | if args["language"] is not None: 73 | warnings.warn( 74 | f"{model_name} is an English-only model but received '{args['language']}'; using English instead." 75 | ) 76 | args["language"] = "en" 77 | align_language = ( 78 | args["language"] if args["language"] is not None else "en" 79 | ) # default to loading english if not specified 80 | 81 | temperature = args.pop("temperature") 82 | if (increment := args.pop("temperature_increment_on_fallback")) is not None: 83 | temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment)) 84 | else: 85 | temperature = [temperature] 86 | 87 | faster_whisper_threads = 4 88 | if (threads := args.pop("threads")) > 0: 89 | torch.set_num_threads(threads) 90 | faster_whisper_threads = threads 91 | 92 | asr_options = { 93 | "beam_size": args.pop("beam_size"), 94 | "patience": args.pop("patience"), 95 | "length_penalty": args.pop("length_penalty"), 96 | "temperatures": temperature, 97 | "compression_ratio_threshold": args.pop("compression_ratio_threshold"), 98 | "log_prob_threshold": args.pop("logprob_threshold"), 99 | "no_speech_threshold": args.pop("no_speech_threshold"), 100 | "condition_on_previous_text": False, 101 | "initial_prompt": args.pop("initial_prompt"), 102 | "suppress_tokens": [int(x) for x in args.pop("suppress_tokens").split(",")], 103 | "suppress_numerals": args.pop("suppress_numerals"), 104 | } 105 | 106 | writer = get_writer(output_format, output_dir) 107 | word_options = ["highlight_words", "max_line_count", "max_line_width"] 108 | if no_align: 109 | for option in word_options: 110 | if args[option]: 111 | parser.error(f"--{option} not possible with --no_align") 112 | if args["max_line_count"] and not args["max_line_width"]: 113 | warnings.warn("--max_line_count has no effect without --max_line_width") 114 | writer_args = {arg: args.pop(arg) for arg in word_options} 115 | 116 | # Part 1: VAD & ASR Loop 117 | results = [] 118 | tmp_results = [] 119 | # model = load_model(model_name, device=device, download_root=model_dir) 120 | model = load_model( 121 | model_name, 122 | device=device, 123 | device_index=device_index, 124 | download_root=model_dir, 125 | compute_type=compute_type, 126 | language=args["language"], 127 | asr_options=asr_options, 128 | vad_method=vad_method, 129 | vad_options={ 130 | "chunk_size": chunk_size, 131 | "vad_onset": vad_onset, 132 | "vad_offset": vad_offset, 133 | }, 134 | task=task, 135 | local_files_only=model_cache_only, 136 | threads=faster_whisper_threads, 137 | ) 138 | 139 | for audio_path in args.pop("audio"): 140 | audio = load_audio(audio_path) 141 | # >> VAD & ASR 142 | print(">>Performing transcription...") 143 | result: TranscriptionResult = model.transcribe( 144 | audio, 145 | batch_size=batch_size, 146 | chunk_size=chunk_size, 147 | print_progress=print_progress, 148 | verbose=verbose, 149 | ) 150 | results.append((result, audio_path)) 151 | 152 | # Unload Whisper and VAD 153 | del model 154 | gc.collect() 155 | torch.cuda.empty_cache() 156 | 157 | # Part 2: Align Loop 158 | if not no_align: 159 | tmp_results = results 160 | results = [] 161 | align_model, align_metadata = load_align_model( 162 | align_language, device, model_name=align_model 163 | ) 164 | for result, audio_path in tmp_results: 165 | # >> Align 166 | if len(tmp_results) > 1: 167 | input_audio = audio_path 168 | else: 169 | # lazily load audio from part 1 170 | input_audio = audio 171 | 172 | if align_model is not None and len(result["segments"]) > 0: 173 | if result.get("language", "en") != align_metadata["language"]: 174 | # load new language 175 | print( 176 | f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language..." 177 | ) 178 | align_model, align_metadata = load_align_model( 179 | result["language"], device 180 | ) 181 | print(">>Performing alignment...") 182 | result: AlignedTranscriptionResult = align( 183 | result["segments"], 184 | align_model, 185 | align_metadata, 186 | input_audio, 187 | device, 188 | interpolate_method=interpolate_method, 189 | return_char_alignments=return_char_alignments, 190 | print_progress=print_progress, 191 | ) 192 | 193 | results.append((result, audio_path)) 194 | 195 | # Unload align model 196 | del align_model 197 | gc.collect() 198 | torch.cuda.empty_cache() 199 | 200 | # >> Diarize 201 | if diarize: 202 | if hf_token is None: 203 | print( 204 | "Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model..." 205 | ) 206 | tmp_results = results 207 | print(">>Performing diarization...") 208 | print(">>Using model:", diarize_model_name) 209 | results = [] 210 | diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device) 211 | for result, input_audio_path in tmp_results: 212 | diarize_segments = diarize_model( 213 | input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers 214 | ) 215 | result = assign_word_speakers(diarize_segments, result) 216 | results.append((result, input_audio_path)) 217 | # >> Write 218 | for result, audio_path in results: 219 | result["language"] = align_language 220 | writer(result, audio_path, writer_args) 221 | -------------------------------------------------------------------------------- /whisperx/types.py: -------------------------------------------------------------------------------- 1 | from typing import TypedDict, Optional, List, Tuple 2 | 3 | 4 | class SingleWordSegment(TypedDict): 5 | """ 6 | A single word of a speech. 7 | """ 8 | word: str 9 | start: float 10 | end: float 11 | score: float 12 | 13 | class SingleCharSegment(TypedDict): 14 | """ 15 | A single char of a speech. 16 | """ 17 | char: str 18 | start: float 19 | end: float 20 | score: float 21 | 22 | 23 | class SingleSegment(TypedDict): 24 | """ 25 | A single segment (up to multiple sentences) of a speech. 26 | """ 27 | 28 | start: float 29 | end: float 30 | text: str 31 | 32 | 33 | class SegmentData(TypedDict): 34 | """ 35 | Temporary processing data used during alignment. 36 | Contains cleaned and preprocessed data for each segment. 37 | """ 38 | clean_char: List[str] # Cleaned characters that exist in model dictionary 39 | clean_cdx: List[int] # Original indices of cleaned characters 40 | clean_wdx: List[int] # Indices of words containing valid characters 41 | sentence_spans: List[Tuple[int, int]] # Start and end indices of sentences 42 | 43 | 44 | class SingleAlignedSegment(TypedDict): 45 | """ 46 | A single segment (up to multiple sentences) of a speech with word alignment. 47 | """ 48 | 49 | start: float 50 | end: float 51 | text: str 52 | words: List[SingleWordSegment] 53 | chars: Optional[List[SingleCharSegment]] 54 | 55 | 56 | class TranscriptionResult(TypedDict): 57 | """ 58 | A list of segments and word segments of a speech. 59 | """ 60 | segments: List[SingleSegment] 61 | language: str 62 | 63 | 64 | class AlignedTranscriptionResult(TypedDict): 65 | """ 66 | A list of segments and word segments of a speech. 67 | """ 68 | segments: List[SingleAlignedSegment] 69 | word_segments: List[SingleWordSegment] 70 | -------------------------------------------------------------------------------- /whisperx/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | import sys 5 | import zlib 6 | from typing import Callable, Optional, TextIO 7 | 8 | LANGUAGES = { 9 | "en": "english", 10 | "zh": "chinese", 11 | "de": "german", 12 | "es": "spanish", 13 | "ru": "russian", 14 | "ko": "korean", 15 | "fr": "french", 16 | "ja": "japanese", 17 | "pt": "portuguese", 18 | "tr": "turkish", 19 | "pl": "polish", 20 | "ca": "catalan", 21 | "nl": "dutch", 22 | "ar": "arabic", 23 | "sv": "swedish", 24 | "it": "italian", 25 | "id": "indonesian", 26 | "hi": "hindi", 27 | "fi": "finnish", 28 | "vi": "vietnamese", 29 | "he": "hebrew", 30 | "uk": "ukrainian", 31 | "el": "greek", 32 | "ms": "malay", 33 | "cs": "czech", 34 | "ro": "romanian", 35 | "da": "danish", 36 | "hu": "hungarian", 37 | "ta": "tamil", 38 | "no": "norwegian", 39 | "th": "thai", 40 | "ur": "urdu", 41 | "hr": "croatian", 42 | "bg": "bulgarian", 43 | "lt": "lithuanian", 44 | "la": "latin", 45 | "mi": "maori", 46 | "ml": "malayalam", 47 | "cy": "welsh", 48 | "sk": "slovak", 49 | "te": "telugu", 50 | "fa": "persian", 51 | "lv": "latvian", 52 | "bn": "bengali", 53 | "sr": "serbian", 54 | "az": "azerbaijani", 55 | "sl": "slovenian", 56 | "kn": "kannada", 57 | "et": "estonian", 58 | "mk": "macedonian", 59 | "br": "breton", 60 | "eu": "basque", 61 | "is": "icelandic", 62 | "hy": "armenian", 63 | "ne": "nepali", 64 | "mn": "mongolian", 65 | "bs": "bosnian", 66 | "kk": "kazakh", 67 | "sq": "albanian", 68 | "sw": "swahili", 69 | "gl": "galician", 70 | "mr": "marathi", 71 | "pa": "punjabi", 72 | "si": "sinhala", 73 | "km": "khmer", 74 | "sn": "shona", 75 | "yo": "yoruba", 76 | "so": "somali", 77 | "af": "afrikaans", 78 | "oc": "occitan", 79 | "ka": "georgian", 80 | "be": "belarusian", 81 | "tg": "tajik", 82 | "sd": "sindhi", 83 | "gu": "gujarati", 84 | "am": "amharic", 85 | "yi": "yiddish", 86 | "lo": "lao", 87 | "uz": "uzbek", 88 | "fo": "faroese", 89 | "ht": "haitian creole", 90 | "ps": "pashto", 91 | "tk": "turkmen", 92 | "nn": "nynorsk", 93 | "mt": "maltese", 94 | "sa": "sanskrit", 95 | "lb": "luxembourgish", 96 | "my": "myanmar", 97 | "bo": "tibetan", 98 | "tl": "tagalog", 99 | "mg": "malagasy", 100 | "as": "assamese", 101 | "tt": "tatar", 102 | "haw": "hawaiian", 103 | "ln": "lingala", 104 | "ha": "hausa", 105 | "ba": "bashkir", 106 | "jw": "javanese", 107 | "su": "sundanese", 108 | "yue": "cantonese", 109 | } 110 | 111 | # language code lookup by name, with a few language aliases 112 | TO_LANGUAGE_CODE = { 113 | **{language: code for code, language in LANGUAGES.items()}, 114 | "burmese": "my", 115 | "valencian": "ca", 116 | "flemish": "nl", 117 | "haitian": "ht", 118 | "letzeburgesch": "lb", 119 | "pushto": "ps", 120 | "panjabi": "pa", 121 | "moldavian": "ro", 122 | "moldovan": "ro", 123 | "sinhalese": "si", 124 | "castilian": "es", 125 | } 126 | 127 | LANGUAGES_WITHOUT_SPACES = ["ja", "zh"] 128 | 129 | system_encoding = sys.getdefaultencoding() 130 | 131 | if system_encoding != "utf-8": 132 | 133 | def make_safe(string): 134 | # replaces any character not representable using the system default encoding with an '?', 135 | # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729). 136 | return string.encode(system_encoding, errors="replace").decode(system_encoding) 137 | 138 | else: 139 | 140 | def make_safe(string): 141 | # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding 142 | return string 143 | 144 | 145 | def exact_div(x, y): 146 | assert x % y == 0 147 | return x // y 148 | 149 | 150 | def str2bool(string): 151 | str2val = {"True": True, "False": False} 152 | if string in str2val: 153 | return str2val[string] 154 | else: 155 | raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") 156 | 157 | 158 | def optional_int(string): 159 | return None if string == "None" else int(string) 160 | 161 | 162 | def optional_float(string): 163 | return None if string == "None" else float(string) 164 | 165 | 166 | def compression_ratio(text) -> float: 167 | text_bytes = text.encode("utf-8") 168 | return len(text_bytes) / len(zlib.compress(text_bytes)) 169 | 170 | 171 | def format_timestamp( 172 | seconds: float, always_include_hours: bool = False, decimal_marker: str = "." 173 | ): 174 | assert seconds >= 0, "non-negative timestamp expected" 175 | milliseconds = round(seconds * 1000.0) 176 | 177 | hours = milliseconds // 3_600_000 178 | milliseconds -= hours * 3_600_000 179 | 180 | minutes = milliseconds // 60_000 181 | milliseconds -= minutes * 60_000 182 | 183 | seconds = milliseconds // 1_000 184 | milliseconds -= seconds * 1_000 185 | 186 | hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" 187 | return ( 188 | f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" 189 | ) 190 | 191 | 192 | class ResultWriter: 193 | extension: str 194 | 195 | def __init__(self, output_dir: str): 196 | self.output_dir = output_dir 197 | 198 | def __call__(self, result: dict, audio_path: str, options: dict): 199 | audio_basename = os.path.basename(audio_path) 200 | audio_basename = os.path.splitext(audio_basename)[0] 201 | output_path = os.path.join( 202 | self.output_dir, audio_basename + "." + self.extension 203 | ) 204 | 205 | with open(output_path, "w", encoding="utf-8") as f: 206 | self.write_result(result, file=f, options=options) 207 | 208 | def write_result(self, result: dict, file: TextIO, options: dict): 209 | raise NotImplementedError 210 | 211 | 212 | class WriteTXT(ResultWriter): 213 | extension: str = "txt" 214 | 215 | def write_result(self, result: dict, file: TextIO, options: dict): 216 | for segment in result["segments"]: 217 | speaker = segment.get("speaker") 218 | text = segment["text"].strip() 219 | if speaker is not None: 220 | print(f"[{speaker}]: {text}", file=file, flush=True) 221 | else: 222 | print(text, file=file, flush=True) 223 | 224 | 225 | class SubtitlesWriter(ResultWriter): 226 | always_include_hours: bool 227 | decimal_marker: str 228 | 229 | def iterate_result(self, result: dict, options: dict): 230 | raw_max_line_width: Optional[int] = options["max_line_width"] 231 | max_line_count: Optional[int] = options["max_line_count"] 232 | highlight_words: bool = options["highlight_words"] 233 | max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width 234 | preserve_segments = max_line_count is None or raw_max_line_width is None 235 | 236 | if len(result["segments"]) == 0: 237 | return 238 | 239 | def iterate_subtitles(): 240 | line_len = 0 241 | line_count = 1 242 | # the next subtitle to yield (a list of word timings with whitespace) 243 | subtitle: list[dict] = [] 244 | times: list[tuple] = [] 245 | last = result["segments"][0]["start"] 246 | for segment in result["segments"]: 247 | for i, original_timing in enumerate(segment["words"]): 248 | timing = original_timing.copy() 249 | long_pause = not preserve_segments 250 | if "start" in timing: 251 | long_pause = long_pause and timing["start"] - last > 3.0 252 | else: 253 | long_pause = False 254 | has_room = line_len + len(timing["word"]) <= max_line_width 255 | seg_break = i == 0 and len(subtitle) > 0 and preserve_segments 256 | if line_len > 0 and has_room and not long_pause and not seg_break: 257 | # line continuation 258 | line_len += len(timing["word"]) 259 | else: 260 | # new line 261 | timing["word"] = timing["word"].strip() 262 | if ( 263 | len(subtitle) > 0 264 | and max_line_count is not None 265 | and (long_pause or line_count >= max_line_count) 266 | or seg_break 267 | ): 268 | # subtitle break 269 | yield subtitle, times 270 | subtitle = [] 271 | times = [] 272 | line_count = 1 273 | elif line_len > 0: 274 | # line break 275 | line_count += 1 276 | timing["word"] = "\n" + timing["word"] 277 | line_len = len(timing["word"].strip()) 278 | subtitle.append(timing) 279 | times.append((segment["start"], segment["end"], segment.get("speaker"))) 280 | if "start" in timing: 281 | last = timing["start"] 282 | if len(subtitle) > 0: 283 | yield subtitle, times 284 | 285 | if "words" in result["segments"][0]: 286 | for subtitle, _ in iterate_subtitles(): 287 | sstart, ssend, speaker = _[0] 288 | subtitle_start = self.format_timestamp(sstart) 289 | subtitle_end = self.format_timestamp(ssend) 290 | if result["language"] in LANGUAGES_WITHOUT_SPACES: 291 | subtitle_text = "".join([word["word"] for word in subtitle]) 292 | else: 293 | subtitle_text = " ".join([word["word"] for word in subtitle]) 294 | has_timing = any(["start" in word for word in subtitle]) 295 | 296 | # add [$SPEAKER_ID]: to each subtitle if speaker is available 297 | prefix = "" 298 | if speaker is not None: 299 | prefix = f"[{speaker}]: " 300 | 301 | if highlight_words and has_timing: 302 | last = subtitle_start 303 | all_words = [timing["word"] for timing in subtitle] 304 | for i, this_word in enumerate(subtitle): 305 | if "start" in this_word: 306 | start = self.format_timestamp(this_word["start"]) 307 | end = self.format_timestamp(this_word["end"]) 308 | if last != start: 309 | yield last, start, prefix + subtitle_text 310 | 311 | yield start, end, prefix + " ".join( 312 | [ 313 | re.sub(r"^(\s*)(.*)$", r"\1\2", word) 314 | if j == i 315 | else word 316 | for j, word in enumerate(all_words) 317 | ] 318 | ) 319 | last = end 320 | else: 321 | yield subtitle_start, subtitle_end, prefix + subtitle_text 322 | else: 323 | for segment in result["segments"]: 324 | segment_start = self.format_timestamp(segment["start"]) 325 | segment_end = self.format_timestamp(segment["end"]) 326 | segment_text = segment["text"].strip().replace("-->", "->") 327 | if "speaker" in segment: 328 | segment_text = f"[{segment['speaker']}]: {segment_text}" 329 | yield segment_start, segment_end, segment_text 330 | 331 | def format_timestamp(self, seconds: float): 332 | return format_timestamp( 333 | seconds=seconds, 334 | always_include_hours=self.always_include_hours, 335 | decimal_marker=self.decimal_marker, 336 | ) 337 | 338 | 339 | class WriteVTT(SubtitlesWriter): 340 | extension: str = "vtt" 341 | always_include_hours: bool = False 342 | decimal_marker: str = "." 343 | 344 | def write_result(self, result: dict, file: TextIO, options: dict): 345 | print("WEBVTT\n", file=file) 346 | for start, end, text in self.iterate_result(result, options): 347 | print(f"{start} --> {end}\n{text}\n", file=file, flush=True) 348 | 349 | 350 | class WriteSRT(SubtitlesWriter): 351 | extension: str = "srt" 352 | always_include_hours: bool = True 353 | decimal_marker: str = "," 354 | 355 | def write_result(self, result: dict, file: TextIO, options: dict): 356 | for i, (start, end, text) in enumerate( 357 | self.iterate_result(result, options), start=1 358 | ): 359 | print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True) 360 | 361 | 362 | class WriteTSV(ResultWriter): 363 | """ 364 | Write a transcript to a file in TSV (tab-separated values) format containing lines like: 365 | \t\t 366 | 367 | Using integer milliseconds as start and end times means there's no chance of interference from 368 | an environment setting a language encoding that causes the decimal in a floating point number 369 | to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++. 370 | """ 371 | 372 | extension: str = "tsv" 373 | 374 | def write_result(self, result: dict, file: TextIO, options: dict): 375 | print("start", "end", "text", sep="\t", file=file) 376 | for segment in result["segments"]: 377 | print(round(1000 * segment["start"]), file=file, end="\t") 378 | print(round(1000 * segment["end"]), file=file, end="\t") 379 | print(segment["text"].strip().replace("\t", " "), file=file, flush=True) 380 | 381 | class WriteAudacity(ResultWriter): 382 | """ 383 | Write a transcript to a text file that audacity can import as labels. 384 | The extension used is "aud" to distinguish it from the txt file produced by WriteTXT. 385 | Yet this is not an audacity project but only a label file! 386 | 387 | Please note : Audacity uses seconds in timestamps not ms! 388 | Also there is no header expected. 389 | 390 | If speaker is provided it is prepended to the text between double square brackets [[]]. 391 | """ 392 | 393 | extension: str = "aud" 394 | 395 | def write_result(self, result: dict, file: TextIO, options: dict): 396 | ARROW = " " 397 | for segment in result["segments"]: 398 | print(segment["start"], file=file, end=ARROW) 399 | print(segment["end"], file=file, end=ARROW) 400 | print( ( ("[[" + segment["speaker"] + "]]") if "speaker" in segment else "") + segment["text"].strip().replace("\t", " "), file=file, flush=True) 401 | 402 | 403 | 404 | class WriteJSON(ResultWriter): 405 | extension: str = "json" 406 | 407 | def write_result(self, result: dict, file: TextIO, options: dict): 408 | json.dump(result, file, ensure_ascii=False) 409 | 410 | 411 | def get_writer( 412 | output_format: str, output_dir: str 413 | ) -> Callable[[dict, TextIO, dict], None]: 414 | writers = { 415 | "txt": WriteTXT, 416 | "vtt": WriteVTT, 417 | "srt": WriteSRT, 418 | "tsv": WriteTSV, 419 | "json": WriteJSON, 420 | } 421 | optional_writers = { 422 | "aud": WriteAudacity, 423 | } 424 | 425 | if output_format == "all": 426 | all_writers = [writer(output_dir) for writer in writers.values()] 427 | 428 | def write_all(result: dict, file: TextIO, options: dict): 429 | for writer in all_writers: 430 | writer(result, file, options) 431 | 432 | return write_all 433 | 434 | if output_format in optional_writers: 435 | return optional_writers[output_format](output_dir) 436 | return writers[output_format](output_dir) 437 | 438 | def interpolate_nans(x, method='nearest'): 439 | if x.notnull().sum() > 1: 440 | return x.interpolate(method=method).ffill().bfill() 441 | else: 442 | return x.ffill().bfill() 443 | -------------------------------------------------------------------------------- /whisperx/vads/__init__.py: -------------------------------------------------------------------------------- 1 | from whisperx.vads.pyannote import Pyannote as Pyannote 2 | from whisperx.vads.silero import Silero as Silero 3 | from whisperx.vads.vad import Vad as Vad 4 | -------------------------------------------------------------------------------- /whisperx/vads/pyannote.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Callable, Text, Union 3 | from typing import Optional 4 | 5 | import numpy as np 6 | import torch 7 | from pyannote.audio import Model 8 | from pyannote.audio.core.io import AudioFile 9 | from pyannote.audio.pipelines import VoiceActivityDetection 10 | from pyannote.audio.pipelines.utils import PipelineModel 11 | from pyannote.core import Annotation, SlidingWindowFeature 12 | from pyannote.core import Segment 13 | 14 | from whisperx.diarize import Segment as SegmentX 15 | from whisperx.vads.vad import Vad 16 | 17 | 18 | def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None): 19 | model_dir = torch.hub._get_torch_home() 20 | 21 | main_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 22 | 23 | os.makedirs(model_dir, exist_ok = True) 24 | if model_fp is None: 25 | # Dynamically resolve the path to the model file 26 | model_fp = os.path.join(main_dir, "assets", "pytorch_model.bin") 27 | model_fp = os.path.abspath(model_fp) # Ensure the path is absolute 28 | else: 29 | model_fp = os.path.abspath(model_fp) # Ensure any provided path is absolute 30 | 31 | # Check if the resolved model file exists 32 | if not os.path.exists(model_fp): 33 | raise FileNotFoundError(f"Model file not found at {model_fp}") 34 | 35 | if os.path.exists(model_fp) and not os.path.isfile(model_fp): 36 | raise RuntimeError(f"{model_fp} exists and is not a regular file") 37 | 38 | model_bytes = open(model_fp, "rb").read() 39 | 40 | vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token) 41 | hyperparameters = {"onset": vad_onset, 42 | "offset": vad_offset, 43 | "min_duration_on": 0.1, 44 | "min_duration_off": 0.1} 45 | vad_pipeline = VoiceActivitySegmentation(segmentation=vad_model, device=torch.device(device)) 46 | vad_pipeline.instantiate(hyperparameters) 47 | 48 | return vad_pipeline 49 | 50 | class Binarize: 51 | """Binarize detection scores using hysteresis thresholding, with min-cut operation 52 | to ensure not segments are longer than max_duration. 53 | 54 | Parameters 55 | ---------- 56 | onset : float, optional 57 | Onset threshold. Defaults to 0.5. 58 | offset : float, optional 59 | Offset threshold. Defaults to `onset`. 60 | min_duration_on : float, optional 61 | Remove active regions shorter than that many seconds. Defaults to 0s. 62 | min_duration_off : float, optional 63 | Fill inactive regions shorter than that many seconds. Defaults to 0s. 64 | pad_onset : float, optional 65 | Extend active regions by moving their start time by that many seconds. 66 | Defaults to 0s. 67 | pad_offset : float, optional 68 | Extend active regions by moving their end time by that many seconds. 69 | Defaults to 0s. 70 | max_duration: float 71 | The maximum length of an active segment, divides segment at timestamp with lowest score. 72 | Reference 73 | --------- 74 | Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of 75 | RNN-based Voice Activity Detection", InterSpeech 2015. 76 | 77 | Modified by Max Bain to include WhisperX's min-cut operation 78 | https://arxiv.org/abs/2303.00747 79 | 80 | Pyannote-audio 81 | """ 82 | 83 | def __init__( 84 | self, 85 | onset: float = 0.5, 86 | offset: Optional[float] = None, 87 | min_duration_on: float = 0.0, 88 | min_duration_off: float = 0.0, 89 | pad_onset: float = 0.0, 90 | pad_offset: float = 0.0, 91 | max_duration: float = float('inf') 92 | ): 93 | 94 | super().__init__() 95 | 96 | self.onset = onset 97 | self.offset = offset or onset 98 | 99 | self.pad_onset = pad_onset 100 | self.pad_offset = pad_offset 101 | 102 | self.min_duration_on = min_duration_on 103 | self.min_duration_off = min_duration_off 104 | 105 | self.max_duration = max_duration 106 | 107 | def __call__(self, scores: SlidingWindowFeature) -> Annotation: 108 | """Binarize detection scores 109 | Parameters 110 | ---------- 111 | scores : SlidingWindowFeature 112 | Detection scores. 113 | Returns 114 | ------- 115 | active : Annotation 116 | Binarized scores. 117 | """ 118 | 119 | num_frames, num_classes = scores.data.shape 120 | frames = scores.sliding_window 121 | timestamps = [frames[i].middle for i in range(num_frames)] 122 | 123 | # annotation meant to store 'active' regions 124 | active = Annotation() 125 | for k, k_scores in enumerate(scores.data.T): 126 | 127 | label = k if scores.labels is None else scores.labels[k] 128 | 129 | # initial state 130 | start = timestamps[0] 131 | is_active = k_scores[0] > self.onset 132 | curr_scores = [k_scores[0]] 133 | curr_timestamps = [start] 134 | t = start 135 | for t, y in zip(timestamps[1:], k_scores[1:]): 136 | # currently active 137 | if is_active: 138 | curr_duration = t - start 139 | if curr_duration > self.max_duration: 140 | search_after = len(curr_scores) // 2 141 | # divide segment 142 | min_score_div_idx = search_after + np.argmin(curr_scores[search_after:]) 143 | min_score_t = curr_timestamps[min_score_div_idx] 144 | region = Segment(start - self.pad_onset, min_score_t + self.pad_offset) 145 | active[region, k] = label 146 | start = curr_timestamps[min_score_div_idx] 147 | curr_scores = curr_scores[min_score_div_idx + 1:] 148 | curr_timestamps = curr_timestamps[min_score_div_idx + 1:] 149 | # switching from active to inactive 150 | elif y < self.offset: 151 | region = Segment(start - self.pad_onset, t + self.pad_offset) 152 | active[region, k] = label 153 | start = t 154 | is_active = False 155 | curr_scores = [] 156 | curr_timestamps = [] 157 | curr_scores.append(y) 158 | curr_timestamps.append(t) 159 | # currently inactive 160 | else: 161 | # switching from inactive to active 162 | if y > self.onset: 163 | start = t 164 | is_active = True 165 | 166 | # if active at the end, add final region 167 | if is_active: 168 | region = Segment(start - self.pad_onset, t + self.pad_offset) 169 | active[region, k] = label 170 | 171 | # because of padding, some active regions might be overlapping: merge them. 172 | # also: fill same speaker gaps shorter than min_duration_off 173 | if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0: 174 | if self.max_duration < float("inf"): 175 | raise NotImplementedError(f"This would break current max_duration param") 176 | active = active.support(collar=self.min_duration_off) 177 | 178 | # remove tracks shorter than min_duration_on 179 | if self.min_duration_on > 0: 180 | for segment, track in list(active.itertracks()): 181 | if segment.duration < self.min_duration_on: 182 | del active[segment, track] 183 | 184 | return active 185 | 186 | 187 | class VoiceActivitySegmentation(VoiceActivityDetection): 188 | def __init__( 189 | self, 190 | segmentation: PipelineModel = "pyannote/segmentation", 191 | fscore: bool = False, 192 | use_auth_token: Union[Text, None] = None, 193 | **inference_kwargs, 194 | ): 195 | 196 | super().__init__(segmentation=segmentation, fscore=fscore, use_auth_token=use_auth_token, **inference_kwargs) 197 | 198 | def apply(self, file: AudioFile, hook: Optional[Callable] = None) -> Annotation: 199 | """Apply voice activity detection 200 | 201 | Parameters 202 | ---------- 203 | file : AudioFile 204 | Processed file. 205 | hook : callable, optional 206 | Hook called after each major step of the pipeline with the following 207 | signature: hook("step_name", step_artefact, file=file) 208 | 209 | Returns 210 | ------- 211 | speech : Annotation 212 | Speech regions. 213 | """ 214 | 215 | # setup hook (e.g. for debugging purposes) 216 | hook = self.setup_hook(file, hook=hook) 217 | 218 | # apply segmentation model (only if needed) 219 | # output shape is (num_chunks, num_frames, 1) 220 | if self.training: 221 | if self.CACHED_SEGMENTATION in file: 222 | segmentations = file[self.CACHED_SEGMENTATION] 223 | else: 224 | segmentations = self._segmentation(file) 225 | file[self.CACHED_SEGMENTATION] = segmentations 226 | else: 227 | segmentations: SlidingWindowFeature = self._segmentation(file) 228 | 229 | return segmentations 230 | 231 | 232 | class Pyannote(Vad): 233 | 234 | def __init__(self, device, use_auth_token=None, model_fp=None, **kwargs): 235 | print(">>Performing voice activity detection using Pyannote...") 236 | super().__init__(kwargs['vad_onset']) 237 | self.vad_pipeline = load_vad_model(device, use_auth_token=use_auth_token, model_fp=model_fp) 238 | 239 | def __call__(self, audio: AudioFile, **kwargs): 240 | return self.vad_pipeline(audio) 241 | 242 | @staticmethod 243 | def preprocess_audio(audio): 244 | return torch.from_numpy(audio).unsqueeze(0) 245 | 246 | @staticmethod 247 | def merge_chunks(segments, 248 | chunk_size, 249 | onset: float = 0.5, 250 | offset: Optional[float] = None, 251 | ): 252 | assert chunk_size > 0 253 | binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset) 254 | segments = binarize(segments) 255 | segments_list = [] 256 | for speech_turn in segments.get_timeline(): 257 | segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN")) 258 | 259 | if len(segments_list) == 0: 260 | print("No active speech found in audio") 261 | return [] 262 | assert segments_list, "segments_list is empty." 263 | return Vad.merge_chunks(segments_list, chunk_size, onset, offset) 264 | -------------------------------------------------------------------------------- /whisperx/vads/silero.py: -------------------------------------------------------------------------------- 1 | from io import IOBase 2 | from pathlib import Path 3 | from typing import Mapping, Text 4 | from typing import Optional 5 | from typing import Union 6 | 7 | import torch 8 | 9 | from whisperx.diarize import Segment as SegmentX 10 | from whisperx.vads.vad import Vad 11 | 12 | AudioFile = Union[Text, Path, IOBase, Mapping] 13 | 14 | 15 | class Silero(Vad): 16 | # check again default values 17 | def __init__(self, **kwargs): 18 | print(">>Performing voice activity detection using Silero...") 19 | super().__init__(kwargs['vad_onset']) 20 | 21 | self.vad_onset = kwargs['vad_onset'] 22 | self.chunk_size = kwargs['chunk_size'] 23 | self.vad_pipeline, vad_utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', 24 | model='silero_vad', 25 | force_reload=False, 26 | onnx=False, 27 | trust_repo=True) 28 | (self.get_speech_timestamps, _, self.read_audio, _, _) = vad_utils 29 | 30 | def __call__(self, audio: AudioFile, **kwargs): 31 | """use silero to get segments of speech""" 32 | # Only accept 16000 Hz for now. 33 | # Note: Silero models support both 8000 and 16000 Hz. Although other values are not directly supported, 34 | # multiples of 16000 (e.g. 32000 or 48000) are cast to 16000 inside of the JIT model! 35 | sample_rate = audio["sample_rate"] 36 | if sample_rate != 16000: 37 | raise ValueError("Only 16000Hz sample rate is allowed") 38 | 39 | timestamps = self.get_speech_timestamps(audio["waveform"], 40 | model=self.vad_pipeline, 41 | sampling_rate=sample_rate, 42 | max_speech_duration_s=self.chunk_size, 43 | threshold=self.vad_onset 44 | # min_silence_duration_ms = self.min_duration_off/1000 45 | # min_speech_duration_ms = self.min_duration_on/1000 46 | # ... 47 | # See silero documentation for full option list 48 | ) 49 | return [SegmentX(i['start'] / sample_rate, i['end'] / sample_rate, "UNKNOWN") for i in timestamps] 50 | 51 | @staticmethod 52 | def preprocess_audio(audio): 53 | return audio 54 | 55 | @staticmethod 56 | def merge_chunks(segments_list, 57 | chunk_size, 58 | onset: float = 0.5, 59 | offset: Optional[float] = None, 60 | ): 61 | assert chunk_size > 0 62 | if len(segments_list) == 0: 63 | print("No active speech found in audio") 64 | return [] 65 | assert segments_list, "segments_list is empty." 66 | return Vad.merge_chunks(segments_list, chunk_size, onset, offset) 67 | -------------------------------------------------------------------------------- /whisperx/vads/vad.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pandas as pd 4 | from pyannote.core import Annotation, Segment 5 | 6 | 7 | class Vad: 8 | def __init__(self, vad_onset): 9 | if not (0 < vad_onset < 1): 10 | raise ValueError( 11 | "vad_onset is a decimal value between 0 and 1." 12 | ) 13 | 14 | @staticmethod 15 | def preprocess_audio(audio): 16 | pass 17 | 18 | # keep merge_chunks as static so it can be also used by manually assigned vad_model (see 'load_model') 19 | @staticmethod 20 | def merge_chunks(segments, 21 | chunk_size, 22 | onset: float, 23 | offset: Optional[float]): 24 | """ 25 | Merge operation described in paper 26 | """ 27 | curr_end = 0 28 | merged_segments = [] 29 | seg_idxs: list[tuple]= [] 30 | speaker_idxs: list[Optional[str]] = [] 31 | 32 | curr_start = segments[0].start 33 | for seg in segments: 34 | if seg.end - curr_start > chunk_size and curr_end - curr_start > 0: 35 | merged_segments.append({ 36 | "start": curr_start, 37 | "end": curr_end, 38 | "segments": seg_idxs, 39 | }) 40 | curr_start = seg.start 41 | seg_idxs = [] 42 | speaker_idxs = [] 43 | curr_end = seg.end 44 | seg_idxs.append((seg.start, seg.end)) 45 | speaker_idxs.append(seg.speaker) 46 | # add final 47 | merged_segments.append({ 48 | "start": curr_start, 49 | "end": curr_end, 50 | "segments": seg_idxs, 51 | }) 52 | 53 | return merged_segments 54 | 55 | # Unused function 56 | @staticmethod 57 | def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0): 58 | active = Annotation() 59 | for k, vad_t in enumerate(vad_arr): 60 | region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset) 61 | active[region, k] = 1 62 | 63 | if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0: 64 | active = active.support(collar=min_duration_off) 65 | 66 | # remove tracks shorter than min_duration_on 67 | if min_duration_on > 0: 68 | for segment, track in list(active.itertracks()): 69 | if segment.duration < min_duration_on: 70 | del active[segment, track] 71 | 72 | active = active.for_json() 73 | active_segs = pd.DataFrame([x['segment'] for x in active['content']]) 74 | return active_segs 75 | --------------------------------------------------------------------------------