├── .flake8 ├── .gitattributes ├── .github ├── dependabot.yml └── workflows │ ├── python-publish.yml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CHANGELOG.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── approach.png ├── data ├── README.md └── meanwhile.json ├── language-breakdown.svg ├── model-card.md ├── notebooks ├── LibriSpeech.ipynb └── Multilingual_ASR.ipynb ├── pyproject.toml ├── requirements.txt ├── tests ├── conftest.py ├── jfk.flac ├── test_audio.py ├── test_normalizer.py ├── test_timing.py ├── test_tokenizer.py └── test_transcribe.py └── whisper ├── __init__.py ├── __main__.py ├── assets ├── gpt2.tiktoken ├── mel_filters.npz └── multilingual.tiktoken ├── audio.py ├── decoding.py ├── model.py ├── normalizers ├── __init__.py ├── basic.py ├── english.json └── english.py ├── timing.py ├── tokenizer.py ├── transcribe.py ├── triton_ops.py ├── utils.py └── version.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | per-file-ignores = 3 | */__init__.py: F401 4 | 5 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Override jupyter in Github language stats for more accurate estimate of repo code languages 2 | # reference: https://github.com/github/linguist/blob/master/docs/overrides.md#generated-code 3 | *.ipynb linguist-generated 4 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # Keep GitHub Actions up to date with GitHub's Dependabot... 2 | # https://docs.github.com/en/code-security/dependabot/working-with-dependabot/keeping-your-actions-up-to-date-with-dependabot 3 | # https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file#package-ecosystem 4 | version: 2 5 | updates: 6 | - package-ecosystem: github-actions 7 | directory: / 8 | groups: 9 | github-actions: 10 | patterns: 11 | - "*" # Group all Actions updates into a single larger pull request 12 | schedule: 13 | interval: weekly 14 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - uses: actions-ecosystem/action-regex-match@v2 13 | id: regex-match 14 | with: 15 | text: ${{ github.event.head_commit.message }} 16 | regex: '^Release ([^ ]+)' 17 | - name: Set up Python 18 | uses: actions/setup-python@v5 19 | with: 20 | python-version: '3.8' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Release 26 | if: ${{ steps.regex-match.outputs.match != '' }} 27 | uses: softprops/action-gh-release@v2 28 | with: 29 | tag_name: v${{ steps.regex-match.outputs.group1 }} 30 | - name: Build and publish 31 | if: ${{ steps.regex-match.outputs.match != '' }} 32 | env: 33 | TWINE_USERNAME: __token__ 34 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 35 | run: | 36 | python -m build --sdist 37 | twine upload dist/* 38 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | branches: 8 | - main 9 | 10 | jobs: 11 | pre-commit: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v4 15 | - name: Fetch base branch 16 | run: git fetch origin ${{ github.base_ref }} 17 | - uses: actions/setup-python@v5 18 | with: 19 | python-version: "3.9" 20 | architecture: x64 21 | - name: Get pip cache dir 22 | id: pip-cache 23 | run: | 24 | echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT 25 | - name: pip/pre-commit cache 26 | uses: actions/cache@v4 27 | with: 28 | path: | 29 | ${{ steps.pip-cache.outputs.dir }} 30 | ~/.cache/pre-commit 31 | key: ${{ runner.os }}-pip-pre-commit-${{ hashFiles('**/.pre-commit-config.yaml') }} 32 | restore-keys: | 33 | ${{ runner.os }}-pip-pre-commit 34 | - name: pre-commit 35 | run: | 36 | pip install --upgrade pre-commit 37 | pre-commit install --install-hooks 38 | pre-commit run --all-files 39 | whisper-test: 40 | needs: pre-commit 41 | runs-on: ubuntu-latest 42 | strategy: 43 | fail-fast: false 44 | matrix: 45 | include: 46 | - python-version: '3.8' 47 | pytorch-version: 1.10.1 48 | numpy-requirement: "'numpy<2'" 49 | - python-version: '3.8' 50 | pytorch-version: 1.13.1 51 | numpy-requirement: "'numpy<2'" 52 | - python-version: '3.8' 53 | pytorch-version: 2.0.1 54 | numpy-requirement: "'numpy<2'" 55 | - python-version: '3.9' 56 | pytorch-version: 2.1.2 57 | numpy-requirement: "'numpy<2'" 58 | - python-version: '3.10' 59 | pytorch-version: 2.2.2 60 | numpy-requirement: "'numpy<2'" 61 | - python-version: '3.11' 62 | pytorch-version: 2.3.1 63 | numpy-requirement: "'numpy'" 64 | - python-version: '3.12' 65 | pytorch-version: 2.4.1 66 | numpy-requirement: "'numpy'" 67 | - python-version: '3.12' 68 | pytorch-version: 2.5.1 69 | numpy-requirement: "'numpy'" 70 | - python-version: '3.13' 71 | pytorch-version: 2.5.1 72 | numpy-requirement: "'numpy'" 73 | steps: 74 | - uses: conda-incubator/setup-miniconda@v3 75 | - run: conda install -n test ffmpeg python=${{ matrix.python-version }} 76 | - uses: actions/checkout@v4 77 | - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH 78 | - run: pip3 install .["dev"] ${{ matrix.numpy-requirement }} torch==${{ matrix.pytorch-version }}+cpu --index-url https://download.pytorch.org/whl/cpu --extra-index-url https://pypi.org/simple 79 | - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda' 80 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | *.egg-info 5 | .pytest_cache 6 | .ipynb_checkpoints 7 | 8 | thumbs.db 9 | .DS_Store 10 | .idea 11 | 12 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: check-json 6 | - id: end-of-file-fixer 7 | types: [file, python] 8 | - id: trailing-whitespace 9 | types: [file, python] 10 | - id: mixed-line-ending 11 | - id: check-added-large-files 12 | args: [--maxkb=4096] 13 | - repo: https://github.com/psf/black 14 | rev: 25.1.0 15 | hooks: 16 | - id: black 17 | - repo: https://github.com/pycqa/isort 18 | rev: 6.0.0 19 | hooks: 20 | - id: isort 21 | name: isort (python) 22 | args: ["--profile", "black", "-l", "88", "--trailing-comma", "--multi-line", "3"] 23 | - repo: https://github.com/pycqa/flake8.git 24 | rev: 7.1.1 25 | hooks: 26 | - id: flake8 27 | types: [python] 28 | args: ["--max-line-length", "88", "--ignore", "E203,E501,W503,W504"] 29 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # CHANGELOG 2 | 3 | ## [v20240930](https://github.com/openai/whisper/releases/tag/v20240930) 4 | 5 | * allowing numpy 2 in tests ([#2362](https://github.com/openai/whisper/pull/2362)) 6 | * large-v3-turbo model ([#2361](https://github.com/openai/whisper/pull/2361)) 7 | * test on python/pytorch versions up to 3.12 and 2.4.1 ([#2360](https://github.com/openai/whisper/pull/2360)) 8 | * using sdpa if available ([#2359](https://github.com/openai/whisper/pull/2359)) 9 | 10 | ## [v20240927](https://github.com/openai/whisper/releases/tag/v20240927) 11 | 12 | * pinning numpy<2 in tests ([#2332](https://github.com/openai/whisper/pull/2332)) 13 | * Relax triton requirements for compatibility with pytorch 2.4 and newer ([#2307](https://github.com/openai/whisper/pull/2307)) 14 | * Skip silence around hallucinations ([#1838](https://github.com/openai/whisper/pull/1838)) 15 | * Fix triton env marker ([#1887](https://github.com/openai/whisper/pull/1887)) 16 | 17 | ## [v20231117](https://github.com/openai/whisper/releases/tag/v20231117) 18 | 19 | * Relax triton requirements for compatibility with pytorch 2.1 and newer ([#1802](https://github.com/openai/whisper/pull/1802)) 20 | 21 | ## [v20231106](https://github.com/openai/whisper/releases/tag/v20231106) 22 | 23 | * large-v3 ([#1761](https://github.com/openai/whisper/pull/1761)) 24 | 25 | ## [v20231105](https://github.com/openai/whisper/releases/tag/v20231105) 26 | 27 | * remove tiktoken pin ([#1759](https://github.com/openai/whisper/pull/1759)) 28 | * docs: Disambiguation of the term "relative speed" in the README ([#1751](https://github.com/openai/whisper/pull/1751)) 29 | * allow_pickle=False while loading of mel matrix IN audio.py ([#1511](https://github.com/openai/whisper/pull/1511)) 30 | * handling transcribe exceptions. ([#1682](https://github.com/openai/whisper/pull/1682)) 31 | * Add new option to generate subtitles by a specific number of words ([#1729](https://github.com/openai/whisper/pull/1729)) 32 | * Fix exception when an audio file with no speech is provided ([#1396](https://github.com/openai/whisper/pull/1396)) 33 | 34 | ## [v20230918](https://github.com/openai/whisper/releases/tag/v20230918) 35 | 36 | * Add .pre-commit-config.yaml ([#1528](https://github.com/openai/whisper/pull/1528)) 37 | * fix doc of TextDecoder ([#1526](https://github.com/openai/whisper/pull/1526)) 38 | * Update model-card.md ([#1643](https://github.com/openai/whisper/pull/1643)) 39 | * word timing tweaks ([#1559](https://github.com/openai/whisper/pull/1559)) 40 | * Avoid rearranging all caches ([#1483](https://github.com/openai/whisper/pull/1483)) 41 | * Improve timestamp heuristics. ([#1461](https://github.com/openai/whisper/pull/1461)) 42 | * fix condition_on_previous_text ([#1224](https://github.com/openai/whisper/pull/1224)) 43 | * Fix numba depreceation notice ([#1233](https://github.com/openai/whisper/pull/1233)) 44 | * Updated README.md to provide more insight on BLEU and specific appendices ([#1236](https://github.com/openai/whisper/pull/1236)) 45 | * Avoid computing higher temperatures on no_speech segments ([#1279](https://github.com/openai/whisper/pull/1279)) 46 | * Dropped unused execute bit from mel_filters.npz. ([#1254](https://github.com/openai/whisper/pull/1254)) 47 | * Drop ffmpeg-python dependency and call ffmpeg directly. ([#1242](https://github.com/openai/whisper/pull/1242)) 48 | * Python 3.11 ([#1171](https://github.com/openai/whisper/pull/1171)) 49 | * Update decoding.py ([#1219](https://github.com/openai/whisper/pull/1219)) 50 | * Update decoding.py ([#1155](https://github.com/openai/whisper/pull/1155)) 51 | * Update README.md to reference tiktoken ([#1105](https://github.com/openai/whisper/pull/1105)) 52 | * Implement max line width and max line count, and make word highlighting optional ([#1184](https://github.com/openai/whisper/pull/1184)) 53 | * Squash long words at window and sentence boundaries. ([#1114](https://github.com/openai/whisper/pull/1114)) 54 | * python-publish.yml: bump actions version to fix node warning ([#1211](https://github.com/openai/whisper/pull/1211)) 55 | * Update tokenizer.py ([#1163](https://github.com/openai/whisper/pull/1163)) 56 | 57 | ## [v20230314](https://github.com/openai/whisper/releases/tag/v20230314) 58 | 59 | * abort find_alignment on empty input ([#1090](https://github.com/openai/whisper/pull/1090)) 60 | * Fix truncated words list when the replacement character is decoded ([#1089](https://github.com/openai/whisper/pull/1089)) 61 | * fix github language stats getting dominated by jupyter notebook ([#1076](https://github.com/openai/whisper/pull/1076)) 62 | * Fix alignment between the segments and the list of words ([#1087](https://github.com/openai/whisper/pull/1087)) 63 | * Use tiktoken ([#1044](https://github.com/openai/whisper/pull/1044)) 64 | 65 | ## [v20230308](https://github.com/openai/whisper/releases/tag/v20230308) 66 | 67 | * kwargs in decode() for convenience ([#1061](https://github.com/openai/whisper/pull/1061)) 68 | * fix all_tokens handling that caused more repetitions and discrepancy in JSON ([#1060](https://github.com/openai/whisper/pull/1060)) 69 | * fix typo in CHANGELOG.md 70 | 71 | ## [v20230307](https://github.com/openai/whisper/releases/tag/v20230307) 72 | 73 | * Fix the repetition/hallucination issue identified in #1046 ([#1052](https://github.com/openai/whisper/pull/1052)) 74 | * Use triton==2.0.0 ([#1053](https://github.com/openai/whisper/pull/1053)) 75 | * Install triton in x86_64 linux only ([#1051](https://github.com/openai/whisper/pull/1051)) 76 | * update setup.py to specify python >= 3.8 requirement 77 | 78 | ## [v20230306](https://github.com/openai/whisper/releases/tag/v20230306) 79 | 80 | * remove auxiliary audio extension ([#1021](https://github.com/openai/whisper/pull/1021)) 81 | * apply formatting with `black`, `isort`, and `flake8` ([#1038](https://github.com/openai/whisper/pull/1038)) 82 | * word-level timestamps in `transcribe()` ([#869](https://github.com/openai/whisper/pull/869)) 83 | * Decoding improvements ([#1033](https://github.com/openai/whisper/pull/1033)) 84 | * Update README.md ([#894](https://github.com/openai/whisper/pull/894)) 85 | * Fix infinite loop caused by incorrect timestamp tokens prediction ([#914](https://github.com/openai/whisper/pull/914)) 86 | * drop python 3.7 support ([#889](https://github.com/openai/whisper/pull/889)) 87 | 88 | ## [v20230124](https://github.com/openai/whisper/releases/tag/v20230124) 89 | 90 | * handle printing even if sys.stdout.buffer is not available ([#887](https://github.com/openai/whisper/pull/887)) 91 | * Add TSV formatted output in transcript, using integer start/end time in milliseconds ([#228](https://github.com/openai/whisper/pull/228)) 92 | * Added `--output_format` option ([#333](https://github.com/openai/whisper/pull/333)) 93 | * Handle `XDG_CACHE_HOME` properly for `download_root` ([#864](https://github.com/openai/whisper/pull/864)) 94 | * use stdout for printing transcription progress ([#867](https://github.com/openai/whisper/pull/867)) 95 | * Fix bug where mm is mistakenly replaced with hmm in e.g. 20mm ([#659](https://github.com/openai/whisper/pull/659)) 96 | * print '?' if a letter can't be encoded using the system default encoding ([#859](https://github.com/openai/whisper/pull/859)) 97 | 98 | ## [v20230117](https://github.com/openai/whisper/releases/tag/v20230117) 99 | 100 | The first versioned release available on [PyPI](https://pypi.org/project/openai-whisper/) 101 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | include README.md 3 | include LICENSE 4 | include whisper/assets/* 5 | include whisper/normalizers/english.json 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Whisper 2 | 3 | [[Blog]](https://openai.com/blog/whisper) 4 | [[Paper]](https://arxiv.org/abs/2212.04356) 5 | [[Model card]](https://github.com/openai/whisper/blob/main/model-card.md) 6 | [[Colab example]](https://colab.research.google.com/github/openai/whisper/blob/master/notebooks/LibriSpeech.ipynb) 7 | 8 | Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multitasking model that can perform multilingual speech recognition, speech translation, and language identification. 9 | 10 | 11 | ## Approach 12 | 13 | ![Approach](https://raw.githubusercontent.com/openai/whisper/main/approach.png) 14 | 15 | A Transformer sequence-to-sequence model is trained on various speech processing tasks, including multilingual speech recognition, speech translation, spoken language identification, and voice activity detection. These tasks are jointly represented as a sequence of tokens to be predicted by the decoder, allowing a single model to replace many stages of a traditional speech-processing pipeline. The multitask training format uses a set of special tokens that serve as task specifiers or classification targets. 16 | 17 | 18 | ## Setup 19 | 20 | We used Python 3.9.9 and [PyTorch](https://pytorch.org/) 1.10.1 to train and test our models, but the codebase is expected to be compatible with Python 3.8-3.11 and recent PyTorch versions. The codebase also depends on a few Python packages, most notably [OpenAI's tiktoken](https://github.com/openai/tiktoken) for their fast tokenizer implementation. You can download and install (or update to) the latest release of Whisper with the following command: 21 | 22 | pip install -U openai-whisper 23 | 24 | Alternatively, the following command will pull and install the latest commit from this repository, along with its Python dependencies: 25 | 26 | pip install git+https://github.com/openai/whisper.git 27 | 28 | To update the package to the latest version of this repository, please run: 29 | 30 | pip install --upgrade --no-deps --force-reinstall git+https://github.com/openai/whisper.git 31 | 32 | It also requires the command-line tool [`ffmpeg`](https://ffmpeg.org/) to be installed on your system, which is available from most package managers: 33 | 34 | ```bash 35 | # on Ubuntu or Debian 36 | sudo apt update && sudo apt install ffmpeg 37 | 38 | # on Arch Linux 39 | sudo pacman -S ffmpeg 40 | 41 | # on MacOS using Homebrew (https://brew.sh/) 42 | brew install ffmpeg 43 | 44 | # on Windows using Chocolatey (https://chocolatey.org/) 45 | choco install ffmpeg 46 | 47 | # on Windows using Scoop (https://scoop.sh/) 48 | scoop install ffmpeg 49 | ``` 50 | 51 | You may need [`rust`](http://rust-lang.org) installed as well, in case [tiktoken](https://github.com/openai/tiktoken) does not provide a pre-built wheel for your platform. If you see installation errors during the `pip install` command above, please follow the [Getting started page](https://www.rust-lang.org/learn/get-started) to install Rust development environment. Additionally, you may need to configure the `PATH` environment variable, e.g. `export PATH="$HOME/.cargo/bin:$PATH"`. If the installation fails with `No module named 'setuptools_rust'`, you need to install `setuptools_rust`, e.g. by running: 52 | 53 | ```bash 54 | pip install setuptools-rust 55 | ``` 56 | 57 | 58 | ## Available models and languages 59 | 60 | There are six model sizes, four with English-only versions, offering speed and accuracy tradeoffs. 61 | Below are the names of the available models and their approximate memory requirements and inference speed relative to the large model. 62 | The relative speeds below are measured by transcribing English speech on a A100, and the real-world speed may vary significantly depending on many factors including the language, the speaking speed, and the available hardware. 63 | 64 | | Size | Parameters | English-only model | Multilingual model | Required VRAM | Relative speed | 65 | |:------:|:----------:|:------------------:|:------------------:|:-------------:|:--------------:| 66 | | tiny | 39 M | `tiny.en` | `tiny` | ~1 GB | ~10x | 67 | | base | 74 M | `base.en` | `base` | ~1 GB | ~7x | 68 | | small | 244 M | `small.en` | `small` | ~2 GB | ~4x | 69 | | medium | 769 M | `medium.en` | `medium` | ~5 GB | ~2x | 70 | | large | 1550 M | N/A | `large` | ~10 GB | 1x | 71 | | turbo | 809 M | N/A | `turbo` | ~6 GB | ~8x | 72 | 73 | The `.en` models for English-only applications tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models. 74 | Additionally, the `turbo` model is an optimized version of `large-v3` that offers faster transcription speed with a minimal degradation in accuracy. 75 | 76 | Whisper's performance varies widely depending on the language. The figure below shows a performance breakdown of `large-v3` and `large-v2` models by language, using WERs (word error rates) or CER (character error rates, shown in *Italic*) evaluated on the Common Voice 15 and Fleurs datasets. Additional WER/CER metrics corresponding to the other models and datasets can be found in Appendix D.1, D.2, and D.4 of [the paper](https://arxiv.org/abs/2212.04356), as well as the BLEU (Bilingual Evaluation Understudy) scores for translation in Appendix D.3. 77 | 78 | ![WER breakdown by language](https://github.com/openai/whisper/assets/266841/f4619d66-1058-4005-8f67-a9d811b77c62) 79 | 80 | 81 | 82 | ## Command-line usage 83 | 84 | The following command will transcribe speech in audio files, using the `turbo` model: 85 | 86 | whisper audio.flac audio.mp3 audio.wav --model turbo 87 | 88 | The default setting (which selects the `turbo` model) works well for transcribing English. To transcribe an audio file containing non-English speech, you can specify the language using the `--language` option: 89 | 90 | whisper japanese.wav --language Japanese 91 | 92 | Adding `--task translate` will translate the speech into English: 93 | 94 | whisper japanese.wav --language Japanese --task translate 95 | 96 | Run the following to view all available options: 97 | 98 | whisper --help 99 | 100 | See [tokenizer.py](https://github.com/openai/whisper/blob/main/whisper/tokenizer.py) for the list of all available languages. 101 | 102 | 103 | ## Python usage 104 | 105 | Transcription can also be performed within Python: 106 | 107 | ```python 108 | import whisper 109 | 110 | model = whisper.load_model("turbo") 111 | result = model.transcribe("audio.mp3") 112 | print(result["text"]) 113 | ``` 114 | 115 | Internally, the `transcribe()` method reads the entire file and processes the audio with a sliding 30-second window, performing autoregressive sequence-to-sequence predictions on each window. 116 | 117 | Below is an example usage of `whisper.detect_language()` and `whisper.decode()` which provide lower-level access to the model. 118 | 119 | ```python 120 | import whisper 121 | 122 | model = whisper.load_model("turbo") 123 | 124 | # load audio and pad/trim it to fit 30 seconds 125 | audio = whisper.load_audio("audio.mp3") 126 | audio = whisper.pad_or_trim(audio) 127 | 128 | # make log-Mel spectrogram and move to the same device as the model 129 | mel = whisper.log_mel_spectrogram(audio, n_mels=model.dims.n_mels).to(model.device) 130 | 131 | # detect the spoken language 132 | _, probs = model.detect_language(mel) 133 | print(f"Detected language: {max(probs, key=probs.get)}") 134 | 135 | # decode the audio 136 | options = whisper.DecodingOptions() 137 | result = whisper.decode(model, mel, options) 138 | 139 | # print the recognized text 140 | print(result.text) 141 | ``` 142 | 143 | ## More examples 144 | 145 | Please use the [🙌 Show and tell](https://github.com/openai/whisper/discussions/categories/show-and-tell) category in Discussions for sharing more example usages of Whisper and third-party extensions such as web demos, integrations with other tools, ports for different platforms, etc. 146 | 147 | 148 | ## License 149 | 150 | Whisper's code and model weights are released under the MIT License. See [LICENSE](https://github.com/openai/whisper/blob/main/LICENSE) for further details. 151 | -------------------------------------------------------------------------------- /approach.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/whisper/dd985ac4b90cafeef8712f2998d62c59c3e62d22/approach.png -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | This directory supplements the paper with more details on how we prepared the data for evaluation, to help replicate our experiments. 2 | 3 | ## Short-form English-only datasets 4 | 5 | ### LibriSpeech 6 | 7 | We used the test-clean and test-other splits from the [LibriSpeech ASR corpus](https://www.openslr.org/12). 8 | 9 | ### TED-LIUM 3 10 | 11 | We used the test split of [TED-LIUM Release 3](https://www.openslr.org/51/), using the segmented manual transcripts included in the release. 12 | 13 | ### Common Voice 5.1 14 | 15 | We downloaded the English subset of Common Voice Corpus 5.1 from [the official website](https://commonvoice.mozilla.org/en/datasets) 16 | 17 | ### Artie 18 | 19 | We used the [Artie bias corpus](https://github.com/artie-inc/artie-bias-corpus). This is a subset of the Common Voice dataset. 20 | 21 | ### CallHome & Switchboard 22 | 23 | We used the two corpora from [LDC2002S09](https://catalog.ldc.upenn.edu/LDC2002S09) and [LDC2002T43](https://catalog.ldc.upenn.edu/LDC2002T43) and followed the [eval2000_data_prep.sh](https://github.com/kaldi-asr/kaldi/blob/master/egs/fisher_swbd/s5/local/eval2000_data_prep.sh) script for preprocessing. The `wav.scp` files can be converted to WAV files with the following bash commands: 24 | 25 | ```bash 26 | mkdir -p wav 27 | while read name cmd; do 28 | echo $name 29 | echo ${cmd/\|/} wav/$name.wav | bash 30 | done < wav.scp 31 | ``` 32 | 33 | 34 | ### WSJ 35 | 36 | We used [LDC93S6B](https://catalog.ldc.upenn.edu/LDC93S6B) and [LDC94S13B](https://catalog.ldc.upenn.edu/LDC94S13B) and followed the [s5 recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/wsj/s5) to preprocess the dataset. 37 | 38 | ### CORAAL 39 | 40 | We used the 231 interviews from [CORAAL (v. 2021.07)](https://oraal.uoregon.edu/coraal) and used the segmentations from [the FairSpeech project](https://github.com/stanford-policylab/asr-disparities/blob/master/input/CORAAL_transcripts.csv). 41 | 42 | ### CHiME-6 43 | 44 | We downloaded the [CHiME-5 dataset](https://spandh.dcs.shef.ac.uk//chime_challenge/CHiME5/download.html) and followed the stage 0 of the [s5_track1 recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/chime6/s5_track1) to create the CHiME-6 dataset which fixes synchronization. We then used the binaural recordings (`*_P??.wav`) and the corresponding transcripts. 45 | 46 | ### AMI-IHM, AMI-SDM1 47 | 48 | We preprocessed the [AMI Corpus](https://groups.inf.ed.ac.uk/ami/corpus/overview.shtml) by following the stage 0 and 2 of the [s5b recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/ami/s5b). 49 | 50 | 51 | ## Long-form English-only datasets 52 | 53 | ### TED-LIUM 3 54 | 55 | To create a long-form transcription dataset from the [TED-LIUM3](https://www.openslr.org/51/) dataset, we sliced the audio between the beginning of the first labeled segment and the end of the last labeled segment of each talk, and we used the concatenated text as the label. Below are the timestamps used for slicing each of the 11 TED talks in the test split. 56 | 57 | | Filename | Begin time (s) | End time (s) | 58 | |---------------------|----------------|--------------| 59 | | DanBarber_2010 | 16.09 | 1116.24 | 60 | | JaneMcGonigal_2010 | 15.476 | 1187.61 | 61 | | BillGates_2010 | 15.861 | 1656.94 | 62 | | TomWujec_2010U | 16.26 | 402.17 | 63 | | GaryFlake_2010 | 16.06 | 367.14 | 64 | | EricMead_2009P | 18.434 | 536.44 | 65 | | MichaelSpecter_2010 | 16.11 | 979.312 | 66 | | DanielKahneman_2010 | 15.8 | 1199.44 | 67 | | AimeeMullins_2009P | 17.82 | 1296.59 | 68 | | JamesCameron_2010 | 16.75 | 1010.65 | 69 | | RobertGupta_2010U | 16.8 | 387.03 | 70 | 71 | ### Meanwhile 72 | 73 | This dataset consists of 64 segments from The Late Show with Stephen Colbert. The YouTube video ID, start and end timestamps, and the labels can be found in [meanwhile.json](meanwhile.json). The labels are collected from the closed-caption data for each video and corrected with manual inspection. 74 | 75 | ### Rev16 76 | 77 | We use a subset of 16 files from the 30 podcast episodes in [Rev.AI's Podcast Transcription Benchmark](https://www.rev.ai/blog/podcast-transcription-benchmark-part-1/), after finding that there are multiple cases where a significant portion of the audio and the labels did not match, mostly on the parts introducing the sponsors. We selected 16 episodes that do not have this error, whose "file number" are: 78 | 79 | 3 4 9 10 11 14 17 18 20 21 23 24 26 27 29 32 80 | 81 | ### Kincaid46 82 | 83 | This dataset consists of 46 audio files and the corresponding transcripts compiled in the blog article [Which automatic transcription service is the most accurate - 2018](https://medium.com/descript/which-automatic-transcription-service-is-the-most-accurate-2018-2e859b23ed19) by Jason Kincaid. We used the 46 audio files and reference transcripts from the Airtable widget in the article. 84 | 85 | For the human transcription benchmark in the paper, we use a subset of 25 examples from this data, whose "Ref ID" are: 86 | 87 | 2 4 5 8 9 10 12 13 14 16 19 21 23 25 26 28 29 30 33 35 36 37 42 43 45 88 | 89 | ### Earnings-21, Earnings-22 90 | 91 | For these datasets, we used the files available in [the speech-datasets repository](https://github.com/revdotcom/speech-datasets), as of their `202206` version. 92 | 93 | ### CORAAL 94 | 95 | We used the 231 interviews from [CORAAL (v. 2021.07)](https://oraal.uoregon.edu/coraal) and used the full-length interview files and transcripts. 96 | 97 | 98 | ## Multilingual datasets 99 | 100 | ### Multilingual LibriSpeech 101 | 102 | We used the test splits from each language in [the Multilingual LibriSpeech (MLS) corpus](https://www.openslr.org/94/). 103 | 104 | ### Fleurs 105 | 106 | We collected audio files and transcripts using the implementation available as [HuggingFace datasets](https://huggingface.co/datasets/google/fleurs/blob/main/fleurs.py). To use as a translation dataset, we matched the numerical utterance IDs to find the corresponding transcript in English. 107 | 108 | ### VoxPopuli 109 | 110 | We used the `get_asr_data.py` script from [the official repository](https://github.com/facebookresearch/voxpopuli) to collect the ASR data in 14 languages. 111 | 112 | ### Common Voice 9 113 | 114 | We downloaded the Common Voice Corpus 9 from [the official website](https://commonvoice.mozilla.org/en/datasets) 115 | 116 | ### CoVOST 2 117 | 118 | We collected the `X into English` data collected using [the official repository](https://github.com/facebookresearch/covost). 119 | -------------------------------------------------------------------------------- /model-card.md: -------------------------------------------------------------------------------- 1 | # Model Card: Whisper 2 | 3 | This is the official codebase for running the automatic speech recognition (ASR) models (Whisper models) trained and released by OpenAI. 4 | 5 | Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993), we're providing some information about the automatic speech recognition model. More information on how these models were trained and evaluated can be found [in the paper](https://arxiv.org/abs/2212.04356). 6 | 7 | 8 | ## Model Details 9 | 10 | The Whisper models are trained for speech recognition and translation tasks, capable of transcribing speech audio into the text in the language it is spoken (ASR) as well as translated into English (speech translation). Researchers at OpenAI developed the models to study the robustness of speech processing systems trained under large-scale weak supervision. There are 9 models of different sizes and capabilities, summarized in the following table. 11 | 12 | | Size | Parameters | English-only model | Multilingual model | 13 | |:------:|:----------:|:------------------:|:------------------:| 14 | | tiny | 39 M | ✓ | ✓ | 15 | | base | 74 M | ✓ | ✓ | 16 | | small | 244 M | ✓ | ✓ | 17 | | medium | 769 M | ✓ | ✓ | 18 | | large | 1550 M | | ✓ | 19 | | turbo | 798 M | | ✓ | 20 | 21 | In December 2022, we [released an improved large model named `large-v2`](https://github.com/openai/whisper/discussions/661), and `large-v3` in November 2023. 22 | Additionally, we've added a `turbo` model in September 2024 which is optimized for inference speed. 23 | 24 | 25 | ### Release date 26 | 27 | September 2022 (original series), December 2022 (`large-v2`), November 2023 (`large-v3`), September 2024 (`large-v3-turbo`) 28 | 29 | ### Model type 30 | 31 | Sequence-to-sequence ASR (automatic speech recognition) and speech translation model 32 | 33 | ### Paper & samples 34 | 35 | [Paper](https://arxiv.org/abs/2212.04356) / [Blog](https://openai.com/blog/whisper) 36 | 37 | 38 | ## Model Use 39 | 40 | ### Evaluated Use 41 | 42 | The primary intended users of these models are AI researchers studying the robustness, generalization, capabilities, biases, and constraints of the current model. However, Whisper is also potentially quite useful as an ASR solution for developers, especially for English speech recognition. We recognize that once models are released, it is impossible to restrict access to only “intended” uses or to draw reasonable guidelines around what is or is not research. 43 | 44 | The models are primarily trained and evaluated on ASR and speech translation to English tasks. They show strong ASR results in ~10 languages. They may exhibit additional capabilities, particularly if fine-tuned on certain tasks like voice activity detection, speaker classification, or speaker diarization but have not been robustly evaluated in these areas. We strongly recommend that users perform robust evaluations of the models in a particular context and domain before deploying them. 45 | 46 | In particular, we caution against using Whisper models to transcribe recordings of individuals taken without their consent or purporting to use these models for any kind of subjective classification. We recommend against use in high-risk domains like decision-making contexts, where flaws in accuracy can lead to pronounced flaws in outcomes. The models are intended to transcribe and translate speech, use of the model for classification is not only not evaluated but also not appropriate, particularly to infer human attributes. 47 | 48 | 49 | ## Training Data 50 | 51 | The models are trained on 680,000 hours of audio and the corresponding transcripts collected from the internet. 65% of this data (or 438,000 hours) represents English-language audio and matched English transcripts, roughly 18% (or 126,000 hours) represents non-English audio and English transcripts, while the final 17% (or 117,000 hours) represents non-English audio and the corresponding transcript. This non-English data represents 98 different languages. 52 | 53 | As discussed in [the accompanying paper](https://arxiv.org/abs/2212.04356), we see that performance on transcription in a given language is directly correlated with the amount of training data we employ in that language. 54 | 55 | 56 | ## Performance and Limitations 57 | 58 | Our studies show that, over many existing ASR systems, the models exhibit improved robustness to accents, background noise, and technical language, as well as zero-shot translation from multiple languages into English; and that accuracy on speech recognition and translation is near the state-of-the-art level. 59 | 60 | However, because the models are trained in a weakly supervised manner using large-scale noisy data, the predictions may include texts that are not actually spoken in the audio input (i.e. hallucination). We hypothesize that this happens because, given their general knowledge of language, the models combine trying to predict the next word in audio with trying to transcribe the audio itself. 61 | 62 | Our models perform unevenly across languages, and we observe lower accuracy on low-resource and/or low-discoverability languages or languages where we have less training data. The models also exhibit disparate performance on different accents and dialects of particular languages, which may include a higher word error rate across speakers of different genders, races, ages, or other demographic criteria. Our full evaluation results are presented in [the paper accompanying this release](https://arxiv.org/abs/2212.04356). 63 | 64 | In addition, the sequence-to-sequence architecture of the model makes it prone to generating repetitive texts, which can be mitigated to some degree by beam search and temperature scheduling but not perfectly. Further analysis of these limitations is provided in [the paper](https://arxiv.org/abs/2212.04356). It is likely that this behavior and hallucinations may be worse in lower-resource and/or lower-discoverability languages. 65 | 66 | 67 | ## Broader Implications 68 | 69 | We anticipate that Whisper models’ transcription capabilities may be used for improving accessibility tools. While Whisper models cannot be used for real-time transcription out of the box – their speed and size suggest that others may be able to build applications on top of them that allow for near-real-time speech recognition and translation. The real value of beneficial applications built on top of Whisper models suggests that the disparate performance of these models may have real economic implications. 70 | 71 | There are also potential dual-use concerns that come with releasing Whisper. While we hope the technology will be used primarily for beneficial purposes, making ASR technology more accessible could enable more actors to build capable surveillance technologies or scale up existing surveillance efforts, as the speed and accuracy allow for affordable automatic transcription and translation of large volumes of audio communication. Moreover, these models may have some capabilities to recognize specific individuals out of the box, which in turn presents safety concerns related both to dual use and disparate performance. In practice, we expect that the cost of transcription is not the limiting factor of scaling up surveillance projects. 72 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | build-backend = "setuptools.build_meta" 3 | 4 | requires = [ "setuptools>=61.2" ] 5 | 6 | [project] 7 | name = "openai-whisper" 8 | description = "Robust Speech Recognition via Large-Scale Weak Supervision" 9 | readme.content-type = "text/markdown" 10 | readme.file = "README.md" 11 | license = { text = "MIT" } 12 | authors = [ { name = "OpenAI" } ] 13 | requires-python = ">=3.8" 14 | classifiers = [ 15 | "Programming Language :: Python :: 3 :: Only", 16 | "Programming Language :: Python :: 3.8", 17 | "Programming Language :: Python :: 3.9", 18 | "Programming Language :: Python :: 3.10", 19 | "Programming Language :: Python :: 3.11", 20 | "Programming Language :: Python :: 3.12", 21 | "Programming Language :: Python :: 3.13", 22 | ] 23 | dynamic = [ "version" ] 24 | dependencies = [ 25 | "more-itertools", 26 | "numba", 27 | "numpy", 28 | "tiktoken", 29 | "torch", 30 | "tqdm", 31 | "triton>=2; (platform_machine=='x86_64' and sys_platform=='linux') or sys_platform=='linux2'", 32 | ] 33 | optional-dependencies.dev = [ "black", "flake8", "isort", "pytest", "scipy" ] 34 | urls = { Homepage = "https://github.com/openai/whisper" } 35 | scripts.whisper = "whisper.transcribe:cli" 36 | 37 | [tool.setuptools] 38 | py-modules = [ "whisper" ] 39 | include-package-data = true 40 | 41 | [tool.setuptools.dynamic] 42 | version = { attr = "whisper.version.__version__" } 43 | 44 | [tool.setuptools.packages.find] 45 | exclude = [ "tests*" ] 46 | namespaces = false 47 | 48 | [tool.black] 49 | 50 | [tool.isort] 51 | profile = "black" 52 | include_trailing_comma = true 53 | line_length = 88 54 | multi_line_output = 3 55 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numba 2 | numpy 3 | torch 4 | tqdm 5 | more-itertools 6 | tiktoken 7 | triton>=2.0.0;platform_machine=="x86_64" and sys_platform=="linux" or sys_platform=="linux2" 8 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import random as rand 2 | 3 | import numpy 4 | import pytest 5 | 6 | 7 | def pytest_configure(config): 8 | config.addinivalue_line("markers", "requires_cuda") 9 | 10 | 11 | @pytest.fixture 12 | def random(): 13 | rand.seed(42) 14 | numpy.random.seed(42) 15 | -------------------------------------------------------------------------------- /tests/jfk.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/whisper/dd985ac4b90cafeef8712f2998d62c59c3e62d22/tests/jfk.flac -------------------------------------------------------------------------------- /tests/test_audio.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | import numpy as np 4 | 5 | from whisper.audio import SAMPLE_RATE, load_audio, log_mel_spectrogram 6 | 7 | 8 | def test_audio(): 9 | audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac") 10 | audio = load_audio(audio_path) 11 | assert audio.ndim == 1 12 | assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12 13 | assert 0 < audio.std() < 1 14 | 15 | mel_from_audio = log_mel_spectrogram(audio) 16 | mel_from_file = log_mel_spectrogram(audio_path) 17 | 18 | assert np.allclose(mel_from_audio, mel_from_file) 19 | assert mel_from_audio.max() - mel_from_audio.min() <= 2.0 20 | -------------------------------------------------------------------------------- /tests/test_normalizer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from whisper.normalizers import EnglishTextNormalizer 4 | from whisper.normalizers.english import ( 5 | EnglishNumberNormalizer, 6 | EnglishSpellingNormalizer, 7 | ) 8 | 9 | 10 | @pytest.mark.parametrize("std", [EnglishNumberNormalizer(), EnglishTextNormalizer()]) 11 | def test_number_normalizer(std): 12 | assert std("two") == "2" 13 | assert std("thirty one") == "31" 14 | assert std("five twenty four") == "524" 15 | assert std("nineteen ninety nine") == "1999" 16 | assert std("twenty nineteen") == "2019" 17 | 18 | assert std("two point five million") == "2500000" 19 | assert std("four point two billions") == "4200000000s" 20 | assert std("200 thousand") == "200000" 21 | assert std("200 thousand dollars") == "$200000" 22 | assert std("$20 million") == "$20000000" 23 | assert std("€52.4 million") == "€52400000" 24 | assert std("£77 thousands") == "£77000s" 25 | 26 | assert std("two double o eight") == "2008" 27 | 28 | assert std("three thousand twenty nine") == "3029" 29 | assert std("forty three thousand two hundred sixty") == "43260" 30 | assert std("forty three thousand two hundred and sixty") == "43260" 31 | 32 | assert std("nineteen fifties") == "1950s" 33 | assert std("thirty first") == "31st" 34 | assert std("thirty three thousand and three hundred and thirty third") == "33333rd" 35 | 36 | assert std("three billion") == "3000000000" 37 | assert std("millions") == "1000000s" 38 | 39 | assert std("july third twenty twenty") == "july 3rd 2020" 40 | assert std("august twenty sixth twenty twenty one") == "august 26th 2021" 41 | assert std("3 14") == "3 14" 42 | assert std("3.14") == "3.14" 43 | assert std("3 point 2") == "3.2" 44 | assert std("3 point 14") == "3.14" 45 | assert std("fourteen point 4") == "14.4" 46 | assert std("two point two five dollars") == "$2.25" 47 | assert std("two hundred million dollars") == "$200000000" 48 | assert std("$20.1 million") == "$20100000" 49 | 50 | assert std("ninety percent") == "90%" 51 | assert std("seventy six per cent") == "76%" 52 | 53 | assert std("double oh seven") == "007" 54 | assert std("double zero seven") == "007" 55 | assert std("nine one one") == "911" 56 | assert std("nine double one") == "911" 57 | assert std("one triple oh one") == "10001" 58 | 59 | assert std("two thousandth") == "2000th" 60 | assert std("thirty two thousandth") == "32000th" 61 | 62 | assert std("minus 500") == "-500" 63 | assert std("positive twenty thousand") == "+20000" 64 | 65 | assert std("two dollars and seventy cents") == "$2.70" 66 | assert std("3 cents") == "¢3" 67 | assert std("$0.36") == "¢36" 68 | assert std("three euros and sixty five cents") == "€3.65" 69 | 70 | assert std("three and a half million") == "3500000" 71 | assert std("forty eight and a half dollars") == "$48.5" 72 | assert std("b747") == "b 747" 73 | assert std("10 th") == "10th" 74 | assert std("10th") == "10th" 75 | 76 | 77 | def test_spelling_normalizer(): 78 | std = EnglishSpellingNormalizer() 79 | 80 | assert std("mobilisation") == "mobilization" 81 | assert std("cancelation") == "cancellation" 82 | 83 | 84 | def test_text_normalizer(): 85 | std = EnglishTextNormalizer() 86 | assert std("Let's") == "let us" 87 | assert std("he's like") == "he is like" 88 | assert std("she's been like") == "she has been like" 89 | assert std("10km") == "10 km" 90 | assert std("10mm") == "10 mm" 91 | assert std("RC232") == "rc 232" 92 | 93 | assert ( 94 | std("Mr. Park visited Assoc. Prof. Kim Jr.") 95 | == "mister park visited associate professor kim junior" 96 | ) 97 | -------------------------------------------------------------------------------- /tests/test_timing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import scipy.ndimage 4 | import torch 5 | 6 | from whisper.timing import dtw_cpu, dtw_cuda, median_filter 7 | 8 | sizes = [ 9 | (10, 20), 10 | (32, 16), 11 | (123, 1500), 12 | (234, 189), 13 | ] 14 | shapes = [ 15 | (10,), 16 | (1, 15), 17 | (4, 5, 345), 18 | (6, 12, 240, 512), 19 | ] 20 | 21 | 22 | @pytest.mark.parametrize("N, M", sizes) 23 | def test_dtw(N: int, M: int): 24 | steps = np.concatenate([np.zeros(N - 1), np.ones(M - 1)]) 25 | np.random.shuffle(steps) 26 | x = np.random.random((N, M)).astype(np.float32) 27 | 28 | i, j, k = 0, 0, 0 29 | trace = [] 30 | while True: 31 | x[i, j] -= 1 32 | trace.append((i, j)) 33 | 34 | if k == len(steps): 35 | break 36 | 37 | if k + 1 < len(steps) and steps[k] != steps[k + 1]: 38 | i += 1 39 | j += 1 40 | k += 2 41 | continue 42 | 43 | if steps[k] == 0: 44 | i += 1 45 | if steps[k] == 1: 46 | j += 1 47 | k += 1 48 | 49 | trace = np.array(trace).T 50 | dtw_trace = dtw_cpu(x) 51 | 52 | assert np.allclose(trace, dtw_trace) 53 | 54 | 55 | @pytest.mark.requires_cuda 56 | @pytest.mark.parametrize("N, M", sizes) 57 | def test_dtw_cuda_equivalence(N: int, M: int): 58 | x_numpy = np.random.randn(N, M).astype(np.float32) 59 | x_cuda = torch.from_numpy(x_numpy).cuda() 60 | 61 | trace_cpu = dtw_cpu(x_numpy) 62 | trace_cuda = dtw_cuda(x_cuda) 63 | 64 | assert np.allclose(trace_cpu, trace_cuda) 65 | 66 | 67 | @pytest.mark.parametrize("shape", shapes) 68 | def test_median_filter(shape): 69 | x = torch.randn(*shape) 70 | 71 | for filter_width in [3, 5, 7, 13]: 72 | filtered = median_filter(x, filter_width) 73 | 74 | # using np.pad to reflect-pad, because Scipy's behavior is different near the edges. 75 | pad_width = filter_width // 2 76 | padded_x = np.pad( 77 | x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect" 78 | ) 79 | scipy_filtered = scipy.ndimage.median_filter( 80 | padded_x, [1] * (x.ndim - 1) + [filter_width] 81 | ) 82 | scipy_filtered = scipy_filtered[..., pad_width:-pad_width] 83 | 84 | assert np.allclose(filtered, scipy_filtered) 85 | 86 | 87 | @pytest.mark.requires_cuda 88 | @pytest.mark.parametrize("shape", shapes) 89 | def test_median_filter_equivalence(shape): 90 | x = torch.randn(*shape) 91 | 92 | for filter_width in [3, 5, 7, 13]: 93 | filtered_cpu = median_filter(x, filter_width) 94 | filtered_gpu = median_filter(x.cuda(), filter_width).cpu() 95 | 96 | assert np.allclose(filtered_cpu, filtered_gpu) 97 | -------------------------------------------------------------------------------- /tests/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from whisper.tokenizer import get_tokenizer 4 | 5 | 6 | @pytest.mark.parametrize("multilingual", [True, False]) 7 | def test_tokenizer(multilingual): 8 | tokenizer = get_tokenizer(multilingual=False) 9 | assert tokenizer.sot in tokenizer.sot_sequence 10 | assert len(tokenizer.all_language_codes) == len(tokenizer.all_language_tokens) 11 | assert all(c < tokenizer.timestamp_begin for c in tokenizer.all_language_tokens) 12 | 13 | 14 | def test_multilingual_tokenizer(): 15 | gpt2_tokenizer = get_tokenizer(multilingual=False) 16 | multilingual_tokenizer = get_tokenizer(multilingual=True) 17 | 18 | text = "다람쥐 헌 쳇바퀴에 타고파" 19 | gpt2_tokens = gpt2_tokenizer.encode(text) 20 | multilingual_tokens = multilingual_tokenizer.encode(text) 21 | 22 | assert gpt2_tokenizer.decode(gpt2_tokens) == text 23 | assert multilingual_tokenizer.decode(multilingual_tokens) == text 24 | assert len(gpt2_tokens) > len(multilingual_tokens) 25 | 26 | 27 | def test_split_on_unicode(): 28 | multilingual_tokenizer = get_tokenizer(multilingual=True) 29 | 30 | tokens = [8404, 871, 287, 6, 246, 526, 3210, 20378] 31 | words, word_tokens = multilingual_tokenizer.split_tokens_on_unicode(tokens) 32 | 33 | assert words == [" elle", " est", " l", "'", "\ufffd", "é", "rit", "oire"] 34 | assert word_tokens == [[8404], [871], [287], [6], [246], [526], [3210], [20378]] 35 | -------------------------------------------------------------------------------- /tests/test_transcribe.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | import torch 5 | 6 | import whisper 7 | from whisper.tokenizer import get_tokenizer 8 | 9 | 10 | @pytest.mark.parametrize("model_name", whisper.available_models()) 11 | def test_transcribe(model_name: str): 12 | device = "cuda" if torch.cuda.is_available() else "cpu" 13 | model = whisper.load_model(model_name).to(device) 14 | audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac") 15 | 16 | language = "en" if model_name.endswith(".en") else None 17 | result = model.transcribe( 18 | audio_path, language=language, temperature=0.0, word_timestamps=True 19 | ) 20 | assert result["language"] == "en" 21 | assert result["text"] == "".join([s["text"] for s in result["segments"]]) 22 | 23 | transcription = result["text"].lower() 24 | assert "my fellow americans" in transcription 25 | assert "your country" in transcription 26 | assert "do for you" in transcription 27 | 28 | tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages) 29 | all_tokens = [t for s in result["segments"] for t in s["tokens"]] 30 | assert tokenizer.decode(all_tokens) == result["text"] 31 | assert tokenizer.decode_with_timestamps(all_tokens).startswith("<|0.00|>") 32 | 33 | timing_checked = False 34 | for segment in result["segments"]: 35 | for timing in segment["words"]: 36 | assert timing["start"] < timing["end"] 37 | if timing["word"].strip(" ,") == "Americans": 38 | assert timing["start"] <= 1.8 39 | assert timing["end"] >= 1.8 40 | timing_checked = True 41 | 42 | assert timing_checked 43 | -------------------------------------------------------------------------------- /whisper/__init__.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import io 3 | import os 4 | import urllib 5 | import warnings 6 | from typing import List, Optional, Union 7 | 8 | import torch 9 | from tqdm import tqdm 10 | 11 | from .audio import load_audio, log_mel_spectrogram, pad_or_trim 12 | from .decoding import DecodingOptions, DecodingResult, decode, detect_language 13 | from .model import ModelDimensions, Whisper 14 | from .transcribe import transcribe 15 | from .version import __version__ 16 | 17 | _MODELS = { 18 | "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", 19 | "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", 20 | "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", 21 | "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", 22 | "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", 23 | "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", 24 | "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", 25 | "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", 26 | "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", 27 | "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", 28 | "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", 29 | "large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", 30 | "large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt", 31 | "turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt", 32 | } 33 | 34 | # base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are 35 | # highly correlated to the word-level timing, i.e. the alignment between audio and text tokens. 36 | _ALIGNMENT_HEADS = { 37 | "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00", 38 | "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO", 39 | "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00", 40 | "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00", 42 | "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P%R7%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9", 45 | "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", 47 | "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", 48 | "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", 49 | "large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`", 50 | "turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`", 51 | } 52 | 53 | 54 | def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: 55 | os.makedirs(root, exist_ok=True) 56 | 57 | expected_sha256 = url.split("/")[-2] 58 | download_target = os.path.join(root, os.path.basename(url)) 59 | 60 | if os.path.exists(download_target) and not os.path.isfile(download_target): 61 | raise RuntimeError(f"{download_target} exists and is not a regular file") 62 | 63 | if os.path.isfile(download_target): 64 | with open(download_target, "rb") as f: 65 | model_bytes = f.read() 66 | if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: 67 | return model_bytes if in_memory else download_target 68 | else: 69 | warnings.warn( 70 | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" 71 | ) 72 | 73 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 74 | with tqdm( 75 | total=int(source.info().get("Content-Length")), 76 | ncols=80, 77 | unit="iB", 78 | unit_scale=True, 79 | unit_divisor=1024, 80 | ) as loop: 81 | while True: 82 | buffer = source.read(8192) 83 | if not buffer: 84 | break 85 | 86 | output.write(buffer) 87 | loop.update(len(buffer)) 88 | 89 | model_bytes = open(download_target, "rb").read() 90 | if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: 91 | raise RuntimeError( 92 | "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." 93 | ) 94 | 95 | return model_bytes if in_memory else download_target 96 | 97 | 98 | def available_models() -> List[str]: 99 | """Returns the names of available models""" 100 | return list(_MODELS.keys()) 101 | 102 | 103 | def load_model( 104 | name: str, 105 | device: Optional[Union[str, torch.device]] = None, 106 | download_root: str = None, 107 | in_memory: bool = False, 108 | ) -> Whisper: 109 | """ 110 | Load a Whisper ASR model 111 | 112 | Parameters 113 | ---------- 114 | name : str 115 | one of the official model names listed by `whisper.available_models()`, or 116 | path to a model checkpoint containing the model dimensions and the model state_dict. 117 | device : Union[str, torch.device] 118 | the PyTorch device to put the model into 119 | download_root: str 120 | path to download the model files; by default, it uses "~/.cache/whisper" 121 | in_memory: bool 122 | whether to preload the model weights into host memory 123 | 124 | Returns 125 | ------- 126 | model : Whisper 127 | The Whisper ASR model instance 128 | """ 129 | 130 | if device is None: 131 | device = "cuda" if torch.cuda.is_available() else "cpu" 132 | if download_root is None: 133 | default = os.path.join(os.path.expanduser("~"), ".cache") 134 | download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") 135 | 136 | if name in _MODELS: 137 | checkpoint_file = _download(_MODELS[name], download_root, in_memory) 138 | alignment_heads = _ALIGNMENT_HEADS[name] 139 | elif os.path.isfile(name): 140 | checkpoint_file = open(name, "rb").read() if in_memory else name 141 | alignment_heads = None 142 | else: 143 | raise RuntimeError( 144 | f"Model {name} not found; available models = {available_models()}" 145 | ) 146 | 147 | with ( 148 | io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") 149 | ) as fp: 150 | checkpoint = torch.load(fp, map_location=device) 151 | del checkpoint_file 152 | 153 | dims = ModelDimensions(**checkpoint["dims"]) 154 | model = Whisper(dims) 155 | model.load_state_dict(checkpoint["model_state_dict"]) 156 | 157 | if alignment_heads is not None: 158 | model.set_alignment_heads(alignment_heads) 159 | 160 | return model.to(device) 161 | -------------------------------------------------------------------------------- /whisper/__main__.py: -------------------------------------------------------------------------------- 1 | from .transcribe import cli 2 | 3 | cli() 4 | -------------------------------------------------------------------------------- /whisper/assets/mel_filters.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/whisper/dd985ac4b90cafeef8712f2998d62c59c3e62d22/whisper/assets/mel_filters.npz -------------------------------------------------------------------------------- /whisper/audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import lru_cache 3 | from subprocess import CalledProcessError, run 4 | from typing import Optional, Union 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from .utils import exact_div 11 | 12 | # hard-coded audio hyperparameters 13 | SAMPLE_RATE = 16000 14 | N_FFT = 400 15 | HOP_LENGTH = 160 16 | CHUNK_LENGTH = 30 17 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk 18 | N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input 19 | 20 | N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 21 | FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame 22 | TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token 23 | 24 | 25 | def load_audio(file: str, sr: int = SAMPLE_RATE): 26 | """ 27 | Open an audio file and read as mono waveform, resampling as necessary 28 | 29 | Parameters 30 | ---------- 31 | file: str 32 | The audio file to open 33 | 34 | sr: int 35 | The sample rate to resample the audio if necessary 36 | 37 | Returns 38 | ------- 39 | A NumPy array containing the audio waveform, in float32 dtype. 40 | """ 41 | 42 | # This launches a subprocess to decode audio while down-mixing 43 | # and resampling as necessary. Requires the ffmpeg CLI in PATH. 44 | # fmt: off 45 | cmd = [ 46 | "ffmpeg", 47 | "-nostdin", 48 | "-threads", "0", 49 | "-i", file, 50 | "-f", "s16le", 51 | "-ac", "1", 52 | "-acodec", "pcm_s16le", 53 | "-ar", str(sr), 54 | "-" 55 | ] 56 | # fmt: on 57 | try: 58 | out = run(cmd, capture_output=True, check=True).stdout 59 | except CalledProcessError as e: 60 | raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e 61 | 62 | return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 63 | 64 | 65 | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): 66 | """ 67 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder. 68 | """ 69 | if torch.is_tensor(array): 70 | if array.shape[axis] > length: 71 | array = array.index_select( 72 | dim=axis, index=torch.arange(length, device=array.device) 73 | ) 74 | 75 | if array.shape[axis] < length: 76 | pad_widths = [(0, 0)] * array.ndim 77 | pad_widths[axis] = (0, length - array.shape[axis]) 78 | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) 79 | else: 80 | if array.shape[axis] > length: 81 | array = array.take(indices=range(length), axis=axis) 82 | 83 | if array.shape[axis] < length: 84 | pad_widths = [(0, 0)] * array.ndim 85 | pad_widths[axis] = (0, length - array.shape[axis]) 86 | array = np.pad(array, pad_widths) 87 | 88 | return array 89 | 90 | 91 | @lru_cache(maxsize=None) 92 | def mel_filters(device, n_mels: int) -> torch.Tensor: 93 | """ 94 | load the mel filterbank matrix for projecting STFT into a Mel spectrogram. 95 | Allows decoupling librosa dependency; saved using: 96 | 97 | np.savez_compressed( 98 | "mel_filters.npz", 99 | mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), 100 | mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128), 101 | ) 102 | """ 103 | assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" 104 | 105 | filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") 106 | with np.load(filters_path, allow_pickle=False) as f: 107 | return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) 108 | 109 | 110 | def log_mel_spectrogram( 111 | audio: Union[str, np.ndarray, torch.Tensor], 112 | n_mels: int = 80, 113 | padding: int = 0, 114 | device: Optional[Union[str, torch.device]] = None, 115 | ): 116 | """ 117 | Compute the log-Mel spectrogram of 118 | 119 | Parameters 120 | ---------- 121 | audio: Union[str, np.ndarray, torch.Tensor], shape = (*) 122 | The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz 123 | 124 | n_mels: int 125 | The number of Mel-frequency filters, only 80 and 128 are supported 126 | 127 | padding: int 128 | Number of zero samples to pad to the right 129 | 130 | device: Optional[Union[str, torch.device]] 131 | If given, the audio tensor is moved to this device before STFT 132 | 133 | Returns 134 | ------- 135 | torch.Tensor, shape = (n_mels, n_frames) 136 | A Tensor that contains the Mel spectrogram 137 | """ 138 | if not torch.is_tensor(audio): 139 | if isinstance(audio, str): 140 | audio = load_audio(audio) 141 | audio = torch.from_numpy(audio) 142 | 143 | if device is not None: 144 | audio = audio.to(device) 145 | if padding > 0: 146 | audio = F.pad(audio, (0, padding)) 147 | window = torch.hann_window(N_FFT).to(audio.device) 148 | stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) 149 | magnitudes = stft[..., :-1].abs() ** 2 150 | 151 | filters = mel_filters(audio.device, n_mels) 152 | mel_spec = filters @ magnitudes 153 | 154 | log_spec = torch.clamp(mel_spec, min=1e-10).log10() 155 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) 156 | log_spec = (log_spec + 4.0) / 4.0 157 | return log_spec 158 | -------------------------------------------------------------------------------- /whisper/decoding.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field, replace 2 | from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import Tensor 8 | from torch.distributions import Categorical 9 | 10 | from .audio import CHUNK_LENGTH 11 | from .tokenizer import Tokenizer, get_tokenizer 12 | from .utils import compression_ratio 13 | 14 | if TYPE_CHECKING: 15 | from .model import Whisper 16 | 17 | 18 | @torch.no_grad() 19 | def detect_language( 20 | model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None 21 | ) -> Tuple[Tensor, List[dict]]: 22 | """ 23 | Detect the spoken language in the audio, and return them as list of strings, along with the ids 24 | of the most probable language tokens and the probability distribution over all language tokens. 25 | This is performed outside the main decode loop in order to not interfere with kv-caching. 26 | 27 | Returns 28 | ------- 29 | language_tokens : Tensor, shape = (n_audio,) 30 | ids of the most probable language tokens, which appears after the startoftranscript token. 31 | language_probs : List[Dict[str, float]], length = n_audio 32 | list of dictionaries containing the probability distribution over all languages. 33 | """ 34 | if tokenizer is None: 35 | tokenizer = get_tokenizer( 36 | model.is_multilingual, num_languages=model.num_languages 37 | ) 38 | if ( 39 | tokenizer.language is None 40 | or tokenizer.language_token not in tokenizer.sot_sequence 41 | ): 42 | raise ValueError( 43 | "This model doesn't have language tokens so it can't perform lang id" 44 | ) 45 | 46 | single = mel.ndim == 2 47 | if single: 48 | mel = mel.unsqueeze(0) 49 | 50 | # skip encoder forward pass if already-encoded audio features were given 51 | if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state): 52 | mel = model.encoder(mel) 53 | 54 | # forward pass using a single token, startoftranscript 55 | n_audio = mel.shape[0] 56 | x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1] 57 | logits = model.logits(x, mel)[:, 0] 58 | 59 | # collect detected languages; suppress all non-language tokens 60 | mask = torch.ones(logits.shape[-1], dtype=torch.bool) 61 | mask[list(tokenizer.all_language_tokens)] = False 62 | logits[:, mask] = -np.inf 63 | language_tokens = logits.argmax(dim=-1) 64 | language_token_probs = logits.softmax(dim=-1).cpu() 65 | language_probs = [ 66 | { 67 | c: language_token_probs[i, j].item() 68 | for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes) 69 | } 70 | for i in range(n_audio) 71 | ] 72 | 73 | if single: 74 | language_tokens = language_tokens[0] 75 | language_probs = language_probs[0] 76 | 77 | return language_tokens, language_probs 78 | 79 | 80 | @dataclass(frozen=True) 81 | class DecodingOptions: 82 | # whether to perform X->X "transcribe" or X->English "translate" 83 | task: str = "transcribe" 84 | 85 | # language that the audio is in; uses detected language if None 86 | language: Optional[str] = None 87 | 88 | # sampling-related options 89 | temperature: float = 0.0 90 | sample_len: Optional[int] = None # maximum number of tokens to sample 91 | best_of: Optional[int] = None # number of independent sample trajectories, if t > 0 92 | beam_size: Optional[int] = None # number of beams in beam search, if t == 0 93 | patience: Optional[float] = None # patience in beam search (arxiv:2204.05424) 94 | 95 | # "alpha" in Google NMT, or None for length norm, when ranking generations 96 | # to select which to return among the beams or best-of-N samples 97 | length_penalty: Optional[float] = None 98 | 99 | # text or tokens to feed as the prompt or the prefix; for more info: 100 | # https://github.com/openai/whisper/discussions/117#discussioncomment-3727051 101 | prompt: Optional[Union[str, List[int]]] = None # for the previous context 102 | prefix: Optional[Union[str, List[int]]] = None # to prefix the current context 103 | 104 | # list of tokens ids (or comma-separated token ids) to suppress 105 | # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()` 106 | suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1" 107 | suppress_blank: bool = True # this will suppress blank outputs 108 | 109 | # timestamp sampling options 110 | without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only 111 | max_initial_timestamp: Optional[float] = 1.0 112 | 113 | # implementation details 114 | fp16: bool = True # use fp16 for most of the calculation 115 | 116 | 117 | @dataclass(frozen=True) 118 | class DecodingResult: 119 | audio_features: Tensor 120 | language: str 121 | language_probs: Optional[Dict[str, float]] = None 122 | tokens: List[int] = field(default_factory=list) 123 | text: str = "" 124 | avg_logprob: float = np.nan 125 | no_speech_prob: float = np.nan 126 | temperature: float = np.nan 127 | compression_ratio: float = np.nan 128 | 129 | 130 | class Inference: 131 | def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: 132 | """Perform a forward pass on the decoder and return per-token logits""" 133 | raise NotImplementedError 134 | 135 | def rearrange_kv_cache(self, source_indices) -> None: 136 | """Update the key-value cache according to the updated beams""" 137 | raise NotImplementedError 138 | 139 | def cleanup_caching(self) -> None: 140 | """Clean up any resources or hooks after decoding is finished""" 141 | pass 142 | 143 | 144 | class PyTorchInference(Inference): 145 | def __init__(self, model: "Whisper", initial_token_length: int): 146 | self.model: "Whisper" = model 147 | self.initial_token_length = initial_token_length 148 | self.kv_cache = {} 149 | self.hooks = [] 150 | 151 | key_modules = [block.attn.key for block in self.model.decoder.blocks] 152 | value_modules = [block.attn.value for block in self.model.decoder.blocks] 153 | self.kv_modules = key_modules + value_modules 154 | 155 | def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: 156 | if not self.kv_cache: 157 | self.kv_cache, self.hooks = self.model.install_kv_cache_hooks() 158 | 159 | if tokens.shape[-1] > self.initial_token_length: 160 | # only need to use the last token except in the first forward pass 161 | tokens = tokens[:, -1:] 162 | 163 | return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache) 164 | 165 | def cleanup_caching(self): 166 | for hook in self.hooks: 167 | hook.remove() 168 | 169 | self.kv_cache = {} 170 | self.hooks = [] 171 | 172 | def rearrange_kv_cache(self, source_indices): 173 | if source_indices != list(range(len(source_indices))): 174 | for module in self.kv_modules: 175 | # update the key/value cache to contain the selected sequences 176 | self.kv_cache[module] = self.kv_cache[module][source_indices].detach() 177 | 178 | 179 | class SequenceRanker: 180 | def rank( 181 | self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]] 182 | ) -> List[int]: 183 | """ 184 | Given a list of groups of samples and their cumulative log probabilities, 185 | return the indices of the samples in each group to select as the final result 186 | """ 187 | raise NotImplementedError 188 | 189 | 190 | class MaximumLikelihoodRanker(SequenceRanker): 191 | """ 192 | Select the sample with the highest log probabilities, penalized using either 193 | a simple length normalization or Google NMT paper's length penalty 194 | """ 195 | 196 | def __init__(self, length_penalty: Optional[float]): 197 | self.length_penalty = length_penalty 198 | 199 | def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]): 200 | def scores(logprobs, lengths): 201 | result = [] 202 | for logprob, length in zip(logprobs, lengths): 203 | if self.length_penalty is None: 204 | penalty = length 205 | else: 206 | # from the Google NMT paper 207 | penalty = ((5 + length) / 6) ** self.length_penalty 208 | result.append(logprob / penalty) 209 | return result 210 | 211 | # get the sequence with the highest score 212 | lengths = [[len(t) for t in s] for s in tokens] 213 | return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)] 214 | 215 | 216 | class TokenDecoder: 217 | def reset(self): 218 | """Initialize any stateful variables for decoding a new sequence""" 219 | 220 | def update( 221 | self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor 222 | ) -> Tuple[Tensor, bool]: 223 | """Specify how to select the next token, based on the current trace and logits 224 | 225 | Parameters 226 | ---------- 227 | tokens : Tensor, shape = (n_batch, current_sequence_length) 228 | all tokens in the context so far, including the prefix and sot_sequence tokens 229 | 230 | logits : Tensor, shape = (n_batch, vocab_size) 231 | per-token logits of the probability distribution at the current step 232 | 233 | sum_logprobs : Tensor, shape = (n_batch) 234 | cumulative log probabilities for each sequence 235 | 236 | Returns 237 | ------- 238 | tokens : Tensor, shape = (n_batch, current_sequence_length + 1) 239 | the tokens, appended with the selected next token 240 | 241 | completed : bool 242 | True if all sequences has reached the end of text 243 | 244 | """ 245 | raise NotImplementedError 246 | 247 | def finalize( 248 | self, tokens: Tensor, sum_logprobs: Tensor 249 | ) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]: 250 | """Finalize search and return the final candidate sequences 251 | 252 | Parameters 253 | ---------- 254 | tokens : Tensor, shape = (n_audio, n_group, current_sequence_length) 255 | all tokens in the context so far, including the prefix and sot_sequence 256 | 257 | sum_logprobs : Tensor, shape = (n_audio, n_group) 258 | cumulative log probabilities for each sequence 259 | 260 | Returns 261 | ------- 262 | tokens : Sequence[Sequence[Tensor]], length = n_audio 263 | sequence of Tensors containing candidate token sequences, for each audio input 264 | 265 | sum_logprobs : List[List[float]], length = n_audio 266 | sequence of cumulative log probabilities corresponding to the above 267 | 268 | """ 269 | raise NotImplementedError 270 | 271 | 272 | class GreedyDecoder(TokenDecoder): 273 | def __init__(self, temperature: float, eot: int): 274 | self.temperature = temperature 275 | self.eot = eot 276 | 277 | def update( 278 | self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor 279 | ) -> Tuple[Tensor, bool]: 280 | if self.temperature == 0: 281 | next_tokens = logits.argmax(dim=-1) 282 | else: 283 | next_tokens = Categorical(logits=logits / self.temperature).sample() 284 | 285 | logprobs = F.log_softmax(logits.float(), dim=-1) 286 | current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens] 287 | sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot) 288 | 289 | next_tokens[tokens[:, -1] == self.eot] = self.eot 290 | tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1) 291 | 292 | completed = (tokens[:, -1] == self.eot).all() 293 | return tokens, completed 294 | 295 | def finalize(self, tokens: Tensor, sum_logprobs: Tensor): 296 | # make sure each sequence has at least one EOT token at the end 297 | tokens = F.pad(tokens, (0, 1), value=self.eot) 298 | return tokens, sum_logprobs.tolist() 299 | 300 | 301 | class BeamSearchDecoder(TokenDecoder): 302 | def __init__( 303 | self, 304 | beam_size: int, 305 | eot: int, 306 | inference: Inference, 307 | patience: Optional[float] = None, 308 | ): 309 | self.beam_size = beam_size 310 | self.eot = eot 311 | self.inference = inference 312 | self.patience = patience or 1.0 313 | self.max_candidates: int = round(beam_size * self.patience) 314 | self.finished_sequences = None 315 | 316 | assert ( 317 | self.max_candidates > 0 318 | ), f"Invalid beam size ({beam_size}) or patience ({patience})" 319 | 320 | def reset(self): 321 | self.finished_sequences = None 322 | 323 | def update( 324 | self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor 325 | ) -> Tuple[Tensor, bool]: 326 | if tokens.shape[0] % self.beam_size != 0: 327 | raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0") 328 | 329 | n_audio = tokens.shape[0] // self.beam_size 330 | if self.finished_sequences is None: # for the first update 331 | self.finished_sequences = [{} for _ in range(n_audio)] 332 | 333 | logprobs = F.log_softmax(logits.float(), dim=-1) 334 | next_tokens, source_indices, finished_sequences = [], [], [] 335 | for i in range(n_audio): 336 | scores, sources, finished = {}, {}, {} 337 | 338 | # STEP 1: calculate the cumulative log probabilities for possible candidates 339 | for j in range(self.beam_size): 340 | idx = i * self.beam_size + j 341 | prefix = tokens[idx].tolist() 342 | for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)): 343 | new_logprob = (sum_logprobs[idx] + logprob).item() 344 | sequence = tuple(prefix + [token.item()]) 345 | scores[sequence] = new_logprob 346 | sources[sequence] = idx 347 | 348 | # STEP 2: rank the candidates and keep the top beam_size sequences for each audio 349 | saved = 0 350 | for sequence in sorted(scores, key=scores.get, reverse=True): 351 | if sequence[-1] == self.eot: 352 | finished[sequence] = scores[sequence] 353 | else: 354 | sum_logprobs[len(next_tokens)] = scores[sequence] 355 | next_tokens.append(sequence) 356 | source_indices.append(sources[sequence]) 357 | 358 | saved += 1 359 | if saved == self.beam_size: 360 | break 361 | 362 | finished_sequences.append(finished) 363 | 364 | tokens = torch.tensor(next_tokens, device=tokens.device) 365 | self.inference.rearrange_kv_cache(source_indices) 366 | 367 | # add newly finished sequences to self.finished_sequences 368 | assert len(self.finished_sequences) == len(finished_sequences) 369 | for previously_finished, newly_finished in zip( 370 | self.finished_sequences, finished_sequences 371 | ): 372 | for seq in sorted(newly_finished, key=newly_finished.get, reverse=True): 373 | if len(previously_finished) >= self.max_candidates: 374 | break # the candidate list is full 375 | previously_finished[seq] = newly_finished[seq] 376 | 377 | # mark as completed if all audio has enough number of samples 378 | completed = all( 379 | len(sequences) >= self.max_candidates 380 | for sequences in self.finished_sequences 381 | ) 382 | return tokens, completed 383 | 384 | def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor): 385 | # collect all finished sequences, including patience, and add unfinished ones if not enough 386 | sum_logprobs = sum_logprobs.cpu() 387 | for i, sequences in enumerate(self.finished_sequences): 388 | if ( 389 | len(sequences) < self.beam_size 390 | ): # when not enough sequences are finished 391 | for j in list(np.argsort(sum_logprobs[i]))[::-1]: 392 | sequence = preceding_tokens[i, j].tolist() + [self.eot] 393 | sequences[tuple(sequence)] = sum_logprobs[i][j].item() 394 | if len(sequences) >= self.beam_size: 395 | break 396 | 397 | tokens: List[List[Tensor]] = [ 398 | [torch.tensor(seq) for seq in sequences.keys()] 399 | for sequences in self.finished_sequences 400 | ] 401 | sum_logprobs: List[List[float]] = [ 402 | list(sequences.values()) for sequences in self.finished_sequences 403 | ] 404 | return tokens, sum_logprobs 405 | 406 | 407 | class LogitFilter: 408 | def apply(self, logits: Tensor, tokens: Tensor) -> None: 409 | """Apply any filtering or masking to logits in-place 410 | 411 | Parameters 412 | ---------- 413 | logits : Tensor, shape = (n_batch, vocab_size) 414 | per-token logits of the probability distribution at the current step 415 | 416 | tokens : Tensor, shape = (n_batch, current_sequence_length) 417 | all tokens in the context so far, including the prefix and sot_sequence tokens 418 | 419 | """ 420 | raise NotImplementedError 421 | 422 | 423 | class SuppressBlank(LogitFilter): 424 | def __init__(self, tokenizer: Tokenizer, sample_begin: int): 425 | self.tokenizer = tokenizer 426 | self.sample_begin = sample_begin 427 | 428 | def apply(self, logits: Tensor, tokens: Tensor): 429 | if tokens.shape[1] == self.sample_begin: 430 | logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf 431 | 432 | 433 | class SuppressTokens(LogitFilter): 434 | def __init__(self, suppress_tokens: Sequence[int]): 435 | self.suppress_tokens = list(suppress_tokens) 436 | 437 | def apply(self, logits: Tensor, tokens: Tensor): 438 | logits[:, self.suppress_tokens] = -np.inf 439 | 440 | 441 | class ApplyTimestampRules(LogitFilter): 442 | def __init__( 443 | self, 444 | tokenizer: Tokenizer, 445 | sample_begin: int, 446 | max_initial_timestamp_index: Optional[int], 447 | ): 448 | self.tokenizer = tokenizer 449 | self.sample_begin = sample_begin 450 | self.max_initial_timestamp_index = max_initial_timestamp_index 451 | 452 | def apply(self, logits: Tensor, tokens: Tensor): 453 | # suppress <|notimestamps|> which is handled by without_timestamps 454 | if self.tokenizer.no_timestamps is not None: 455 | logits[:, self.tokenizer.no_timestamps] = -np.inf 456 | 457 | # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly 458 | for k in range(tokens.shape[0]): 459 | sampled_tokens = tokens[k, self.sample_begin :] 460 | seq = [t for t in sampled_tokens.tolist()] 461 | last_was_timestamp = ( 462 | len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin 463 | ) 464 | penultimate_was_timestamp = ( 465 | len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin 466 | ) 467 | 468 | if last_was_timestamp: 469 | if penultimate_was_timestamp: # has to be non-timestamp 470 | logits[k, self.tokenizer.timestamp_begin :] = -np.inf 471 | else: # cannot be normal text tokens 472 | logits[k, : self.tokenizer.eot] = -np.inf 473 | 474 | timestamps = sampled_tokens[ 475 | sampled_tokens.ge(self.tokenizer.timestamp_begin) 476 | ] 477 | if timestamps.numel() > 0: 478 | # timestamps shouldn't decrease; forbid timestamp tokens smaller than the last 479 | # also force each segment to have a nonzero length, to prevent infinite looping 480 | if last_was_timestamp and not penultimate_was_timestamp: 481 | timestamp_last = timestamps[-1] 482 | else: 483 | timestamp_last = timestamps[-1] + 1 484 | logits[k, self.tokenizer.timestamp_begin : timestamp_last] = -np.inf 485 | 486 | if tokens.shape[1] == self.sample_begin: 487 | # suppress generating non-timestamp tokens at the beginning 488 | logits[:, : self.tokenizer.timestamp_begin] = -np.inf 489 | 490 | # apply the `max_initial_timestamp` option 491 | if self.max_initial_timestamp_index is not None: 492 | last_allowed = ( 493 | self.tokenizer.timestamp_begin + self.max_initial_timestamp_index 494 | ) 495 | logits[:, last_allowed + 1 :] = -np.inf 496 | 497 | # if sum of probability over timestamps is above any other token, sample timestamp 498 | logprobs = F.log_softmax(logits.float(), dim=-1) 499 | for k in range(tokens.shape[0]): 500 | timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp( 501 | dim=-1 502 | ) 503 | max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max() 504 | if timestamp_logprob > max_text_token_logprob: 505 | logits[k, : self.tokenizer.timestamp_begin] = -np.inf 506 | 507 | 508 | class DecodingTask: 509 | inference: Inference 510 | sequence_ranker: SequenceRanker 511 | decoder: TokenDecoder 512 | logit_filters: List[LogitFilter] 513 | 514 | def __init__(self, model: "Whisper", options: DecodingOptions): 515 | self.model = model 516 | 517 | language = options.language or "en" 518 | tokenizer = get_tokenizer( 519 | model.is_multilingual, 520 | num_languages=model.num_languages, 521 | language=language, 522 | task=options.task, 523 | ) 524 | self.tokenizer: Tokenizer = tokenizer 525 | self.options: DecodingOptions = self._verify_options(options) 526 | 527 | self.n_group: int = options.beam_size or options.best_of or 1 528 | self.n_ctx: int = model.dims.n_text_ctx 529 | self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2 530 | 531 | self.sot_sequence: Tuple[int] = tokenizer.sot_sequence 532 | if self.options.without_timestamps: 533 | self.sot_sequence = tokenizer.sot_sequence_including_notimestamps 534 | 535 | self.initial_tokens: Tuple[int] = self._get_initial_tokens() 536 | self.sample_begin: int = len(self.initial_tokens) 537 | self.sot_index: int = self.initial_tokens.index(tokenizer.sot) 538 | 539 | # inference: implements the forward pass through the decoder, including kv caching 540 | self.inference = PyTorchInference(model, len(self.initial_tokens)) 541 | 542 | # sequence ranker: implements how to rank a group of sampled sequences 543 | self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty) 544 | 545 | # decoder: implements how to select the next tokens, given the autoregressive distribution 546 | if options.beam_size is not None: 547 | self.decoder = BeamSearchDecoder( 548 | options.beam_size, tokenizer.eot, self.inference, options.patience 549 | ) 550 | else: 551 | self.decoder = GreedyDecoder(options.temperature, tokenizer.eot) 552 | 553 | # logit filters: applies various rules to suppress or penalize certain tokens 554 | self.logit_filters = [] 555 | if self.options.suppress_blank: 556 | self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin)) 557 | if self.options.suppress_tokens: 558 | self.logit_filters.append(SuppressTokens(self._get_suppress_tokens())) 559 | if not options.without_timestamps: 560 | precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds 561 | max_initial_timestamp_index = None 562 | if options.max_initial_timestamp: 563 | max_initial_timestamp_index = round( 564 | self.options.max_initial_timestamp / precision 565 | ) 566 | self.logit_filters.append( 567 | ApplyTimestampRules( 568 | tokenizer, self.sample_begin, max_initial_timestamp_index 569 | ) 570 | ) 571 | 572 | def _verify_options(self, options: DecodingOptions) -> DecodingOptions: 573 | if options.beam_size is not None and options.best_of is not None: 574 | raise ValueError("beam_size and best_of can't be given together") 575 | if options.temperature == 0: 576 | if options.best_of is not None: 577 | raise ValueError("best_of with greedy sampling (T=0) is not compatible") 578 | if options.patience is not None and options.beam_size is None: 579 | raise ValueError("patience requires beam_size to be given") 580 | if options.length_penalty is not None and not ( 581 | 0 <= options.length_penalty <= 1 582 | ): 583 | raise ValueError("length_penalty (alpha) should be a value between 0 and 1") 584 | 585 | return options 586 | 587 | def _get_initial_tokens(self) -> Tuple[int]: 588 | tokens = list(self.sot_sequence) 589 | 590 | if prefix := self.options.prefix: 591 | prefix_tokens = ( 592 | self.tokenizer.encode(" " + prefix.strip()) 593 | if isinstance(prefix, str) 594 | else prefix 595 | ) 596 | if self.sample_len is not None: 597 | max_prefix_len = self.n_ctx // 2 - self.sample_len 598 | prefix_tokens = prefix_tokens[-max_prefix_len:] 599 | tokens = tokens + prefix_tokens 600 | 601 | if prompt := self.options.prompt: 602 | prompt_tokens = ( 603 | self.tokenizer.encode(" " + prompt.strip()) 604 | if isinstance(prompt, str) 605 | else prompt 606 | ) 607 | tokens = ( 608 | [self.tokenizer.sot_prev] 609 | + prompt_tokens[-(self.n_ctx // 2 - 1) :] 610 | + tokens 611 | ) 612 | 613 | return tuple(tokens) 614 | 615 | def _get_suppress_tokens(self) -> Tuple[int]: 616 | suppress_tokens = self.options.suppress_tokens 617 | 618 | if isinstance(suppress_tokens, str): 619 | suppress_tokens = [int(t) for t in suppress_tokens.split(",")] 620 | 621 | if -1 in suppress_tokens: 622 | suppress_tokens = [t for t in suppress_tokens if t >= 0] 623 | suppress_tokens.extend(self.tokenizer.non_speech_tokens) 624 | elif suppress_tokens is None or len(suppress_tokens) == 0: 625 | suppress_tokens = [] # interpret empty string as an empty list 626 | else: 627 | assert isinstance(suppress_tokens, list), "suppress_tokens must be a list" 628 | 629 | suppress_tokens.extend( 630 | [ 631 | self.tokenizer.transcribe, 632 | self.tokenizer.translate, 633 | self.tokenizer.sot, 634 | self.tokenizer.sot_prev, 635 | self.tokenizer.sot_lm, 636 | ] 637 | ) 638 | if self.tokenizer.no_speech is not None: 639 | # no-speech probability is collected separately 640 | suppress_tokens.append(self.tokenizer.no_speech) 641 | 642 | return tuple(sorted(set(suppress_tokens))) 643 | 644 | def _get_audio_features(self, mel: Tensor): 645 | if self.options.fp16: 646 | mel = mel.half() 647 | 648 | if mel.shape[-2:] == ( 649 | self.model.dims.n_audio_ctx, 650 | self.model.dims.n_audio_state, 651 | ): 652 | # encoded audio features are given; skip audio encoding 653 | audio_features = mel 654 | else: 655 | audio_features = self.model.encoder(mel) 656 | 657 | if audio_features.dtype != ( 658 | torch.float16 if self.options.fp16 else torch.float32 659 | ): 660 | return TypeError( 661 | f"audio_features has an incorrect dtype: {audio_features.dtype}" 662 | ) 663 | 664 | return audio_features 665 | 666 | def _detect_language(self, audio_features: Tensor, tokens: Tensor): 667 | languages = [self.options.language] * audio_features.shape[0] 668 | lang_probs = None 669 | 670 | if self.options.language is None or self.options.task == "lang_id": 671 | lang_tokens, lang_probs = self.model.detect_language( 672 | audio_features, self.tokenizer 673 | ) 674 | languages = [max(probs, key=probs.get) for probs in lang_probs] 675 | if self.options.language is None: 676 | tokens[:, self.sot_index + 1] = lang_tokens # write language tokens 677 | 678 | return languages, lang_probs 679 | 680 | def _main_loop(self, audio_features: Tensor, tokens: Tensor): 681 | n_batch = tokens.shape[0] 682 | sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device) 683 | no_speech_probs = [np.nan] * n_batch 684 | 685 | try: 686 | for i in range(self.sample_len): 687 | logits = self.inference.logits(tokens, audio_features) 688 | 689 | if ( 690 | i == 0 and self.tokenizer.no_speech is not None 691 | ): # save no_speech_probs 692 | probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1) 693 | no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist() 694 | 695 | # now we need to consider the logits at the last token only 696 | logits = logits[:, -1] 697 | 698 | # apply the logit filters, e.g. for suppressing or applying penalty to 699 | for logit_filter in self.logit_filters: 700 | logit_filter.apply(logits, tokens) 701 | 702 | # expand the tokens tensor with the selected next tokens 703 | tokens, completed = self.decoder.update(tokens, logits, sum_logprobs) 704 | 705 | if completed or tokens.shape[-1] > self.n_ctx: 706 | break 707 | finally: 708 | self.inference.cleanup_caching() 709 | 710 | return tokens, sum_logprobs, no_speech_probs 711 | 712 | @torch.no_grad() 713 | def run(self, mel: Tensor) -> List[DecodingResult]: 714 | self.decoder.reset() 715 | tokenizer: Tokenizer = self.tokenizer 716 | n_audio: int = mel.shape[0] 717 | 718 | audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass 719 | tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1) 720 | 721 | # detect language if requested, overwriting the language token 722 | languages, language_probs = self._detect_language(audio_features, tokens) 723 | if self.options.task == "lang_id": 724 | return [ 725 | DecodingResult( 726 | audio_features=features, language=language, language_probs=probs 727 | ) 728 | for features, language, probs in zip( 729 | audio_features, languages, language_probs 730 | ) 731 | ] 732 | 733 | # repeat text tensors by the group size, for beam search or best-of-n sampling 734 | tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device) 735 | 736 | # call the main sampling loop 737 | tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens) 738 | 739 | # reshape the tensors to have (n_audio, n_group) as the first two dimensions 740 | audio_features = audio_features[:: self.n_group] 741 | no_speech_probs = no_speech_probs[:: self.n_group] 742 | assert audio_features.shape[0] == len(no_speech_probs) == n_audio 743 | 744 | tokens = tokens.reshape(n_audio, self.n_group, -1) 745 | sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group) 746 | 747 | # get the final candidates for each group, and slice between the first sampled token and EOT 748 | tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs) 749 | tokens: List[List[Tensor]] = [ 750 | [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] 751 | for s in tokens 752 | ] 753 | 754 | # select the top-ranked sample in each group 755 | selected = self.sequence_ranker.rank(tokens, sum_logprobs) 756 | tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)] 757 | texts: List[str] = [tokenizer.decode(t).strip() for t in tokens] 758 | 759 | sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)] 760 | avg_logprobs: List[float] = [ 761 | lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs) 762 | ] 763 | 764 | fields = ( 765 | texts, 766 | languages, 767 | tokens, 768 | audio_features, 769 | avg_logprobs, 770 | no_speech_probs, 771 | ) 772 | if len(set(map(len, fields))) != 1: 773 | raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}") 774 | 775 | return [ 776 | DecodingResult( 777 | audio_features=features, 778 | language=language, 779 | tokens=tokens, 780 | text=text, 781 | avg_logprob=avg_logprob, 782 | no_speech_prob=no_speech_prob, 783 | temperature=self.options.temperature, 784 | compression_ratio=compression_ratio(text), 785 | ) 786 | for text, language, tokens, features, avg_logprob, no_speech_prob in zip( 787 | *fields 788 | ) 789 | ] 790 | 791 | 792 | @torch.no_grad() 793 | def decode( 794 | model: "Whisper", 795 | mel: Tensor, 796 | options: DecodingOptions = DecodingOptions(), 797 | **kwargs, 798 | ) -> Union[DecodingResult, List[DecodingResult]]: 799 | """ 800 | Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s). 801 | 802 | Parameters 803 | ---------- 804 | model: Whisper 805 | the Whisper model instance 806 | 807 | mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000) 808 | A tensor containing the Mel spectrogram(s) 809 | 810 | options: DecodingOptions 811 | A dataclass that contains all necessary options for decoding 30-second segments 812 | 813 | Returns 814 | ------- 815 | result: Union[DecodingResult, List[DecodingResult]] 816 | The result(s) of decoding contained in `DecodingResult` dataclass instance(s) 817 | """ 818 | if single := mel.ndim == 2: 819 | mel = mel.unsqueeze(0) 820 | 821 | if kwargs: 822 | options = replace(options, **kwargs) 823 | 824 | result = DecodingTask(model, options).run(mel) 825 | 826 | return result[0] if single else result 827 | -------------------------------------------------------------------------------- /whisper/model.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import gzip 3 | from contextlib import contextmanager 4 | from dataclasses import dataclass 5 | from typing import Dict, Iterable, Optional, Tuple 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | from torch import Tensor, nn 11 | 12 | from .decoding import decode as decode_function 13 | from .decoding import detect_language as detect_language_function 14 | from .transcribe import transcribe as transcribe_function 15 | 16 | try: 17 | from torch.nn.functional import scaled_dot_product_attention 18 | 19 | SDPA_AVAILABLE = True 20 | except (ImportError, RuntimeError, OSError): 21 | scaled_dot_product_attention = None 22 | SDPA_AVAILABLE = False 23 | 24 | 25 | @dataclass 26 | class ModelDimensions: 27 | n_mels: int 28 | n_audio_ctx: int 29 | n_audio_state: int 30 | n_audio_head: int 31 | n_audio_layer: int 32 | n_vocab: int 33 | n_text_ctx: int 34 | n_text_state: int 35 | n_text_head: int 36 | n_text_layer: int 37 | 38 | 39 | class LayerNorm(nn.LayerNorm): 40 | def forward(self, x: Tensor) -> Tensor: 41 | return super().forward(x.float()).type(x.dtype) 42 | 43 | 44 | class Linear(nn.Linear): 45 | def forward(self, x: Tensor) -> Tensor: 46 | return F.linear( 47 | x, 48 | self.weight.to(x.dtype), 49 | None if self.bias is None else self.bias.to(x.dtype), 50 | ) 51 | 52 | 53 | class Conv1d(nn.Conv1d): 54 | def _conv_forward( 55 | self, x: Tensor, weight: Tensor, bias: Optional[Tensor] 56 | ) -> Tensor: 57 | return super()._conv_forward( 58 | x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) 59 | ) 60 | 61 | 62 | def sinusoids(length, channels, max_timescale=10000): 63 | """Returns sinusoids for positional embedding""" 64 | assert channels % 2 == 0 65 | log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) 66 | inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) 67 | scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] 68 | return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) 69 | 70 | 71 | @contextmanager 72 | def disable_sdpa(): 73 | prev_state = MultiHeadAttention.use_sdpa 74 | try: 75 | MultiHeadAttention.use_sdpa = False 76 | yield 77 | finally: 78 | MultiHeadAttention.use_sdpa = prev_state 79 | 80 | 81 | class MultiHeadAttention(nn.Module): 82 | use_sdpa = True 83 | 84 | def __init__(self, n_state: int, n_head: int): 85 | super().__init__() 86 | self.n_head = n_head 87 | self.query = Linear(n_state, n_state) 88 | self.key = Linear(n_state, n_state, bias=False) 89 | self.value = Linear(n_state, n_state) 90 | self.out = Linear(n_state, n_state) 91 | 92 | def forward( 93 | self, 94 | x: Tensor, 95 | xa: Optional[Tensor] = None, 96 | mask: Optional[Tensor] = None, 97 | kv_cache: Optional[dict] = None, 98 | ): 99 | q = self.query(x) 100 | 101 | if kv_cache is None or xa is None or self.key not in kv_cache: 102 | # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; 103 | # otherwise, perform key/value projections for self- or cross-attention as usual. 104 | k = self.key(x if xa is None else xa) 105 | v = self.value(x if xa is None else xa) 106 | else: 107 | # for cross-attention, calculate keys and values once and reuse in subsequent calls. 108 | k = kv_cache[self.key] 109 | v = kv_cache[self.value] 110 | 111 | wv, qk = self.qkv_attention(q, k, v, mask) 112 | return self.out(wv), qk 113 | 114 | def qkv_attention( 115 | self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None 116 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 117 | n_batch, n_ctx, n_state = q.shape 118 | scale = (n_state // self.n_head) ** -0.25 119 | q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) 120 | k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) 121 | v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) 122 | 123 | if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa: 124 | a = scaled_dot_product_attention( 125 | q, k, v, is_causal=mask is not None and n_ctx > 1 126 | ) 127 | out = a.permute(0, 2, 1, 3).flatten(start_dim=2) 128 | qk = None 129 | else: 130 | qk = (q * scale) @ (k * scale).transpose(-1, -2) 131 | if mask is not None: 132 | qk = qk + mask[:n_ctx, :n_ctx] 133 | qk = qk.float() 134 | 135 | w = F.softmax(qk, dim=-1).to(q.dtype) 136 | out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2) 137 | qk = qk.detach() 138 | 139 | return out, qk 140 | 141 | 142 | class ResidualAttentionBlock(nn.Module): 143 | def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): 144 | super().__init__() 145 | 146 | self.attn = MultiHeadAttention(n_state, n_head) 147 | self.attn_ln = LayerNorm(n_state) 148 | 149 | self.cross_attn = ( 150 | MultiHeadAttention(n_state, n_head) if cross_attention else None 151 | ) 152 | self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None 153 | 154 | n_mlp = n_state * 4 155 | self.mlp = nn.Sequential( 156 | Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state) 157 | ) 158 | self.mlp_ln = LayerNorm(n_state) 159 | 160 | def forward( 161 | self, 162 | x: Tensor, 163 | xa: Optional[Tensor] = None, 164 | mask: Optional[Tensor] = None, 165 | kv_cache: Optional[dict] = None, 166 | ): 167 | x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] 168 | if self.cross_attn: 169 | x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] 170 | x = x + self.mlp(self.mlp_ln(x)) 171 | return x 172 | 173 | 174 | class AudioEncoder(nn.Module): 175 | def __init__( 176 | self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int 177 | ): 178 | super().__init__() 179 | self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) 180 | self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) 181 | self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) 182 | 183 | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( 184 | [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] 185 | ) 186 | self.ln_post = LayerNorm(n_state) 187 | 188 | def forward(self, x: Tensor): 189 | """ 190 | x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) 191 | the mel spectrogram of the audio 192 | """ 193 | x = F.gelu(self.conv1(x)) 194 | x = F.gelu(self.conv2(x)) 195 | x = x.permute(0, 2, 1) 196 | 197 | assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" 198 | x = (x + self.positional_embedding).to(x.dtype) 199 | 200 | for block in self.blocks: 201 | x = block(x) 202 | 203 | x = self.ln_post(x) 204 | return x 205 | 206 | 207 | class TextDecoder(nn.Module): 208 | def __init__( 209 | self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int 210 | ): 211 | super().__init__() 212 | 213 | self.token_embedding = nn.Embedding(n_vocab, n_state) 214 | self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) 215 | 216 | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( 217 | [ 218 | ResidualAttentionBlock(n_state, n_head, cross_attention=True) 219 | for _ in range(n_layer) 220 | ] 221 | ) 222 | self.ln = LayerNorm(n_state) 223 | 224 | mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) 225 | self.register_buffer("mask", mask, persistent=False) 226 | 227 | def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): 228 | """ 229 | x : torch.LongTensor, shape = (batch_size, <= n_ctx) 230 | the text tokens 231 | xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) 232 | the encoded audio features to be attended on 233 | """ 234 | offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 235 | x = ( 236 | self.token_embedding(x) 237 | + self.positional_embedding[offset : offset + x.shape[-1]] 238 | ) 239 | x = x.to(xa.dtype) 240 | 241 | for block in self.blocks: 242 | x = block(x, xa, mask=self.mask, kv_cache=kv_cache) 243 | 244 | x = self.ln(x) 245 | logits = ( 246 | x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) 247 | ).float() 248 | 249 | return logits 250 | 251 | 252 | class Whisper(nn.Module): 253 | def __init__(self, dims: ModelDimensions): 254 | super().__init__() 255 | self.dims = dims 256 | self.encoder = AudioEncoder( 257 | self.dims.n_mels, 258 | self.dims.n_audio_ctx, 259 | self.dims.n_audio_state, 260 | self.dims.n_audio_head, 261 | self.dims.n_audio_layer, 262 | ) 263 | self.decoder = TextDecoder( 264 | self.dims.n_vocab, 265 | self.dims.n_text_ctx, 266 | self.dims.n_text_state, 267 | self.dims.n_text_head, 268 | self.dims.n_text_layer, 269 | ) 270 | # use the last half among the decoder layers for time alignment by default; 271 | # to use a specific set of heads, see `set_alignment_heads()` below. 272 | all_heads = torch.zeros( 273 | self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool 274 | ) 275 | all_heads[self.dims.n_text_layer // 2 :] = True 276 | self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) 277 | 278 | def set_alignment_heads(self, dump: bytes): 279 | array = np.frombuffer( 280 | gzip.decompress(base64.b85decode(dump)), dtype=bool 281 | ).copy() 282 | mask = torch.from_numpy(array).reshape( 283 | self.dims.n_text_layer, self.dims.n_text_head 284 | ) 285 | self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False) 286 | 287 | def embed_audio(self, mel: torch.Tensor): 288 | return self.encoder(mel) 289 | 290 | def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): 291 | return self.decoder(tokens, audio_features) 292 | 293 | def forward( 294 | self, mel: torch.Tensor, tokens: torch.Tensor 295 | ) -> Dict[str, torch.Tensor]: 296 | return self.decoder(tokens, self.encoder(mel)) 297 | 298 | @property 299 | def device(self): 300 | return next(self.parameters()).device 301 | 302 | @property 303 | def is_multilingual(self): 304 | return self.dims.n_vocab >= 51865 305 | 306 | @property 307 | def num_languages(self): 308 | return self.dims.n_vocab - 51765 - int(self.is_multilingual) 309 | 310 | def install_kv_cache_hooks(self, cache: Optional[dict] = None): 311 | """ 312 | The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value 313 | tensors calculated for the previous positions. This method returns a dictionary that stores 314 | all caches, and the necessary hooks for the key and value projection modules that save the 315 | intermediate tensors to be reused during later calculations. 316 | 317 | Returns 318 | ------- 319 | cache : Dict[nn.Module, torch.Tensor] 320 | A dictionary object mapping the key/value projection modules to its cache 321 | hooks : List[RemovableHandle] 322 | List of PyTorch RemovableHandle objects to stop the hooks to be called 323 | """ 324 | cache = {**cache} if cache is not None else {} 325 | hooks = [] 326 | 327 | def save_to_cache(module, _, output): 328 | if module not in cache or output.shape[1] > self.dims.n_text_ctx: 329 | # save as-is, for the first token or cross attention 330 | cache[module] = output 331 | else: 332 | cache[module] = torch.cat([cache[module], output], dim=1).detach() 333 | return cache[module] 334 | 335 | def install_hooks(layer: nn.Module): 336 | if isinstance(layer, MultiHeadAttention): 337 | hooks.append(layer.key.register_forward_hook(save_to_cache)) 338 | hooks.append(layer.value.register_forward_hook(save_to_cache)) 339 | 340 | self.decoder.apply(install_hooks) 341 | return cache, hooks 342 | 343 | detect_language = detect_language_function 344 | transcribe = transcribe_function 345 | decode = decode_function 346 | -------------------------------------------------------------------------------- /whisper/normalizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic import BasicTextNormalizer as BasicTextNormalizer 2 | from .english import EnglishTextNormalizer as EnglishTextNormalizer 3 | -------------------------------------------------------------------------------- /whisper/normalizers/basic.py: -------------------------------------------------------------------------------- 1 | import re 2 | import unicodedata 3 | 4 | import regex 5 | 6 | # non-ASCII letters that are not separated by "NFKD" normalization 7 | ADDITIONAL_DIACRITICS = { 8 | "œ": "oe", 9 | "Œ": "OE", 10 | "ø": "o", 11 | "Ø": "O", 12 | "æ": "ae", 13 | "Æ": "AE", 14 | "ß": "ss", 15 | "ẞ": "SS", 16 | "đ": "d", 17 | "Đ": "D", 18 | "ð": "d", 19 | "Ð": "D", 20 | "þ": "th", 21 | "Þ": "th", 22 | "ł": "l", 23 | "Ł": "L", 24 | } 25 | 26 | 27 | def remove_symbols_and_diacritics(s: str, keep=""): 28 | """ 29 | Replace any other markers, symbols, and punctuations with a space, 30 | and drop any diacritics (category 'Mn' and some manual mappings) 31 | """ 32 | return "".join( 33 | ( 34 | c 35 | if c in keep 36 | else ( 37 | ADDITIONAL_DIACRITICS[c] 38 | if c in ADDITIONAL_DIACRITICS 39 | else ( 40 | "" 41 | if unicodedata.category(c) == "Mn" 42 | else " " if unicodedata.category(c)[0] in "MSP" else c 43 | ) 44 | ) 45 | ) 46 | for c in unicodedata.normalize("NFKD", s) 47 | ) 48 | 49 | 50 | def remove_symbols(s: str): 51 | """ 52 | Replace any other markers, symbols, punctuations with a space, keeping diacritics 53 | """ 54 | return "".join( 55 | " " if unicodedata.category(c)[0] in "MSP" else c 56 | for c in unicodedata.normalize("NFKC", s) 57 | ) 58 | 59 | 60 | class BasicTextNormalizer: 61 | def __init__(self, remove_diacritics: bool = False, split_letters: bool = False): 62 | self.clean = ( 63 | remove_symbols_and_diacritics if remove_diacritics else remove_symbols 64 | ) 65 | self.split_letters = split_letters 66 | 67 | def __call__(self, s: str): 68 | s = s.lower() 69 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets 70 | s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis 71 | s = self.clean(s).lower() 72 | 73 | if self.split_letters: 74 | s = " ".join(regex.findall(r"\X", s, regex.U)) 75 | 76 | s = re.sub( 77 | r"\s+", " ", s 78 | ) # replace any successive whitespace characters with a space 79 | 80 | return s 81 | -------------------------------------------------------------------------------- /whisper/normalizers/english.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | from fractions import Fraction 5 | from typing import Iterator, List, Match, Optional, Union 6 | 7 | from more_itertools import windowed 8 | 9 | from .basic import remove_symbols_and_diacritics 10 | 11 | 12 | class EnglishNumberNormalizer: 13 | """ 14 | Convert any spelled-out numbers into arabic numbers, while handling: 15 | 16 | - remove any commas 17 | - keep the suffixes such as: `1960s`, `274th`, `32nd`, etc. 18 | - spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars` 19 | - spell out `one` and `ones` 20 | - interpret successive single-digit numbers as nominal: `one oh one` -> `101` 21 | """ 22 | 23 | def __init__(self): 24 | super().__init__() 25 | 26 | self.zeros = {"o", "oh", "zero"} 27 | self.ones = { 28 | name: i 29 | for i, name in enumerate( 30 | [ 31 | "one", 32 | "two", 33 | "three", 34 | "four", 35 | "five", 36 | "six", 37 | "seven", 38 | "eight", 39 | "nine", 40 | "ten", 41 | "eleven", 42 | "twelve", 43 | "thirteen", 44 | "fourteen", 45 | "fifteen", 46 | "sixteen", 47 | "seventeen", 48 | "eighteen", 49 | "nineteen", 50 | ], 51 | start=1, 52 | ) 53 | } 54 | self.ones_plural = { 55 | "sixes" if name == "six" else name + "s": (value, "s") 56 | for name, value in self.ones.items() 57 | } 58 | self.ones_ordinal = { 59 | "zeroth": (0, "th"), 60 | "first": (1, "st"), 61 | "second": (2, "nd"), 62 | "third": (3, "rd"), 63 | "fifth": (5, "th"), 64 | "twelfth": (12, "th"), 65 | **{ 66 | name + ("h" if name.endswith("t") else "th"): (value, "th") 67 | for name, value in self.ones.items() 68 | if value > 3 and value != 5 and value != 12 69 | }, 70 | } 71 | self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal} 72 | 73 | self.tens = { 74 | "twenty": 20, 75 | "thirty": 30, 76 | "forty": 40, 77 | "fifty": 50, 78 | "sixty": 60, 79 | "seventy": 70, 80 | "eighty": 80, 81 | "ninety": 90, 82 | } 83 | self.tens_plural = { 84 | name.replace("y", "ies"): (value, "s") for name, value in self.tens.items() 85 | } 86 | self.tens_ordinal = { 87 | name.replace("y", "ieth"): (value, "th") 88 | for name, value in self.tens.items() 89 | } 90 | self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal} 91 | 92 | self.multipliers = { 93 | "hundred": 100, 94 | "thousand": 1_000, 95 | "million": 1_000_000, 96 | "billion": 1_000_000_000, 97 | "trillion": 1_000_000_000_000, 98 | "quadrillion": 1_000_000_000_000_000, 99 | "quintillion": 1_000_000_000_000_000_000, 100 | "sextillion": 1_000_000_000_000_000_000_000, 101 | "septillion": 1_000_000_000_000_000_000_000_000, 102 | "octillion": 1_000_000_000_000_000_000_000_000_000, 103 | "nonillion": 1_000_000_000_000_000_000_000_000_000_000, 104 | "decillion": 1_000_000_000_000_000_000_000_000_000_000_000, 105 | } 106 | self.multipliers_plural = { 107 | name + "s": (value, "s") for name, value in self.multipliers.items() 108 | } 109 | self.multipliers_ordinal = { 110 | name + "th": (value, "th") for name, value in self.multipliers.items() 111 | } 112 | self.multipliers_suffixed = { 113 | **self.multipliers_plural, 114 | **self.multipliers_ordinal, 115 | } 116 | self.decimals = {*self.ones, *self.tens, *self.zeros} 117 | 118 | self.preceding_prefixers = { 119 | "minus": "-", 120 | "negative": "-", 121 | "plus": "+", 122 | "positive": "+", 123 | } 124 | self.following_prefixers = { 125 | "pound": "£", 126 | "pounds": "£", 127 | "euro": "€", 128 | "euros": "€", 129 | "dollar": "$", 130 | "dollars": "$", 131 | "cent": "¢", 132 | "cents": "¢", 133 | } 134 | self.prefixes = set( 135 | list(self.preceding_prefixers.values()) 136 | + list(self.following_prefixers.values()) 137 | ) 138 | self.suffixers = { 139 | "per": {"cent": "%"}, 140 | "percent": "%", 141 | } 142 | self.specials = {"and", "double", "triple", "point"} 143 | 144 | self.words = set( 145 | [ 146 | key 147 | for mapping in [ 148 | self.zeros, 149 | self.ones, 150 | self.ones_suffixed, 151 | self.tens, 152 | self.tens_suffixed, 153 | self.multipliers, 154 | self.multipliers_suffixed, 155 | self.preceding_prefixers, 156 | self.following_prefixers, 157 | self.suffixers, 158 | self.specials, 159 | ] 160 | for key in mapping 161 | ] 162 | ) 163 | self.literal_words = {"one", "ones"} 164 | 165 | def process_words(self, words: List[str]) -> Iterator[str]: 166 | prefix: Optional[str] = None 167 | value: Optional[Union[str, int]] = None 168 | skip = False 169 | 170 | def to_fraction(s: str): 171 | try: 172 | return Fraction(s) 173 | except ValueError: 174 | return None 175 | 176 | def output(result: Union[str, int]): 177 | nonlocal prefix, value 178 | result = str(result) 179 | if prefix is not None: 180 | result = prefix + result 181 | value = None 182 | prefix = None 183 | return result 184 | 185 | if len(words) == 0: 186 | return 187 | 188 | for prev, current, next in windowed([None] + words + [None], 3): 189 | if skip: 190 | skip = False 191 | continue 192 | 193 | next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next) 194 | has_prefix = current[0] in self.prefixes 195 | current_without_prefix = current[1:] if has_prefix else current 196 | if re.match(r"^\d+(\.\d+)?$", current_without_prefix): 197 | # arabic numbers (potentially with signs and fractions) 198 | f = to_fraction(current_without_prefix) 199 | assert f is not None 200 | if value is not None: 201 | if isinstance(value, str) and value.endswith("."): 202 | # concatenate decimals / ip address components 203 | value = str(value) + str(current) 204 | continue 205 | else: 206 | yield output(value) 207 | 208 | prefix = current[0] if has_prefix else prefix 209 | if f.denominator == 1: 210 | value = f.numerator # store integers as int 211 | else: 212 | value = current_without_prefix 213 | elif current not in self.words: 214 | # non-numeric words 215 | if value is not None: 216 | yield output(value) 217 | yield output(current) 218 | elif current in self.zeros: 219 | value = str(value or "") + "0" 220 | elif current in self.ones: 221 | ones = self.ones[current] 222 | 223 | if value is None: 224 | value = ones 225 | elif isinstance(value, str) or prev in self.ones: 226 | if ( 227 | prev in self.tens and ones < 10 228 | ): # replace the last zero with the digit 229 | assert value[-1] == "0" 230 | value = value[:-1] + str(ones) 231 | else: 232 | value = str(value) + str(ones) 233 | elif ones < 10: 234 | if value % 10 == 0: 235 | value += ones 236 | else: 237 | value = str(value) + str(ones) 238 | else: # eleven to nineteen 239 | if value % 100 == 0: 240 | value += ones 241 | else: 242 | value = str(value) + str(ones) 243 | elif current in self.ones_suffixed: 244 | # ordinal or cardinal; yield the number right away 245 | ones, suffix = self.ones_suffixed[current] 246 | if value is None: 247 | yield output(str(ones) + suffix) 248 | elif isinstance(value, str) or prev in self.ones: 249 | if prev in self.tens and ones < 10: 250 | assert value[-1] == "0" 251 | yield output(value[:-1] + str(ones) + suffix) 252 | else: 253 | yield output(str(value) + str(ones) + suffix) 254 | elif ones < 10: 255 | if value % 10 == 0: 256 | yield output(str(value + ones) + suffix) 257 | else: 258 | yield output(str(value) + str(ones) + suffix) 259 | else: # eleven to nineteen 260 | if value % 100 == 0: 261 | yield output(str(value + ones) + suffix) 262 | else: 263 | yield output(str(value) + str(ones) + suffix) 264 | value = None 265 | elif current in self.tens: 266 | tens = self.tens[current] 267 | if value is None: 268 | value = tens 269 | elif isinstance(value, str): 270 | value = str(value) + str(tens) 271 | else: 272 | if value % 100 == 0: 273 | value += tens 274 | else: 275 | value = str(value) + str(tens) 276 | elif current in self.tens_suffixed: 277 | # ordinal or cardinal; yield the number right away 278 | tens, suffix = self.tens_suffixed[current] 279 | if value is None: 280 | yield output(str(tens) + suffix) 281 | elif isinstance(value, str): 282 | yield output(str(value) + str(tens) + suffix) 283 | else: 284 | if value % 100 == 0: 285 | yield output(str(value + tens) + suffix) 286 | else: 287 | yield output(str(value) + str(tens) + suffix) 288 | elif current in self.multipliers: 289 | multiplier = self.multipliers[current] 290 | if value is None: 291 | value = multiplier 292 | elif isinstance(value, str) or value == 0: 293 | f = to_fraction(value) 294 | p = f * multiplier if f is not None else None 295 | if f is not None and p.denominator == 1: 296 | value = p.numerator 297 | else: 298 | yield output(value) 299 | value = multiplier 300 | else: 301 | before = value // 1000 * 1000 302 | residual = value % 1000 303 | value = before + residual * multiplier 304 | elif current in self.multipliers_suffixed: 305 | multiplier, suffix = self.multipliers_suffixed[current] 306 | if value is None: 307 | yield output(str(multiplier) + suffix) 308 | elif isinstance(value, str): 309 | f = to_fraction(value) 310 | p = f * multiplier if f is not None else None 311 | if f is not None and p.denominator == 1: 312 | yield output(str(p.numerator) + suffix) 313 | else: 314 | yield output(value) 315 | yield output(str(multiplier) + suffix) 316 | else: # int 317 | before = value // 1000 * 1000 318 | residual = value % 1000 319 | value = before + residual * multiplier 320 | yield output(str(value) + suffix) 321 | value = None 322 | elif current in self.preceding_prefixers: 323 | # apply prefix (positive, minus, etc.) if it precedes a number 324 | if value is not None: 325 | yield output(value) 326 | 327 | if next in self.words or next_is_numeric: 328 | prefix = self.preceding_prefixers[current] 329 | else: 330 | yield output(current) 331 | elif current in self.following_prefixers: 332 | # apply prefix (dollars, cents, etc.) only after a number 333 | if value is not None: 334 | prefix = self.following_prefixers[current] 335 | yield output(value) 336 | else: 337 | yield output(current) 338 | elif current in self.suffixers: 339 | # apply suffix symbols (percent -> '%') 340 | if value is not None: 341 | suffix = self.suffixers[current] 342 | if isinstance(suffix, dict): 343 | if next in suffix: 344 | yield output(str(value) + suffix[next]) 345 | skip = True 346 | else: 347 | yield output(value) 348 | yield output(current) 349 | else: 350 | yield output(str(value) + suffix) 351 | else: 352 | yield output(current) 353 | elif current in self.specials: 354 | if next not in self.words and not next_is_numeric: 355 | # apply special handling only if the next word can be numeric 356 | if value is not None: 357 | yield output(value) 358 | yield output(current) 359 | elif current == "and": 360 | # ignore "and" after hundreds, thousands, etc. 361 | if prev not in self.multipliers: 362 | if value is not None: 363 | yield output(value) 364 | yield output(current) 365 | elif current == "double" or current == "triple": 366 | if next in self.ones or next in self.zeros: 367 | repeats = 2 if current == "double" else 3 368 | ones = self.ones.get(next, 0) 369 | value = str(value or "") + str(ones) * repeats 370 | skip = True 371 | else: 372 | if value is not None: 373 | yield output(value) 374 | yield output(current) 375 | elif current == "point": 376 | if next in self.decimals or next_is_numeric: 377 | value = str(value or "") + "." 378 | else: 379 | # should all have been covered at this point 380 | raise ValueError(f"Unexpected token: {current}") 381 | else: 382 | # all should have been covered at this point 383 | raise ValueError(f"Unexpected token: {current}") 384 | 385 | if value is not None: 386 | yield output(value) 387 | 388 | def preprocess(self, s: str): 389 | # replace " and a half" with " point five" 390 | results = [] 391 | 392 | segments = re.split(r"\band\s+a\s+half\b", s) 393 | for i, segment in enumerate(segments): 394 | if len(segment.strip()) == 0: 395 | continue 396 | if i == len(segments) - 1: 397 | results.append(segment) 398 | else: 399 | results.append(segment) 400 | last_word = segment.rsplit(maxsplit=2)[-1] 401 | if last_word in self.decimals or last_word in self.multipliers: 402 | results.append("point five") 403 | else: 404 | results.append("and a half") 405 | 406 | s = " ".join(results) 407 | 408 | # put a space at number/letter boundary 409 | s = re.sub(r"([a-z])([0-9])", r"\1 \2", s) 410 | s = re.sub(r"([0-9])([a-z])", r"\1 \2", s) 411 | 412 | # but remove spaces which could be a suffix 413 | s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s) 414 | 415 | return s 416 | 417 | def postprocess(self, s: str): 418 | def combine_cents(m: Match): 419 | try: 420 | currency = m.group(1) 421 | integer = m.group(2) 422 | cents = int(m.group(3)) 423 | return f"{currency}{integer}.{cents:02d}" 424 | except ValueError: 425 | return m.string 426 | 427 | def extract_cents(m: Match): 428 | try: 429 | return f"¢{int(m.group(1))}" 430 | except ValueError: 431 | return m.string 432 | 433 | # apply currency postprocessing; "$2 and ¢7" -> "$2.07" 434 | s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s) 435 | s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s) 436 | 437 | # write "one(s)" instead of "1(s)", just for the readability 438 | s = re.sub(r"\b1(s?)\b", r"one\1", s) 439 | 440 | return s 441 | 442 | def __call__(self, s: str): 443 | s = self.preprocess(s) 444 | s = " ".join(word for word in self.process_words(s.split()) if word is not None) 445 | s = self.postprocess(s) 446 | 447 | return s 448 | 449 | 450 | class EnglishSpellingNormalizer: 451 | """ 452 | Applies British-American spelling mappings as listed in [1]. 453 | 454 | [1] https://www.tysto.com/uk-us-spelling-list.html 455 | """ 456 | 457 | def __init__(self): 458 | mapping_path = os.path.join(os.path.dirname(__file__), "english.json") 459 | self.mapping = json.load(open(mapping_path)) 460 | 461 | def __call__(self, s: str): 462 | return " ".join(self.mapping.get(word, word) for word in s.split()) 463 | 464 | 465 | class EnglishTextNormalizer: 466 | def __init__(self): 467 | self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b" 468 | self.replacers = { 469 | # common contractions 470 | r"\bwon't\b": "will not", 471 | r"\bcan't\b": "can not", 472 | r"\blet's\b": "let us", 473 | r"\bain't\b": "aint", 474 | r"\by'all\b": "you all", 475 | r"\bwanna\b": "want to", 476 | r"\bgotta\b": "got to", 477 | r"\bgonna\b": "going to", 478 | r"\bi'ma\b": "i am going to", 479 | r"\bimma\b": "i am going to", 480 | r"\bwoulda\b": "would have", 481 | r"\bcoulda\b": "could have", 482 | r"\bshoulda\b": "should have", 483 | r"\bma'am\b": "madam", 484 | # contractions in titles/prefixes 485 | r"\bmr\b": "mister ", 486 | r"\bmrs\b": "missus ", 487 | r"\bst\b": "saint ", 488 | r"\bdr\b": "doctor ", 489 | r"\bprof\b": "professor ", 490 | r"\bcapt\b": "captain ", 491 | r"\bgov\b": "governor ", 492 | r"\bald\b": "alderman ", 493 | r"\bgen\b": "general ", 494 | r"\bsen\b": "senator ", 495 | r"\brep\b": "representative ", 496 | r"\bpres\b": "president ", 497 | r"\brev\b": "reverend ", 498 | r"\bhon\b": "honorable ", 499 | r"\basst\b": "assistant ", 500 | r"\bassoc\b": "associate ", 501 | r"\blt\b": "lieutenant ", 502 | r"\bcol\b": "colonel ", 503 | r"\bjr\b": "junior ", 504 | r"\bsr\b": "senior ", 505 | r"\besq\b": "esquire ", 506 | # prefect tenses, ideally it should be any past participles, but it's harder.. 507 | r"'d been\b": " had been", 508 | r"'s been\b": " has been", 509 | r"'d gone\b": " had gone", 510 | r"'s gone\b": " has gone", 511 | r"'d done\b": " had done", # "'s done" is ambiguous 512 | r"'s got\b": " has got", 513 | # general contractions 514 | r"n't\b": " not", 515 | r"'re\b": " are", 516 | r"'s\b": " is", 517 | r"'d\b": " would", 518 | r"'ll\b": " will", 519 | r"'t\b": " not", 520 | r"'ve\b": " have", 521 | r"'m\b": " am", 522 | } 523 | self.standardize_numbers = EnglishNumberNormalizer() 524 | self.standardize_spellings = EnglishSpellingNormalizer() 525 | 526 | def __call__(self, s: str): 527 | s = s.lower() 528 | 529 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets 530 | s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis 531 | s = re.sub(self.ignore_patterns, "", s) 532 | s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe 533 | 534 | for pattern, replacement in self.replacers.items(): 535 | s = re.sub(pattern, replacement, s) 536 | 537 | s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits 538 | s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers 539 | s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols 540 | 541 | s = self.standardize_numbers(s) 542 | s = self.standardize_spellings(s) 543 | 544 | # now remove prefix/suffix symbols that are not preceded/followed by numbers 545 | s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s) 546 | s = re.sub(r"([^0-9])%", r"\1 ", s) 547 | 548 | s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space 549 | 550 | return s 551 | -------------------------------------------------------------------------------- /whisper/timing.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import subprocess 3 | import warnings 4 | from dataclasses import dataclass 5 | from typing import TYPE_CHECKING, List 6 | 7 | import numba 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND 13 | from .tokenizer import Tokenizer 14 | 15 | if TYPE_CHECKING: 16 | from .model import Whisper 17 | 18 | 19 | def median_filter(x: torch.Tensor, filter_width: int): 20 | """Apply a median filter of width `filter_width` along the last dimension of `x`""" 21 | pad_width = filter_width // 2 22 | if x.shape[-1] <= pad_width: 23 | # F.pad requires the padding width to be smaller than the input dimension 24 | return x 25 | 26 | if (ndim := x.ndim) <= 2: 27 | # `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D 28 | x = x[None, None, :] 29 | 30 | assert ( 31 | filter_width > 0 and filter_width % 2 == 1 32 | ), "`filter_width` should be an odd number" 33 | 34 | result = None 35 | x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect") 36 | if x.is_cuda: 37 | try: 38 | from .triton_ops import median_filter_cuda 39 | 40 | result = median_filter_cuda(x, filter_width) 41 | except (RuntimeError, subprocess.CalledProcessError): 42 | warnings.warn( 43 | "Failed to launch Triton kernels, likely due to missing CUDA toolkit; " 44 | "falling back to a slower median kernel implementation..." 45 | ) 46 | 47 | if result is None: 48 | # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450) 49 | result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2] 50 | 51 | if ndim <= 2: 52 | result = result[0, 0] 53 | 54 | return result 55 | 56 | 57 | @numba.jit(nopython=True) 58 | def backtrace(trace: np.ndarray): 59 | i = trace.shape[0] - 1 60 | j = trace.shape[1] - 1 61 | trace[0, :] = 2 62 | trace[:, 0] = 1 63 | 64 | result = [] 65 | while i > 0 or j > 0: 66 | result.append((i - 1, j - 1)) 67 | 68 | if trace[i, j] == 0: 69 | i -= 1 70 | j -= 1 71 | elif trace[i, j] == 1: 72 | i -= 1 73 | elif trace[i, j] == 2: 74 | j -= 1 75 | else: 76 | raise ValueError("Unexpected trace[i, j]") 77 | 78 | result = np.array(result) 79 | return result[::-1, :].T 80 | 81 | 82 | @numba.jit(nopython=True, parallel=True) 83 | def dtw_cpu(x: np.ndarray): 84 | N, M = x.shape 85 | cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf 86 | trace = -np.ones((N + 1, M + 1), dtype=np.float32) 87 | 88 | cost[0, 0] = 0 89 | for j in range(1, M + 1): 90 | for i in range(1, N + 1): 91 | c0 = cost[i - 1, j - 1] 92 | c1 = cost[i - 1, j] 93 | c2 = cost[i, j - 1] 94 | 95 | if c0 < c1 and c0 < c2: 96 | c, t = c0, 0 97 | elif c1 < c0 and c1 < c2: 98 | c, t = c1, 1 99 | else: 100 | c, t = c2, 2 101 | 102 | cost[i, j] = x[i - 1, j - 1] + c 103 | trace[i, j] = t 104 | 105 | return backtrace(trace) 106 | 107 | 108 | def dtw_cuda(x, BLOCK_SIZE=1024): 109 | from .triton_ops import dtw_kernel 110 | 111 | M, N = x.shape 112 | assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}" 113 | 114 | x_skew = ( 115 | F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M) 116 | ) 117 | x_skew = x_skew.T.contiguous() 118 | cost = torch.ones(N + M + 2, M + 2) * np.inf 119 | cost[0, 0] = 0 120 | cost = cost.cuda() 121 | trace = torch.zeros_like(cost, dtype=torch.int32) 122 | 123 | dtw_kernel[(1,)]( 124 | cost, 125 | trace, 126 | x_skew, 127 | x_skew.stride(0), 128 | cost.stride(0), 129 | trace.stride(0), 130 | N, 131 | M, 132 | BLOCK_SIZE=BLOCK_SIZE, 133 | ) 134 | 135 | trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[ 136 | :, : N + 1 137 | ] 138 | return backtrace(trace.cpu().numpy()) 139 | 140 | 141 | def dtw(x: torch.Tensor) -> np.ndarray: 142 | if x.is_cuda: 143 | try: 144 | return dtw_cuda(x) 145 | except (RuntimeError, subprocess.CalledProcessError): 146 | warnings.warn( 147 | "Failed to launch Triton kernels, likely due to missing CUDA toolkit; " 148 | "falling back to a slower DTW implementation..." 149 | ) 150 | 151 | return dtw_cpu(x.double().cpu().numpy()) 152 | 153 | 154 | @dataclass 155 | class WordTiming: 156 | word: str 157 | tokens: List[int] 158 | start: float 159 | end: float 160 | probability: float 161 | 162 | 163 | def find_alignment( 164 | model: "Whisper", 165 | tokenizer: Tokenizer, 166 | text_tokens: List[int], 167 | mel: torch.Tensor, 168 | num_frames: int, 169 | *, 170 | medfilt_width: int = 7, 171 | qk_scale: float = 1.0, 172 | ) -> List[WordTiming]: 173 | if len(text_tokens) == 0: 174 | return [] 175 | 176 | tokens = torch.tensor( 177 | [ 178 | *tokenizer.sot_sequence, 179 | tokenizer.no_timestamps, 180 | *text_tokens, 181 | tokenizer.eot, 182 | ] 183 | ).to(model.device) 184 | 185 | # install hooks on the cross attention layers to retrieve the attention weights 186 | QKs = [None] * model.dims.n_text_layer 187 | hooks = [ 188 | block.cross_attn.register_forward_hook( 189 | lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0]) 190 | ) 191 | for i, block in enumerate(model.decoder.blocks) 192 | ] 193 | 194 | from .model import disable_sdpa 195 | 196 | with torch.no_grad(), disable_sdpa(): 197 | logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0] 198 | sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot] 199 | token_probs = sampled_logits.softmax(dim=-1) 200 | text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens] 201 | text_token_probs = text_token_probs.tolist() 202 | 203 | for hook in hooks: 204 | hook.remove() 205 | 206 | # heads * tokens * frames 207 | weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T]) 208 | weights = weights[:, :, : num_frames // 2] 209 | weights = (weights * qk_scale).softmax(dim=-1) 210 | std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False) 211 | weights = (weights - mean) / std 212 | weights = median_filter(weights, medfilt_width) 213 | 214 | matrix = weights.mean(axis=0) 215 | matrix = matrix[len(tokenizer.sot_sequence) : -1] 216 | text_indices, time_indices = dtw(-matrix) 217 | 218 | words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot]) 219 | if len(word_tokens) <= 1: 220 | # return on eot only 221 | # >>> np.pad([], (1, 0)) 222 | # array([0.]) 223 | # This results in crashes when we lookup jump_times with float, like 224 | # IndexError: arrays used as indices must be of integer (or boolean) type 225 | return [] 226 | word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)) 227 | 228 | jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) 229 | jump_times = time_indices[jumps] / TOKENS_PER_SECOND 230 | start_times = jump_times[word_boundaries[:-1]] 231 | end_times = jump_times[word_boundaries[1:]] 232 | word_probabilities = [ 233 | np.mean(text_token_probs[i:j]) 234 | for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) 235 | ] 236 | 237 | return [ 238 | WordTiming(word, tokens, start, end, probability) 239 | for word, tokens, start, end, probability in zip( 240 | words, word_tokens, start_times, end_times, word_probabilities 241 | ) 242 | ] 243 | 244 | 245 | def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str): 246 | # merge prepended punctuations 247 | i = len(alignment) - 2 248 | j = len(alignment) - 1 249 | while i >= 0: 250 | previous = alignment[i] 251 | following = alignment[j] 252 | if previous.word.startswith(" ") and previous.word.strip() in prepended: 253 | # prepend it to the following word 254 | following.word = previous.word + following.word 255 | following.tokens = previous.tokens + following.tokens 256 | previous.word = "" 257 | previous.tokens = [] 258 | else: 259 | j = i 260 | i -= 1 261 | 262 | # merge appended punctuations 263 | i = 0 264 | j = 1 265 | while j < len(alignment): 266 | previous = alignment[i] 267 | following = alignment[j] 268 | if not previous.word.endswith(" ") and following.word in appended: 269 | # append it to the previous word 270 | previous.word = previous.word + following.word 271 | previous.tokens = previous.tokens + following.tokens 272 | following.word = "" 273 | following.tokens = [] 274 | else: 275 | i = j 276 | j += 1 277 | 278 | 279 | def add_word_timestamps( 280 | *, 281 | segments: List[dict], 282 | model: "Whisper", 283 | tokenizer: Tokenizer, 284 | mel: torch.Tensor, 285 | num_frames: int, 286 | prepend_punctuations: str = "\"'“¿([{-", 287 | append_punctuations: str = "\"'.。,,!!??::”)]}、", 288 | last_speech_timestamp: float, 289 | **kwargs, 290 | ): 291 | if len(segments) == 0: 292 | return 293 | 294 | text_tokens_per_segment = [ 295 | [token for token in segment["tokens"] if token < tokenizer.eot] 296 | for segment in segments 297 | ] 298 | 299 | text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment)) 300 | alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs) 301 | word_durations = np.array([t.end - t.start for t in alignment]) 302 | word_durations = word_durations[word_durations.nonzero()] 303 | median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0 304 | median_duration = min(0.7, float(median_duration)) 305 | max_duration = median_duration * 2 306 | 307 | # hack: truncate long words at sentence boundaries. 308 | # a better segmentation algorithm based on VAD should be able to replace this. 309 | if len(word_durations) > 0: 310 | sentence_end_marks = ".。!!??" 311 | # ensure words at sentence boundaries are not longer than twice the median word duration. 312 | for i in range(1, len(alignment)): 313 | if alignment[i].end - alignment[i].start > max_duration: 314 | if alignment[i].word in sentence_end_marks: 315 | alignment[i].end = alignment[i].start + max_duration 316 | elif alignment[i - 1].word in sentence_end_marks: 317 | alignment[i].start = alignment[i].end - max_duration 318 | 319 | merge_punctuations(alignment, prepend_punctuations, append_punctuations) 320 | 321 | time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE 322 | word_index = 0 323 | 324 | for segment, text_tokens in zip(segments, text_tokens_per_segment): 325 | saved_tokens = 0 326 | words = [] 327 | 328 | while word_index < len(alignment) and saved_tokens < len(text_tokens): 329 | timing = alignment[word_index] 330 | 331 | if timing.word: 332 | words.append( 333 | dict( 334 | word=timing.word, 335 | start=round(time_offset + timing.start, 2), 336 | end=round(time_offset + timing.end, 2), 337 | probability=timing.probability, 338 | ) 339 | ) 340 | 341 | saved_tokens += len(timing.tokens) 342 | word_index += 1 343 | 344 | # hack: truncate long words at segment boundaries. 345 | # a better segmentation algorithm based on VAD should be able to replace this. 346 | if len(words) > 0: 347 | # ensure the first and second word after a pause is not longer than 348 | # twice the median word duration. 349 | if words[0]["end"] - last_speech_timestamp > median_duration * 4 and ( 350 | words[0]["end"] - words[0]["start"] > max_duration 351 | or ( 352 | len(words) > 1 353 | and words[1]["end"] - words[0]["start"] > max_duration * 2 354 | ) 355 | ): 356 | if ( 357 | len(words) > 1 358 | and words[1]["end"] - words[1]["start"] > max_duration 359 | ): 360 | boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration) 361 | words[0]["end"] = words[1]["start"] = boundary 362 | words[0]["start"] = max(0, words[0]["end"] - max_duration) 363 | 364 | # prefer the segment-level start timestamp if the first word is too long. 365 | if ( 366 | segment["start"] < words[0]["end"] 367 | and segment["start"] - 0.5 > words[0]["start"] 368 | ): 369 | words[0]["start"] = max( 370 | 0, min(words[0]["end"] - median_duration, segment["start"]) 371 | ) 372 | else: 373 | segment["start"] = words[0]["start"] 374 | 375 | # prefer the segment-level end timestamp if the last word is too long. 376 | if ( 377 | segment["end"] > words[-1]["start"] 378 | and segment["end"] + 0.5 < words[-1]["end"] 379 | ): 380 | words[-1]["end"] = max( 381 | words[-1]["start"] + median_duration, segment["end"] 382 | ) 383 | else: 384 | segment["end"] = words[-1]["end"] 385 | 386 | last_speech_timestamp = segment["end"] 387 | 388 | segment["words"] = words 389 | -------------------------------------------------------------------------------- /whisper/tokenizer.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import os 3 | import string 4 | from dataclasses import dataclass, field 5 | from functools import cached_property, lru_cache 6 | from typing import Dict, List, Optional, Tuple 7 | 8 | import tiktoken 9 | 10 | LANGUAGES = { 11 | "en": "english", 12 | "zh": "chinese", 13 | "de": "german", 14 | "es": "spanish", 15 | "ru": "russian", 16 | "ko": "korean", 17 | "fr": "french", 18 | "ja": "japanese", 19 | "pt": "portuguese", 20 | "tr": "turkish", 21 | "pl": "polish", 22 | "ca": "catalan", 23 | "nl": "dutch", 24 | "ar": "arabic", 25 | "sv": "swedish", 26 | "it": "italian", 27 | "id": "indonesian", 28 | "hi": "hindi", 29 | "fi": "finnish", 30 | "vi": "vietnamese", 31 | "he": "hebrew", 32 | "uk": "ukrainian", 33 | "el": "greek", 34 | "ms": "malay", 35 | "cs": "czech", 36 | "ro": "romanian", 37 | "da": "danish", 38 | "hu": "hungarian", 39 | "ta": "tamil", 40 | "no": "norwegian", 41 | "th": "thai", 42 | "ur": "urdu", 43 | "hr": "croatian", 44 | "bg": "bulgarian", 45 | "lt": "lithuanian", 46 | "la": "latin", 47 | "mi": "maori", 48 | "ml": "malayalam", 49 | "cy": "welsh", 50 | "sk": "slovak", 51 | "te": "telugu", 52 | "fa": "persian", 53 | "lv": "latvian", 54 | "bn": "bengali", 55 | "sr": "serbian", 56 | "az": "azerbaijani", 57 | "sl": "slovenian", 58 | "kn": "kannada", 59 | "et": "estonian", 60 | "mk": "macedonian", 61 | "br": "breton", 62 | "eu": "basque", 63 | "is": "icelandic", 64 | "hy": "armenian", 65 | "ne": "nepali", 66 | "mn": "mongolian", 67 | "bs": "bosnian", 68 | "kk": "kazakh", 69 | "sq": "albanian", 70 | "sw": "swahili", 71 | "gl": "galician", 72 | "mr": "marathi", 73 | "pa": "punjabi", 74 | "si": "sinhala", 75 | "km": "khmer", 76 | "sn": "shona", 77 | "yo": "yoruba", 78 | "so": "somali", 79 | "af": "afrikaans", 80 | "oc": "occitan", 81 | "ka": "georgian", 82 | "be": "belarusian", 83 | "tg": "tajik", 84 | "sd": "sindhi", 85 | "gu": "gujarati", 86 | "am": "amharic", 87 | "yi": "yiddish", 88 | "lo": "lao", 89 | "uz": "uzbek", 90 | "fo": "faroese", 91 | "ht": "haitian creole", 92 | "ps": "pashto", 93 | "tk": "turkmen", 94 | "nn": "nynorsk", 95 | "mt": "maltese", 96 | "sa": "sanskrit", 97 | "lb": "luxembourgish", 98 | "my": "myanmar", 99 | "bo": "tibetan", 100 | "tl": "tagalog", 101 | "mg": "malagasy", 102 | "as": "assamese", 103 | "tt": "tatar", 104 | "haw": "hawaiian", 105 | "ln": "lingala", 106 | "ha": "hausa", 107 | "ba": "bashkir", 108 | "jw": "javanese", 109 | "su": "sundanese", 110 | "yue": "cantonese", 111 | } 112 | 113 | # language code lookup by name, with a few language aliases 114 | TO_LANGUAGE_CODE = { 115 | **{language: code for code, language in LANGUAGES.items()}, 116 | "burmese": "my", 117 | "valencian": "ca", 118 | "flemish": "nl", 119 | "haitian": "ht", 120 | "letzeburgesch": "lb", 121 | "pushto": "ps", 122 | "panjabi": "pa", 123 | "moldavian": "ro", 124 | "moldovan": "ro", 125 | "sinhalese": "si", 126 | "castilian": "es", 127 | "mandarin": "zh", 128 | } 129 | 130 | 131 | @dataclass 132 | class Tokenizer: 133 | """A thin wrapper around `tiktoken` providing quick access to special tokens""" 134 | 135 | encoding: tiktoken.Encoding 136 | num_languages: int 137 | language: Optional[str] = None 138 | task: Optional[str] = None 139 | sot_sequence: Tuple[int] = () 140 | special_tokens: Dict[str, int] = field(default_factory=dict) 141 | 142 | def __post_init__(self): 143 | for special in self.encoding.special_tokens_set: 144 | special_token = self.encoding.encode_single_token(special) 145 | self.special_tokens[special] = special_token 146 | 147 | sot: int = self.special_tokens["<|startoftranscript|>"] 148 | translate: int = self.special_tokens["<|translate|>"] 149 | transcribe: int = self.special_tokens["<|transcribe|>"] 150 | 151 | langs = tuple(LANGUAGES.keys())[: self.num_languages] 152 | sot_sequence = [sot] 153 | if self.language is not None: 154 | sot_sequence.append(sot + 1 + langs.index(self.language)) 155 | if self.task is not None: 156 | task_token: int = transcribe if self.task == "transcribe" else translate 157 | sot_sequence.append(task_token) 158 | 159 | self.sot_sequence = tuple(sot_sequence) 160 | 161 | def encode(self, text, **kwargs): 162 | return self.encoding.encode(text, **kwargs) 163 | 164 | def decode(self, token_ids: List[int], **kwargs) -> str: 165 | token_ids = [t for t in token_ids if t < self.timestamp_begin] 166 | return self.encoding.decode(token_ids, **kwargs) 167 | 168 | def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str: 169 | """ 170 | Timestamp tokens are above other special tokens' id range and are ignored by `decode()`. 171 | This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". 172 | """ 173 | return self.encoding.decode(token_ids, **kwargs) 174 | 175 | @cached_property 176 | def eot(self) -> int: 177 | return self.encoding.eot_token 178 | 179 | @cached_property 180 | def transcribe(self) -> int: 181 | return self.special_tokens["<|transcribe|>"] 182 | 183 | @cached_property 184 | def translate(self) -> int: 185 | return self.special_tokens["<|translate|>"] 186 | 187 | @cached_property 188 | def sot(self) -> int: 189 | return self.special_tokens["<|startoftranscript|>"] 190 | 191 | @cached_property 192 | def sot_lm(self) -> int: 193 | return self.special_tokens["<|startoflm|>"] 194 | 195 | @cached_property 196 | def sot_prev(self) -> int: 197 | return self.special_tokens["<|startofprev|>"] 198 | 199 | @cached_property 200 | def no_speech(self) -> int: 201 | return self.special_tokens["<|nospeech|>"] 202 | 203 | @cached_property 204 | def no_timestamps(self) -> int: 205 | return self.special_tokens["<|notimestamps|>"] 206 | 207 | @cached_property 208 | def timestamp_begin(self) -> int: 209 | return self.special_tokens["<|0.00|>"] 210 | 211 | @cached_property 212 | def language_token(self) -> int: 213 | """Returns the token id corresponding to the value of the `language` field""" 214 | if self.language is None: 215 | raise ValueError("This tokenizer does not have language token configured") 216 | 217 | return self.to_language_token(self.language) 218 | 219 | def to_language_token(self, language): 220 | if token := self.special_tokens.get(f"<|{language}|>", None): 221 | return token 222 | 223 | raise KeyError(f"Language {language} not found in tokenizer.") 224 | 225 | @cached_property 226 | def all_language_tokens(self) -> Tuple[int]: 227 | result = [] 228 | for token, token_id in self.special_tokens.items(): 229 | if token.strip("<|>") in LANGUAGES: 230 | result.append(token_id) 231 | return tuple(result)[: self.num_languages] 232 | 233 | @cached_property 234 | def all_language_codes(self) -> Tuple[str]: 235 | return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens) 236 | 237 | @cached_property 238 | def sot_sequence_including_notimestamps(self) -> Tuple[int]: 239 | return tuple(list(self.sot_sequence) + [self.no_timestamps]) 240 | 241 | @cached_property 242 | def non_speech_tokens(self) -> Tuple[int]: 243 | """ 244 | Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech 245 | annotations, to prevent sampling texts that are not actually spoken in the audio, e.g. 246 | 247 | - ♪♪♪ 248 | - ( SPEAKING FOREIGN LANGUAGE ) 249 | - [DAVID] Hey there, 250 | 251 | keeping basic punctuations like commas, periods, question marks, exclamation points, etc. 252 | """ 253 | symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』') 254 | symbols += ( 255 | "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split() 256 | ) 257 | 258 | # symbols that may be a single token or multiple tokens depending on the tokenizer. 259 | # In case they're multiple tokens, suppress the first token, which is safe because: 260 | # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress 261 | # in generations, and in the 3-byte UTF-8 representation they share the first two bytes. 262 | miscellaneous = set("♩♪♫♬♭♮♯") 263 | assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous) 264 | 265 | # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word 266 | result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]} 267 | for symbol in symbols + list(miscellaneous): 268 | for tokens in [ 269 | self.encoding.encode(symbol), 270 | self.encoding.encode(" " + symbol), 271 | ]: 272 | if len(tokens) == 1 or symbol in miscellaneous: 273 | result.add(tokens[0]) 274 | 275 | return tuple(sorted(result)) 276 | 277 | def split_to_word_tokens(self, tokens: List[int]): 278 | if self.language in {"zh", "ja", "th", "lo", "my", "yue"}: 279 | # These languages don't typically use spaces, so it is difficult to split words 280 | # without morpheme analysis. Here, we instead split words at any 281 | # position where the tokens are decoded as valid unicode points 282 | return self.split_tokens_on_unicode(tokens) 283 | 284 | return self.split_tokens_on_spaces(tokens) 285 | 286 | def split_tokens_on_unicode(self, tokens: List[int]): 287 | decoded_full = self.decode_with_timestamps(tokens) 288 | replacement_char = "\ufffd" 289 | 290 | words = [] 291 | word_tokens = [] 292 | current_tokens = [] 293 | unicode_offset = 0 294 | 295 | for token in tokens: 296 | current_tokens.append(token) 297 | decoded = self.decode_with_timestamps(current_tokens) 298 | 299 | if ( 300 | replacement_char not in decoded 301 | or decoded_full[unicode_offset + decoded.index(replacement_char)] 302 | == replacement_char 303 | ): 304 | words.append(decoded) 305 | word_tokens.append(current_tokens) 306 | current_tokens = [] 307 | unicode_offset += len(decoded) 308 | 309 | return words, word_tokens 310 | 311 | def split_tokens_on_spaces(self, tokens: List[int]): 312 | subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens) 313 | words = [] 314 | word_tokens = [] 315 | 316 | for subword, subword_tokens in zip(subwords, subword_tokens_list): 317 | special = subword_tokens[0] >= self.eot 318 | with_space = subword.startswith(" ") 319 | punctuation = subword.strip() in string.punctuation 320 | if special or with_space or punctuation or len(words) == 0: 321 | words.append(subword) 322 | word_tokens.append(subword_tokens) 323 | else: 324 | words[-1] = words[-1] + subword 325 | word_tokens[-1].extend(subword_tokens) 326 | 327 | return words, word_tokens 328 | 329 | 330 | @lru_cache(maxsize=None) 331 | def get_encoding(name: str = "gpt2", num_languages: int = 99): 332 | vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken") 333 | ranks = { 334 | base64.b64decode(token): int(rank) 335 | for token, rank in (line.split() for line in open(vocab_path) if line) 336 | } 337 | n_vocab = len(ranks) 338 | special_tokens = {} 339 | 340 | specials = [ 341 | "<|endoftext|>", 342 | "<|startoftranscript|>", 343 | *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], 344 | "<|translate|>", 345 | "<|transcribe|>", 346 | "<|startoflm|>", 347 | "<|startofprev|>", 348 | "<|nospeech|>", 349 | "<|notimestamps|>", 350 | *[f"<|{i * 0.02:.2f}|>" for i in range(1501)], 351 | ] 352 | 353 | for token in specials: 354 | special_tokens[token] = n_vocab 355 | n_vocab += 1 356 | 357 | return tiktoken.Encoding( 358 | name=os.path.basename(vocab_path), 359 | explicit_n_vocab=n_vocab, 360 | pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", 361 | mergeable_ranks=ranks, 362 | special_tokens=special_tokens, 363 | ) 364 | 365 | 366 | @lru_cache(maxsize=None) 367 | def get_tokenizer( 368 | multilingual: bool, 369 | *, 370 | num_languages: int = 99, 371 | language: Optional[str] = None, 372 | task: Optional[str] = None, # Literal["transcribe", "translate", None] 373 | ) -> Tokenizer: 374 | if language is not None: 375 | language = language.lower() 376 | if language not in LANGUAGES: 377 | if language in TO_LANGUAGE_CODE: 378 | language = TO_LANGUAGE_CODE[language] 379 | else: 380 | raise ValueError(f"Unsupported language: {language}") 381 | 382 | if multilingual: 383 | encoding_name = "multilingual" 384 | language = language or "en" 385 | task = task or "transcribe" 386 | else: 387 | encoding_name = "gpt2" 388 | language = None 389 | task = None 390 | 391 | encoding = get_encoding(name=encoding_name, num_languages=num_languages) 392 | 393 | return Tokenizer( 394 | encoding=encoding, num_languages=num_languages, language=language, task=task 395 | ) 396 | -------------------------------------------------------------------------------- /whisper/transcribe.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import traceback 4 | import warnings 5 | from typing import TYPE_CHECKING, List, Optional, Tuple, Union 6 | 7 | import numpy as np 8 | import torch 9 | import tqdm 10 | 11 | from .audio import ( 12 | FRAMES_PER_SECOND, 13 | HOP_LENGTH, 14 | N_FRAMES, 15 | N_SAMPLES, 16 | SAMPLE_RATE, 17 | log_mel_spectrogram, 18 | pad_or_trim, 19 | ) 20 | from .decoding import DecodingOptions, DecodingResult 21 | from .timing import add_word_timestamps 22 | from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer 23 | from .utils import ( 24 | exact_div, 25 | format_timestamp, 26 | get_end, 27 | get_writer, 28 | make_safe, 29 | optional_float, 30 | optional_int, 31 | str2bool, 32 | ) 33 | 34 | if TYPE_CHECKING: 35 | from .model import Whisper 36 | 37 | 38 | def transcribe( 39 | model: "Whisper", 40 | audio: Union[str, np.ndarray, torch.Tensor], 41 | *, 42 | verbose: Optional[bool] = None, 43 | temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), 44 | compression_ratio_threshold: Optional[float] = 2.4, 45 | logprob_threshold: Optional[float] = -1.0, 46 | no_speech_threshold: Optional[float] = 0.6, 47 | condition_on_previous_text: bool = True, 48 | initial_prompt: Optional[str] = None, 49 | carry_initial_prompt: bool = False, 50 | word_timestamps: bool = False, 51 | prepend_punctuations: str = "\"'“¿([{-", 52 | append_punctuations: str = "\"'.。,,!!??::”)]}、", 53 | clip_timestamps: Union[str, List[float]] = "0", 54 | hallucination_silence_threshold: Optional[float] = None, 55 | **decode_options, 56 | ): 57 | """ 58 | Transcribe an audio file using Whisper 59 | 60 | Parameters 61 | ---------- 62 | model: Whisper 63 | The Whisper model instance 64 | 65 | audio: Union[str, np.ndarray, torch.Tensor] 66 | The path to the audio file to open, or the audio waveform 67 | 68 | verbose: bool 69 | Whether to display the text being decoded to the console. If True, displays all the details, 70 | If False, displays minimal details. If None, does not display anything 71 | 72 | temperature: Union[float, Tuple[float, ...]] 73 | Temperature for sampling. It can be a tuple of temperatures, which will be successively used 74 | upon failures according to either `compression_ratio_threshold` or `logprob_threshold`. 75 | 76 | compression_ratio_threshold: float 77 | If the gzip compression ratio is above this value, treat as failed 78 | 79 | logprob_threshold: float 80 | If the average log probability over sampled tokens is below this value, treat as failed 81 | 82 | no_speech_threshold: float 83 | If the no_speech probability is higher than this value AND the average log probability 84 | over sampled tokens is below `logprob_threshold`, consider the segment as silent 85 | 86 | condition_on_previous_text: bool 87 | if True, the previous output of the model is provided as a prompt for the next window; 88 | disabling may make the text inconsistent across windows, but the model becomes less prone to 89 | getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. 90 | 91 | word_timestamps: bool 92 | Extract word-level timestamps using the cross-attention pattern and dynamic time warping, 93 | and include the timestamps for each word in each segment. 94 | 95 | prepend_punctuations: str 96 | If word_timestamps is True, merge these punctuation symbols with the next word 97 | 98 | append_punctuations: str 99 | If word_timestamps is True, merge these punctuation symbols with the previous word 100 | 101 | initial_prompt: Optional[str] 102 | Optional text to provide as a prompt for the first window. This can be used to provide, or 103 | "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns 104 | to make it more likely to predict those word correctly. 105 | 106 | carry_initial_prompt: bool 107 | If carry_initial_prompt is True, `initial_prompt` is prepended to the prompt of each internal 108 | `decode()` call. If there is not enough context space at the start of the prompt, it is 109 | left-sliced to make space. 110 | 111 | decode_options: dict 112 | Keyword arguments to construct `DecodingOptions` instances 113 | 114 | clip_timestamps: Union[str, List[float]] 115 | Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process. 116 | The last end timestamp defaults to the end of the file. 117 | 118 | hallucination_silence_threshold: Optional[float] 119 | When word_timestamps is True, skip silent periods longer than this threshold (in seconds) 120 | when a possible hallucination is detected 121 | 122 | Returns 123 | ------- 124 | A dictionary containing the resulting text ("text") and segment-level details ("segments"), and 125 | the spoken language ("language"), which is detected when `decode_options["language"]` is None. 126 | """ 127 | dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32 128 | if model.device == torch.device("cpu"): 129 | if torch.cuda.is_available(): 130 | warnings.warn("Performing inference on CPU when CUDA is available") 131 | if dtype == torch.float16: 132 | warnings.warn("FP16 is not supported on CPU; using FP32 instead") 133 | dtype = torch.float32 134 | 135 | if dtype == torch.float32: 136 | decode_options["fp16"] = False 137 | 138 | # Pad 30-seconds of silence to the input audio, for slicing 139 | mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES) 140 | content_frames = mel.shape[-1] - N_FRAMES 141 | content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE) 142 | 143 | if decode_options.get("language", None) is None: 144 | if not model.is_multilingual: 145 | decode_options["language"] = "en" 146 | else: 147 | if verbose: 148 | print( 149 | "Detecting language using up to the first 30 seconds. Use `--language` to specify the language" 150 | ) 151 | mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) 152 | _, probs = model.detect_language(mel_segment) 153 | decode_options["language"] = max(probs, key=probs.get) 154 | if verbose is not None: 155 | print( 156 | f"Detected language: {LANGUAGES[decode_options['language']].title()}" 157 | ) 158 | 159 | language: str = decode_options["language"] 160 | task: str = decode_options.get("task", "transcribe") 161 | tokenizer = get_tokenizer( 162 | model.is_multilingual, 163 | num_languages=model.num_languages, 164 | language=language, 165 | task=task, 166 | ) 167 | 168 | if isinstance(clip_timestamps, str): 169 | clip_timestamps = [ 170 | float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else []) 171 | ] 172 | seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps] 173 | if len(seek_points) == 0: 174 | seek_points.append(0) 175 | if len(seek_points) % 2 == 1: 176 | seek_points.append(content_frames) 177 | seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2])) 178 | 179 | punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、" 180 | 181 | if word_timestamps and task == "translate": 182 | warnings.warn("Word-level timestamps on translations may not be reliable.") 183 | 184 | def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: 185 | temperatures = ( 186 | [temperature] if isinstance(temperature, (int, float)) else temperature 187 | ) 188 | decode_result = None 189 | 190 | for t in temperatures: 191 | kwargs = {**decode_options} 192 | if t > 0: 193 | # disable beam_size and patience when t > 0 194 | kwargs.pop("beam_size", None) 195 | kwargs.pop("patience", None) 196 | else: 197 | # disable best_of when t == 0 198 | kwargs.pop("best_of", None) 199 | 200 | options = DecodingOptions(**kwargs, temperature=t) 201 | decode_result = model.decode(segment, options) 202 | 203 | needs_fallback = False 204 | if ( 205 | compression_ratio_threshold is not None 206 | and decode_result.compression_ratio > compression_ratio_threshold 207 | ): 208 | needs_fallback = True # too repetitive 209 | if ( 210 | logprob_threshold is not None 211 | and decode_result.avg_logprob < logprob_threshold 212 | ): 213 | needs_fallback = True # average log probability is too low 214 | if ( 215 | no_speech_threshold is not None 216 | and decode_result.no_speech_prob > no_speech_threshold 217 | and logprob_threshold is not None 218 | and decode_result.avg_logprob < logprob_threshold 219 | ): 220 | needs_fallback = False # silence 221 | if not needs_fallback: 222 | break 223 | 224 | return decode_result 225 | 226 | clip_idx = 0 227 | seek = seek_clips[clip_idx][0] 228 | input_stride = exact_div( 229 | N_FRAMES, model.dims.n_audio_ctx 230 | ) # mel frames per output token: 2 231 | time_precision = ( 232 | input_stride * HOP_LENGTH / SAMPLE_RATE 233 | ) # time per output token: 0.02 (seconds) 234 | all_tokens = [] 235 | all_segments = [] 236 | prompt_reset_since = 0 237 | 238 | remaining_prompt_length = model.dims.n_text_ctx // 2 - 1 239 | if initial_prompt is not None: 240 | initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip()) 241 | all_tokens.extend(initial_prompt_tokens) 242 | remaining_prompt_length -= len(initial_prompt_tokens) 243 | else: 244 | initial_prompt_tokens = [] 245 | 246 | def new_segment( 247 | *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult 248 | ): 249 | tokens = tokens.tolist() 250 | text_tokens = [token for token in tokens if token < tokenizer.eot] 251 | return { 252 | "seek": seek, 253 | "start": start, 254 | "end": end, 255 | "text": tokenizer.decode(text_tokens), 256 | "tokens": tokens, 257 | "temperature": result.temperature, 258 | "avg_logprob": result.avg_logprob, 259 | "compression_ratio": result.compression_ratio, 260 | "no_speech_prob": result.no_speech_prob, 261 | } 262 | 263 | # show the progress bar when verbose is False (if True, transcribed text will be printed) 264 | with tqdm.tqdm( 265 | total=content_frames, unit="frames", disable=verbose is not False 266 | ) as pbar: 267 | last_speech_timestamp = 0.0 268 | # NOTE: This loop is obscurely flattened to make the diff readable. 269 | # A later commit should turn this into a simpler nested loop. 270 | # for seek_clip_start, seek_clip_end in seek_clips: 271 | # while seek < seek_clip_end 272 | while clip_idx < len(seek_clips): 273 | seek_clip_start, seek_clip_end = seek_clips[clip_idx] 274 | if seek < seek_clip_start: 275 | seek = seek_clip_start 276 | if seek >= seek_clip_end: 277 | clip_idx += 1 278 | if clip_idx < len(seek_clips): 279 | seek = seek_clips[clip_idx][0] 280 | continue 281 | time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) 282 | window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE) 283 | segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek) 284 | mel_segment = mel[:, seek : seek + segment_size] 285 | segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE 286 | mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) 287 | 288 | if carry_initial_prompt: 289 | nignored = max(len(initial_prompt_tokens), prompt_reset_since) 290 | remaining_prompt = all_tokens[nignored:][-remaining_prompt_length:] 291 | decode_options["prompt"] = initial_prompt_tokens + remaining_prompt 292 | else: 293 | decode_options["prompt"] = all_tokens[prompt_reset_since:] 294 | 295 | result: DecodingResult = decode_with_fallback(mel_segment) 296 | tokens = torch.tensor(result.tokens) 297 | 298 | if no_speech_threshold is not None: 299 | # no voice activity check 300 | should_skip = result.no_speech_prob > no_speech_threshold 301 | if ( 302 | logprob_threshold is not None 303 | and result.avg_logprob > logprob_threshold 304 | ): 305 | # don't skip if the logprob is high enough, despite the no_speech_prob 306 | should_skip = False 307 | 308 | if should_skip: 309 | seek += segment_size # fast-forward to the next segment boundary 310 | continue 311 | 312 | previous_seek = seek 313 | current_segments = [] 314 | 315 | # anomalous words are very long/short/improbable 316 | def word_anomaly_score(word: dict) -> float: 317 | probability = word.get("probability", 0.0) 318 | duration = word["end"] - word["start"] 319 | score = 0.0 320 | if probability < 0.15: 321 | score += 1.0 322 | if duration < 0.133: 323 | score += (0.133 - duration) * 15 324 | if duration > 2.0: 325 | score += duration - 2.0 326 | return score 327 | 328 | def is_segment_anomaly(segment: Optional[dict]) -> bool: 329 | if segment is None or not segment["words"]: 330 | return False 331 | words = [w for w in segment["words"] if w["word"] not in punctuation] 332 | words = words[:8] 333 | score = sum(word_anomaly_score(w) for w in words) 334 | return score >= 3 or score + 0.01 >= len(words) 335 | 336 | def next_words_segment(segments: List[dict]) -> Optional[dict]: 337 | return next((s for s in segments if s["words"]), None) 338 | 339 | timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) 340 | single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] 341 | 342 | consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] 343 | consecutive.add_(1) 344 | if len(consecutive) > 0: 345 | # if the output contains two consecutive timestamp tokens 346 | slices = consecutive.tolist() 347 | if single_timestamp_ending: 348 | slices.append(len(tokens)) 349 | 350 | last_slice = 0 351 | for current_slice in slices: 352 | sliced_tokens = tokens[last_slice:current_slice] 353 | start_timestamp_pos = ( 354 | sliced_tokens[0].item() - tokenizer.timestamp_begin 355 | ) 356 | end_timestamp_pos = ( 357 | sliced_tokens[-1].item() - tokenizer.timestamp_begin 358 | ) 359 | current_segments.append( 360 | new_segment( 361 | start=time_offset + start_timestamp_pos * time_precision, 362 | end=time_offset + end_timestamp_pos * time_precision, 363 | tokens=sliced_tokens, 364 | result=result, 365 | ) 366 | ) 367 | last_slice = current_slice 368 | 369 | if single_timestamp_ending: 370 | # single timestamp at the end means no speech after the last timestamp. 371 | seek += segment_size 372 | else: 373 | # otherwise, ignore the unfinished segment and seek to the last timestamp 374 | last_timestamp_pos = ( 375 | tokens[last_slice - 1].item() - tokenizer.timestamp_begin 376 | ) 377 | seek += last_timestamp_pos * input_stride 378 | else: 379 | duration = segment_duration 380 | timestamps = tokens[timestamp_tokens.nonzero().flatten()] 381 | if ( 382 | len(timestamps) > 0 383 | and timestamps[-1].item() != tokenizer.timestamp_begin 384 | ): 385 | # no consecutive timestamps but it has a timestamp; use the last one. 386 | last_timestamp_pos = ( 387 | timestamps[-1].item() - tokenizer.timestamp_begin 388 | ) 389 | duration = last_timestamp_pos * time_precision 390 | 391 | current_segments.append( 392 | new_segment( 393 | start=time_offset, 394 | end=time_offset + duration, 395 | tokens=tokens, 396 | result=result, 397 | ) 398 | ) 399 | seek += segment_size 400 | 401 | if word_timestamps: 402 | add_word_timestamps( 403 | segments=current_segments, 404 | model=model, 405 | tokenizer=tokenizer, 406 | mel=mel_segment, 407 | num_frames=segment_size, 408 | prepend_punctuations=prepend_punctuations, 409 | append_punctuations=append_punctuations, 410 | last_speech_timestamp=last_speech_timestamp, 411 | ) 412 | 413 | if not single_timestamp_ending: 414 | last_word_end = get_end(current_segments) 415 | if last_word_end is not None and last_word_end > time_offset: 416 | seek = round(last_word_end * FRAMES_PER_SECOND) 417 | 418 | # skip silence before possible hallucinations 419 | if hallucination_silence_threshold is not None: 420 | threshold = hallucination_silence_threshold 421 | if not single_timestamp_ending: 422 | last_word_end = get_end(current_segments) 423 | if last_word_end is not None and last_word_end > time_offset: 424 | remaining_duration = window_end_time - last_word_end 425 | if remaining_duration > threshold: 426 | seek = round(last_word_end * FRAMES_PER_SECOND) 427 | else: 428 | seek = previous_seek + segment_size 429 | 430 | # if first segment might be a hallucination, skip leading silence 431 | first_segment = next_words_segment(current_segments) 432 | if first_segment is not None and is_segment_anomaly(first_segment): 433 | gap = first_segment["start"] - time_offset 434 | if gap > threshold: 435 | seek = previous_seek + round(gap * FRAMES_PER_SECOND) 436 | continue 437 | 438 | # skip silence before any possible hallucination that is surrounded 439 | # by silence or more hallucinations 440 | hal_last_end = last_speech_timestamp 441 | for si in range(len(current_segments)): 442 | segment = current_segments[si] 443 | if not segment["words"]: 444 | continue 445 | if is_segment_anomaly(segment): 446 | next_segment = next_words_segment( 447 | current_segments[si + 1 :] 448 | ) 449 | if next_segment is not None: 450 | hal_next_start = next_segment["words"][0]["start"] 451 | else: 452 | hal_next_start = time_offset + segment_duration 453 | silence_before = ( 454 | segment["start"] - hal_last_end > threshold 455 | or segment["start"] < threshold 456 | or segment["start"] - time_offset < 2.0 457 | ) 458 | silence_after = ( 459 | hal_next_start - segment["end"] > threshold 460 | or is_segment_anomaly(next_segment) 461 | or window_end_time - segment["end"] < 2.0 462 | ) 463 | if silence_before and silence_after: 464 | seek = round( 465 | max(time_offset + 1, segment["start"]) 466 | * FRAMES_PER_SECOND 467 | ) 468 | if content_duration - segment["end"] < threshold: 469 | seek = content_frames 470 | current_segments[si:] = [] 471 | break 472 | hal_last_end = segment["end"] 473 | 474 | last_word_end = get_end(current_segments) 475 | if last_word_end is not None: 476 | last_speech_timestamp = last_word_end 477 | 478 | if verbose: 479 | for segment in current_segments: 480 | start, end, text = segment["start"], segment["end"], segment["text"] 481 | line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}" 482 | print(make_safe(line)) 483 | 484 | # if a segment is instantaneous or does not contain text, clear it 485 | for i, segment in enumerate(current_segments): 486 | if segment["start"] == segment["end"] or segment["text"].strip() == "": 487 | segment["text"] = "" 488 | segment["tokens"] = [] 489 | segment["words"] = [] 490 | 491 | all_segments.extend( 492 | [ 493 | {"id": i, **segment} 494 | for i, segment in enumerate( 495 | current_segments, start=len(all_segments) 496 | ) 497 | ] 498 | ) 499 | all_tokens.extend( 500 | [token for segment in current_segments for token in segment["tokens"]] 501 | ) 502 | 503 | if not condition_on_previous_text or result.temperature > 0.5: 504 | # do not feed the prompt tokens if a high temperature was used 505 | prompt_reset_since = len(all_tokens) 506 | 507 | # update progress bar 508 | pbar.update(min(content_frames, seek) - previous_seek) 509 | 510 | return dict( 511 | text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), 512 | segments=all_segments, 513 | language=language, 514 | ) 515 | 516 | 517 | def cli(): 518 | from . import available_models 519 | 520 | def valid_model_name(name): 521 | if name in available_models() or os.path.exists(name): 522 | return name 523 | raise ValueError( 524 | f"model should be one of {available_models()} or path to a model checkpoint" 525 | ) 526 | 527 | # fmt: off 528 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 529 | parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") 530 | parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use") 531 | parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") 532 | parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") 533 | parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") 534 | parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced") 535 | parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") 536 | 537 | parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") 538 | parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection") 539 | 540 | parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling") 541 | parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature") 542 | parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero") 543 | parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search") 544 | parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default") 545 | 546 | parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations") 547 | parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") 548 | parser.add_argument("--carry_initial_prompt", type=str2bool, default=False, help="if True, prepend initial_prompt to every internal decode() call. May reduce the effectiveness of condition_on_previous_text") 549 | 550 | parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") 551 | parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default") 552 | 553 | parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below") 554 | parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed") 555 | parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed") 556 | parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") 557 | parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them") 558 | parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word") 559 | parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word") 560 | parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt") 561 | parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line") 562 | parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment") 563 | parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment") 564 | parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") 565 | parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file") 566 | parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected") 567 | # fmt: on 568 | 569 | args = parser.parse_args().__dict__ 570 | model_name: str = args.pop("model") 571 | model_dir: str = args.pop("model_dir") 572 | output_dir: str = args.pop("output_dir") 573 | output_format: str = args.pop("output_format") 574 | device: str = args.pop("device") 575 | os.makedirs(output_dir, exist_ok=True) 576 | 577 | if model_name.endswith(".en") and args["language"] not in {"en", "English"}: 578 | if args["language"] is not None: 579 | warnings.warn( 580 | f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead." 581 | ) 582 | args["language"] = "en" 583 | 584 | temperature = args.pop("temperature") 585 | if (increment := args.pop("temperature_increment_on_fallback")) is not None: 586 | temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment)) 587 | else: 588 | temperature = [temperature] 589 | 590 | if (threads := args.pop("threads")) > 0: 591 | torch.set_num_threads(threads) 592 | 593 | from . import load_model 594 | 595 | model = load_model(model_name, device=device, download_root=model_dir) 596 | 597 | writer = get_writer(output_format, output_dir) 598 | word_options = [ 599 | "highlight_words", 600 | "max_line_count", 601 | "max_line_width", 602 | "max_words_per_line", 603 | ] 604 | if not args["word_timestamps"]: 605 | for option in word_options: 606 | if args[option]: 607 | parser.error(f"--{option} requires --word_timestamps True") 608 | if args["max_line_count"] and not args["max_line_width"]: 609 | warnings.warn("--max_line_count has no effect without --max_line_width") 610 | if args["max_words_per_line"] and args["max_line_width"]: 611 | warnings.warn("--max_words_per_line has no effect with --max_line_width") 612 | writer_args = {arg: args.pop(arg) for arg in word_options} 613 | for audio_path in args.pop("audio"): 614 | try: 615 | result = transcribe(model, audio_path, temperature=temperature, **args) 616 | writer(result, audio_path, **writer_args) 617 | except Exception as e: 618 | traceback.print_exc() 619 | print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}") 620 | 621 | 622 | if __name__ == "__main__": 623 | cli() 624 | -------------------------------------------------------------------------------- /whisper/triton_ops.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | 3 | import numpy as np 4 | import torch 5 | 6 | try: 7 | import triton 8 | import triton.language as tl 9 | except ImportError: 10 | raise RuntimeError("triton import failed; try `pip install --pre triton`") 11 | 12 | 13 | @triton.jit 14 | def dtw_kernel( 15 | cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr 16 | ): 17 | offsets = tl.arange(0, BLOCK_SIZE) 18 | mask = offsets < M 19 | 20 | for k in range(1, N + M + 1): # k = i + j 21 | tl.debug_barrier() 22 | 23 | p0 = cost + (k - 1) * cost_stride 24 | p1 = cost + k * cost_stride 25 | p2 = cost + k * cost_stride + 1 26 | 27 | c0 = tl.load(p0 + offsets, mask=mask) 28 | c1 = tl.load(p1 + offsets, mask=mask) 29 | c2 = tl.load(p2 + offsets, mask=mask) 30 | 31 | x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0) 32 | cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2) 33 | 34 | cost_ptr = cost + (k + 1) * cost_stride + 1 35 | tl.store(cost_ptr + offsets, cost_row, mask=mask) 36 | 37 | trace_ptr = trace + (k + 1) * trace_stride + 1 38 | tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1)) 39 | tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2)) 40 | tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2)) 41 | 42 | 43 | @lru_cache(maxsize=None) 44 | def median_kernel(filter_width: int): 45 | @triton.jit 46 | def kernel( 47 | y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr 48 | ): # x.shape[-1] == filter_width 49 | row_idx = tl.program_id(0) 50 | offsets = tl.arange(0, BLOCK_SIZE) 51 | mask = offsets < y_stride 52 | 53 | x_ptr = x + row_idx * x_stride # noqa: F841 54 | y_ptr = y + row_idx * y_stride 55 | 56 | LOAD_ALL_ROWS_HERE # noqa: F821 57 | 58 | BUBBLESORT_HERE # noqa: F821 59 | 60 | tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821 61 | 62 | kernel = triton.JITFunction(kernel.fn) 63 | kernel.src = kernel.src.replace( 64 | " LOAD_ALL_ROWS_HERE", 65 | "\n".join( 66 | [ 67 | f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)" 68 | for i in range(filter_width) 69 | ] 70 | ), 71 | ) 72 | kernel.src = kernel.src.replace( 73 | " BUBBLESORT_HERE", 74 | "\n\n".join( 75 | [ 76 | "\n\n".join( 77 | [ 78 | "\n".join( 79 | [ 80 | f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})", 81 | f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})", 82 | f" row{j} = smaller", 83 | f" row{j + 1} = larger", 84 | ] 85 | ) 86 | for j in range(filter_width - i - 1) 87 | ] 88 | ) 89 | for i in range(filter_width // 2 + 1) 90 | ] 91 | ), 92 | ) 93 | kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}") 94 | 95 | return kernel 96 | 97 | 98 | def median_filter_cuda(x: torch.Tensor, filter_width: int): 99 | """Apply a median filter of given width along the last dimension of x""" 100 | slices = x.contiguous().unfold(-1, filter_width, 1) 101 | grid = np.prod(slices.shape[:-2]) 102 | 103 | kernel = median_kernel(filter_width) 104 | y = torch.empty_like(slices[..., 0]) 105 | 106 | BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length() 107 | kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE) 108 | 109 | return y 110 | -------------------------------------------------------------------------------- /whisper/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | import sys 5 | import zlib 6 | from typing import Callable, List, Optional, TextIO 7 | 8 | system_encoding = sys.getdefaultencoding() 9 | 10 | if system_encoding != "utf-8": 11 | 12 | def make_safe(string): 13 | # replaces any character not representable using the system default encoding with an '?', 14 | # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729). 15 | return string.encode(system_encoding, errors="replace").decode(system_encoding) 16 | 17 | else: 18 | 19 | def make_safe(string): 20 | # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding 21 | return string 22 | 23 | 24 | def exact_div(x, y): 25 | assert x % y == 0 26 | return x // y 27 | 28 | 29 | def str2bool(string): 30 | str2val = {"True": True, "False": False} 31 | if string in str2val: 32 | return str2val[string] 33 | else: 34 | raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") 35 | 36 | 37 | def optional_int(string): 38 | return None if string == "None" else int(string) 39 | 40 | 41 | def optional_float(string): 42 | return None if string == "None" else float(string) 43 | 44 | 45 | def compression_ratio(text) -> float: 46 | text_bytes = text.encode("utf-8") 47 | return len(text_bytes) / len(zlib.compress(text_bytes)) 48 | 49 | 50 | def format_timestamp( 51 | seconds: float, always_include_hours: bool = False, decimal_marker: str = "." 52 | ): 53 | assert seconds >= 0, "non-negative timestamp expected" 54 | milliseconds = round(seconds * 1000.0) 55 | 56 | hours = milliseconds // 3_600_000 57 | milliseconds -= hours * 3_600_000 58 | 59 | minutes = milliseconds // 60_000 60 | milliseconds -= minutes * 60_000 61 | 62 | seconds = milliseconds // 1_000 63 | milliseconds -= seconds * 1_000 64 | 65 | hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" 66 | return ( 67 | f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" 68 | ) 69 | 70 | 71 | def get_start(segments: List[dict]) -> Optional[float]: 72 | return next( 73 | (w["start"] for s in segments for w in s["words"]), 74 | segments[0]["start"] if segments else None, 75 | ) 76 | 77 | 78 | def get_end(segments: List[dict]) -> Optional[float]: 79 | return next( 80 | (w["end"] for s in reversed(segments) for w in reversed(s["words"])), 81 | segments[-1]["end"] if segments else None, 82 | ) 83 | 84 | 85 | class ResultWriter: 86 | extension: str 87 | 88 | def __init__(self, output_dir: str): 89 | self.output_dir = output_dir 90 | 91 | def __call__( 92 | self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs 93 | ): 94 | audio_basename = os.path.basename(audio_path) 95 | audio_basename = os.path.splitext(audio_basename)[0] 96 | output_path = os.path.join( 97 | self.output_dir, audio_basename + "." + self.extension 98 | ) 99 | 100 | with open(output_path, "w", encoding="utf-8") as f: 101 | self.write_result(result, file=f, options=options, **kwargs) 102 | 103 | def write_result( 104 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs 105 | ): 106 | raise NotImplementedError 107 | 108 | 109 | class WriteTXT(ResultWriter): 110 | extension: str = "txt" 111 | 112 | def write_result( 113 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs 114 | ): 115 | for segment in result["segments"]: 116 | print(segment["text"].strip(), file=file, flush=True) 117 | 118 | 119 | class SubtitlesWriter(ResultWriter): 120 | always_include_hours: bool 121 | decimal_marker: str 122 | 123 | def iterate_result( 124 | self, 125 | result: dict, 126 | options: Optional[dict] = None, 127 | *, 128 | max_line_width: Optional[int] = None, 129 | max_line_count: Optional[int] = None, 130 | highlight_words: bool = False, 131 | max_words_per_line: Optional[int] = None, 132 | ): 133 | options = options or {} 134 | max_line_width = max_line_width or options.get("max_line_width") 135 | max_line_count = max_line_count or options.get("max_line_count") 136 | highlight_words = highlight_words or options.get("highlight_words", False) 137 | max_words_per_line = max_words_per_line or options.get("max_words_per_line") 138 | preserve_segments = max_line_count is None or max_line_width is None 139 | max_line_width = max_line_width or 1000 140 | max_words_per_line = max_words_per_line or 1000 141 | 142 | def iterate_subtitles(): 143 | line_len = 0 144 | line_count = 1 145 | # the next subtitle to yield (a list of word timings with whitespace) 146 | subtitle: List[dict] = [] 147 | last: float = get_start(result["segments"]) or 0.0 148 | for segment in result["segments"]: 149 | chunk_index = 0 150 | words_count = max_words_per_line 151 | while chunk_index < len(segment["words"]): 152 | remaining_words = len(segment["words"]) - chunk_index 153 | if max_words_per_line > len(segment["words"]) - chunk_index: 154 | words_count = remaining_words 155 | for i, original_timing in enumerate( 156 | segment["words"][chunk_index : chunk_index + words_count] 157 | ): 158 | timing = original_timing.copy() 159 | long_pause = ( 160 | not preserve_segments and timing["start"] - last > 3.0 161 | ) 162 | has_room = line_len + len(timing["word"]) <= max_line_width 163 | seg_break = i == 0 and len(subtitle) > 0 and preserve_segments 164 | if ( 165 | line_len > 0 166 | and has_room 167 | and not long_pause 168 | and not seg_break 169 | ): 170 | # line continuation 171 | line_len += len(timing["word"]) 172 | else: 173 | # new line 174 | timing["word"] = timing["word"].strip() 175 | if ( 176 | len(subtitle) > 0 177 | and max_line_count is not None 178 | and (long_pause or line_count >= max_line_count) 179 | or seg_break 180 | ): 181 | # subtitle break 182 | yield subtitle 183 | subtitle = [] 184 | line_count = 1 185 | elif line_len > 0: 186 | # line break 187 | line_count += 1 188 | timing["word"] = "\n" + timing["word"] 189 | line_len = len(timing["word"].strip()) 190 | subtitle.append(timing) 191 | last = timing["start"] 192 | chunk_index += max_words_per_line 193 | if len(subtitle) > 0: 194 | yield subtitle 195 | 196 | if len(result["segments"]) > 0 and "words" in result["segments"][0]: 197 | for subtitle in iterate_subtitles(): 198 | subtitle_start = self.format_timestamp(subtitle[0]["start"]) 199 | subtitle_end = self.format_timestamp(subtitle[-1]["end"]) 200 | subtitle_text = "".join([word["word"] for word in subtitle]) 201 | if highlight_words: 202 | last = subtitle_start 203 | all_words = [timing["word"] for timing in subtitle] 204 | for i, this_word in enumerate(subtitle): 205 | start = self.format_timestamp(this_word["start"]) 206 | end = self.format_timestamp(this_word["end"]) 207 | if last != start: 208 | yield last, start, subtitle_text 209 | 210 | yield start, end, "".join( 211 | [ 212 | ( 213 | re.sub(r"^(\s*)(.*)$", r"\1\2", word) 214 | if j == i 215 | else word 216 | ) 217 | for j, word in enumerate(all_words) 218 | ] 219 | ) 220 | last = end 221 | else: 222 | yield subtitle_start, subtitle_end, subtitle_text 223 | else: 224 | for segment in result["segments"]: 225 | segment_start = self.format_timestamp(segment["start"]) 226 | segment_end = self.format_timestamp(segment["end"]) 227 | segment_text = segment["text"].strip().replace("-->", "->") 228 | yield segment_start, segment_end, segment_text 229 | 230 | def format_timestamp(self, seconds: float): 231 | return format_timestamp( 232 | seconds=seconds, 233 | always_include_hours=self.always_include_hours, 234 | decimal_marker=self.decimal_marker, 235 | ) 236 | 237 | 238 | class WriteVTT(SubtitlesWriter): 239 | extension: str = "vtt" 240 | always_include_hours: bool = False 241 | decimal_marker: str = "." 242 | 243 | def write_result( 244 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs 245 | ): 246 | print("WEBVTT\n", file=file) 247 | for start, end, text in self.iterate_result(result, options, **kwargs): 248 | print(f"{start} --> {end}\n{text}\n", file=file, flush=True) 249 | 250 | 251 | class WriteSRT(SubtitlesWriter): 252 | extension: str = "srt" 253 | always_include_hours: bool = True 254 | decimal_marker: str = "," 255 | 256 | def write_result( 257 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs 258 | ): 259 | for i, (start, end, text) in enumerate( 260 | self.iterate_result(result, options, **kwargs), start=1 261 | ): 262 | print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True) 263 | 264 | 265 | class WriteTSV(ResultWriter): 266 | """ 267 | Write a transcript to a file in TSV (tab-separated values) format containing lines like: 268 | \t\t 269 | 270 | Using integer milliseconds as start and end times means there's no chance of interference from 271 | an environment setting a language encoding that causes the decimal in a floating point number 272 | to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++. 273 | """ 274 | 275 | extension: str = "tsv" 276 | 277 | def write_result( 278 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs 279 | ): 280 | print("start", "end", "text", sep="\t", file=file) 281 | for segment in result["segments"]: 282 | print(round(1000 * segment["start"]), file=file, end="\t") 283 | print(round(1000 * segment["end"]), file=file, end="\t") 284 | print(segment["text"].strip().replace("\t", " "), file=file, flush=True) 285 | 286 | 287 | class WriteJSON(ResultWriter): 288 | extension: str = "json" 289 | 290 | def write_result( 291 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs 292 | ): 293 | json.dump(result, file) 294 | 295 | 296 | def get_writer( 297 | output_format: str, output_dir: str 298 | ) -> Callable[[dict, TextIO, dict], None]: 299 | writers = { 300 | "txt": WriteTXT, 301 | "vtt": WriteVTT, 302 | "srt": WriteSRT, 303 | "tsv": WriteTSV, 304 | "json": WriteJSON, 305 | } 306 | 307 | if output_format == "all": 308 | all_writers = [writer(output_dir) for writer in writers.values()] 309 | 310 | def write_all( 311 | result: dict, file: TextIO, options: Optional[dict] = None, **kwargs 312 | ): 313 | for writer in all_writers: 314 | writer(result, file, options, **kwargs) 315 | 316 | return write_all 317 | 318 | return writers[output_format](output_dir) 319 | -------------------------------------------------------------------------------- /whisper/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "20240930" 2 | --------------------------------------------------------------------------------