├── .flake8
├── .gitattributes
├── .github
└── workflows
│ ├── python-publish.yml
│ └── test.yml
├── .gitignore
├── CHANGELOG.md
├── LICENSE
├── MANIFEST.in
├── README.md
├── approach.png
├── data
├── README.md
└── meanwhile.json
├── language-breakdown.svg
├── model-card.md
├── notebooks
├── LibriSpeech.ipynb
└── Multilingual_ASR.ipynb
├── notes.txt
├── pyproject.toml
├── requirements.txt
├── setup.py
├── tests
├── conftest.py
├── jfk.flac
├── test_audio.py
├── test_normalizer.py
├── test_timing.py
├── test_tokenizer.py
└── test_transcribe.py
└── whisper
├── __init__.py
├── __main__.py
├── assets
├── gpt2.tiktoken
├── mel_filters.npz
└── multilingual.tiktoken
├── audio.py
├── decoding.py
├── model.py
├── normalizers
├── __init__.py
├── basic.py
├── english.json
└── english.py
├── timing.py
├── tokenizer.py
├── transcribe.py
├── triton_ops.py
├── utils.py
└── version.py
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | per-file-ignores =
3 | */__init__.py: F401
4 |
5 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Override jupyter in Github language stats for more accurate estimate of repo code languages
2 | # reference: https://github.com/github/linguist/blob/master/docs/overrides.md#generated-code
3 | *.ipynb linguist-generated
4 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | name: Release
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | jobs:
8 | deploy:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v2
12 | - uses: actions-ecosystem/action-regex-match@v2
13 | id: regex-match
14 | with:
15 | text: ${{ github.event.head_commit.message }}
16 | regex: '^Release ([^ ]+)'
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.8'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Release
26 | if: ${{ steps.regex-match.outputs.match != '' }}
27 | uses: softprops/action-gh-release@v1
28 | with:
29 | tag_name: v${{ steps.regex-match.outputs.group1 }}
30 | - name: Build and publish
31 | if: ${{ steps.regex-match.outputs.match != '' }}
32 | env:
33 | TWINE_USERNAME: __token__
34 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
35 | run: |
36 | python setup.py sdist
37 | twine upload dist/*
38 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: test
2 | on:
3 | push:
4 | branches:
5 | - main
6 | pull_request:
7 | branches:
8 | - main
9 | jobs:
10 | whisper-test:
11 | runs-on: ubuntu-latest
12 | strategy:
13 | matrix:
14 | python-version: ['3.8', '3.9', '3.10']
15 | pytorch-version: [1.10.2, 1.13.1]
16 | exclude:
17 | - python-version: '3.10'
18 | pytorch-version: 1.10.2
19 | steps:
20 | - uses: conda-incubator/setup-miniconda@v2
21 | - run: conda install -n test ffmpeg python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch
22 | - uses: actions/checkout@v2
23 | - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
24 | - run: pip install .["dev"]
25 | - run: black --check --diff -t py38 --include '(\.pyi?)$' .
26 | - run: isort --check --diff .
27 | - run: flake8 --ignore E203,W503,W504,E501,E731,E741 .
28 | - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'
29 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.py[cod]
3 | *$py.class
4 | *.egg-info
5 | .pytest_cache
6 | .ipynb_checkpoints
7 |
8 | thumbs.db
9 | .DS_Store
10 | .idea
11 |
12 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # CHANGELOG
2 |
3 | ## [v20230308](https://github.com/openai/whisper/releases/tag/v20230308)
4 |
5 | * kwargs in decode() for convenience ([#1061](https://github.com/openai/whisper/pull/1061))
6 | * fix all_tokens handling that caused more repetitions and discrepancy in JSON ([#1060](https://github.com/openai/whisper/pull/1060))
7 | * fix typo in CHANGELOG.md
8 |
9 | ## [v20230307](https://github.com/openai/whisper/releases/tag/v20230307)
10 |
11 | * Fix the repetition/hallucination issue identified in #1046 ([#1052](https://github.com/openai/whisper/pull/1052))
12 | * Use triton==2.0.0 ([#1053](https://github.com/openai/whisper/pull/1053))
13 | * Install triton in x86_64 linux only ([#1051](https://github.com/openai/whisper/pull/1051))
14 | * update setup.py to specify python >= 3.8 requirement
15 |
16 | ## [v20230306](https://github.com/openai/whisper/releases/tag/v20230306)
17 |
18 | * remove auxiliary audio extension ([#1021](https://github.com/openai/whisper/pull/1021))
19 | * apply formatting with `black`, `isort`, and `flake8` ([#1038](https://github.com/openai/whisper/pull/1038))
20 | * word-level timestamps in `transcribe()` ([#869](https://github.com/openai/whisper/pull/869))
21 | * Decoding improvements ([#1033](https://github.com/openai/whisper/pull/1033))
22 | * Update README.md ([#894](https://github.com/openai/whisper/pull/894))
23 | * Fix infinite loop caused by incorrect timestamp tokens prediction ([#914](https://github.com/openai/whisper/pull/914))
24 | * drop python 3.7 support ([#889](https://github.com/openai/whisper/pull/889))
25 |
26 | ## [v20230124](https://github.com/openai/whisper/releases/tag/v20230124)
27 |
28 | * handle printing even if sys.stdout.buffer is not available ([#887](https://github.com/openai/whisper/pull/887))
29 | * Add TSV formatted output in transcript, using integer start/end time in milliseconds ([#228](https://github.com/openai/whisper/pull/228))
30 | * Added `--output_format` option ([#333](https://github.com/openai/whisper/pull/333))
31 | * Handle `XDG_CACHE_HOME` properly for `download_root` ([#864](https://github.com/openai/whisper/pull/864))
32 | * use stdout for printing transcription progress ([#867](https://github.com/openai/whisper/pull/867))
33 | * Fix bug where mm is mistakenly replaced with hmm in e.g. 20mm ([#659](https://github.com/openai/whisper/pull/659))
34 | * print '?' if a letter can't be encoded using the system default encoding ([#859](https://github.com/openai/whisper/pull/859))
35 |
36 | ## [v20230117](https://github.com/openai/whisper/releases/tag/v20230117)
37 |
38 | The first versioned release available on [PyPI](https://pypi.org/project/openai-whisper/)
39 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 OpenAI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include requirements.txt
2 | include README.md
3 | include LICENSE
4 | include whisper/assets/*
5 | include whisper/normalizers/english.json
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Whisper
2 |
3 | [[Blog]](https://openai.com/blog/whisper)
4 | [[Paper]](https://arxiv.org/abs/2212.04356)
5 | [[Model card]](https://github.com/openai/whisper/blob/main/model-card.md)
6 | [[Colab example]](https://colab.research.google.com/github/openai/whisper/blob/master/notebooks/LibriSpeech.ipynb)
7 |
8 | Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multitasking model that can perform multilingual speech recognition, speech translation, and language identification.
9 |
10 |
11 | ## Approach
12 |
13 | 
14 |
15 | A Transformer sequence-to-sequence model is trained on various speech processing tasks, including multilingual speech recognition, speech translation, spoken language identification, and voice activity detection. These tasks are jointly represented as a sequence of tokens to be predicted by the decoder, allowing a single model to replace many stages of a traditional speech-processing pipeline. The multitask training format uses a set of special tokens that serve as task specifiers or classification targets.
16 |
17 |
18 | ## Setup
19 |
20 | We used Python 3.9.9 and [PyTorch](https://pytorch.org/) 1.10.1 to train and test our models, but the codebase is expected to be compatible with Python 3.8-3.10 and recent PyTorch versions. The codebase also depends on a few Python packages, most notably [HuggingFace Transformers](https://huggingface.co/docs/transformers/index) for their fast tokenizer implementation and [ffmpeg-python](https://github.com/kkroening/ffmpeg-python) for reading audio files. You can download and install (or update to) the latest release of Whisper with the following command:
21 |
22 | pip install -U openai-whisper
23 |
24 | Alternatively, the following command will pull and install the latest commit from this repository, along with its Python dependencies:
25 |
26 | pip install git+https://github.com/ProjectEGU/whisper-for-low-vram.git
27 |
28 | To update the package to the latest version of this repository, please run:
29 |
30 | pip install --upgrade --no-deps --force-reinstall git+https://github.com/ProjectEGU/whisper-for-low-vram.git
31 |
32 | To update the package to the latest version of this repository, please run:
33 |
34 | pip install --upgrade --no-deps --force-reinstall git+https://github.com/ProjectEGU/whisper-for-low-vram.git
35 |
36 | It also requires the command-line tool [`ffmpeg`](https://ffmpeg.org/) to be installed on your system, which is available from most package managers:
37 |
38 | ```bash
39 | # on Ubuntu or Debian
40 | sudo apt update && sudo apt install ffmpeg
41 |
42 | # on Arch Linux
43 | sudo pacman -S ffmpeg
44 |
45 | # on MacOS using Homebrew (https://brew.sh/)
46 | brew install ffmpeg
47 |
48 | # on Windows using Chocolatey (https://chocolatey.org/)
49 | choco install ffmpeg
50 |
51 | # on Windows using Scoop (https://scoop.sh/)
52 | scoop install ffmpeg
53 | ```
54 |
55 | You may need [`rust`](http://rust-lang.org) installed as well, in case [tokenizers](https://pypi.org/project/tokenizers/) does not provide a pre-built wheel for your platform. If you see installation errors during the `pip install` command above, please follow the [Getting started page](https://www.rust-lang.org/learn/get-started) to install Rust development environment. Additionally, you may need to configure the `PATH` environment variable, e.g. `export PATH="$HOME/.cargo/bin:$PATH"`. If the installation fails with `No module named 'setuptools_rust'`, you need to install `setuptools_rust`, e.g. by running:
56 |
57 | ```bash
58 | pip install setuptools-rust
59 | ```
60 |
61 |
62 | ## Available models and languages
63 |
64 | There are five model sizes, four with English-only versions, offering speed and accuracy tradeoffs. Below are the names of the available models and their approximate memory requirements and relative speed.
65 |
66 |
67 | | Size | Parameters | English-only model | Multilingual model | Required VRAM | Relative speed |
68 | |:------:|:----------:|:------------------:|:------------------:|:-------------:|:--------------:|
69 | | tiny | 39 M | `tiny.en` | `tiny` | ~1 GB | ~32x |
70 | | base | 74 M | `base.en` | `base` | ~1 GB | ~16x |
71 | | small | 244 M | `small.en` | `small` | ~2 GB | ~6x |
72 | | medium | 769 M | `medium.en` | `medium` | ~5 GB | ~2x |
73 | | large | 1550 M | N/A | `large` | ~10 GB | 1x |
74 |
75 | The `.en` models for English-only applications tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models.
76 |
77 | Whisper's performance varies widely depending on the language. The figure below shows a WER (Word Error Rate) breakdown by languages of the Fleurs dataset using the `large-v2` model. More WER and BLEU scores corresponding to the other models and datasets can be found in Appendix D in [the paper](https://arxiv.org/abs/2212.04356). The smaller, the better.
78 |
79 | 
80 |
81 |
82 |
83 | ## Command-line usage
84 |
85 | The following command will transcribe speech in audio files, using the `medium` model:
86 |
87 | whisper audio.flac audio.mp3 audio.wav --model medium
88 |
89 | The default setting (which selects the `small` model) works well for transcribing English. To transcribe an audio file containing non-English speech, you can specify the language using the `--language` option:
90 |
91 | whisper japanese.wav --language Japanese
92 |
93 | Adding `--task translate` will translate the speech into English:
94 |
95 | whisper japanese.wav --language Japanese --task translate
96 |
97 | Run the following to view all available options:
98 |
99 | whisper --help
100 |
101 | See [tokenizer.py](https://github.com/openai/whisper/blob/main/whisper/tokenizer.py) for the list of all available languages.
102 |
103 |
104 | ## Python usage
105 |
106 | Transcription can also be performed within Python:
107 |
108 | ```python
109 | import whisper
110 |
111 | model = whisper.load_model("base")
112 | result = model.transcribe("audio.mp3")
113 | print(result["text"])
114 | ```
115 |
116 | Internally, the `transcribe()` method reads the entire file and processes the audio with a sliding 30-second window, performing autoregressive sequence-to-sequence predictions on each window.
117 |
118 | Below is an example usage of `whisper.detect_language()` and `whisper.decode()` which provide lower-level access to the model.
119 |
120 | ```python
121 | import whisper
122 |
123 | model = whisper.load_model("base")
124 |
125 | # load audio and pad/trim it to fit 30 seconds
126 | audio = whisper.load_audio("audio.mp3")
127 | audio = whisper.pad_or_trim(audio)
128 |
129 | # make log-Mel spectrogram and move to the same device as the model
130 | mel = whisper.log_mel_spectrogram(audio).to(model.device)
131 |
132 | # detect the spoken language
133 | _, probs = model.detect_language(mel)
134 | print(f"Detected language: {max(probs, key=probs.get)}")
135 |
136 | # decode the audio
137 | options = whisper.DecodingOptions()
138 | result = whisper.decode(model, mel, options)
139 |
140 | # print the recognized text
141 | print(result.text)
142 | ```
143 |
144 | ## More examples
145 |
146 | Please use the [🙌 Show and tell](https://github.com/openai/whisper/discussions/categories/show-and-tell) category in Discussions for sharing more example usages of Whisper and third-party extensions such as web demos, integrations with other tools, ports for different platforms, etc.
147 |
148 |
149 | ## License
150 |
151 | Whisper's code and model weights are released under the MIT License. See [LICENSE](https://github.com/openai/whisper/blob/main/LICENSE) for further details.
152 |
--------------------------------------------------------------------------------
/approach.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProjectEGU/whisper-for-low-vram/6175c21f63450a971ee75428e1dc4aeb5d1953b4/approach.png
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | This directory supplements the paper with more details on how we prepared the data for evaluation, to help replicate our experiments.
2 |
3 | ## Short-form English-only datasets
4 |
5 | ### LibriSpeech
6 |
7 | We used the test-clean and test-other splits from the [LibriSpeech ASR corpus](https://www.openslr.org/12).
8 |
9 | ### TED-LIUM 3
10 |
11 | We used the test split of [TED-LIUM Release 3](https://www.openslr.org/51/), using the segmented manual transcripts included in the release.
12 |
13 | ### Common Voice 5.1
14 |
15 | We downloaded the English subset of Common Voice Corpus 5.1 from [the official website](https://commonvoice.mozilla.org/en/datasets)
16 |
17 | ### Artie
18 |
19 | We used the [Artie bias corpus](https://github.com/artie-inc/artie-bias-corpus). This is a subset of the Common Voice dataset.
20 |
21 | ### CallHome & Switchboard
22 |
23 | We used the two corpora from [LDC2002S09](https://catalog.ldc.upenn.edu/LDC2002S09) and [LDC2002T43](https://catalog.ldc.upenn.edu/LDC2002T43) and followed the [eval2000_data_prep.sh](https://github.com/kaldi-asr/kaldi/blob/master/egs/fisher_swbd/s5/local/eval2000_data_prep.sh) script for preprocessing. The `wav.scp` files can be converted to WAV files with the following bash commands:
24 |
25 | ```bash
26 | mkdir -p wav
27 | while read name cmd; do
28 | echo $name
29 | echo ${cmd/\|/} wav/$name.wav | bash
30 | done < wav.scp
31 | ```
32 |
33 |
34 | ### WSJ
35 |
36 | We used [LDC93S6B](https://catalog.ldc.upenn.edu/LDC93S6B) and [LDC94S13B](https://catalog.ldc.upenn.edu/LDC94S13B) and followed the [s5 recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/wsj/s5) to preprocess the dataset.
37 |
38 | ### CORAAL
39 |
40 | We used the 231 interviews from [CORAAL (v. 2021.07)](https://oraal.uoregon.edu/coraal) and used the segmentations from [the FairSpeech project](https://github.com/stanford-policylab/asr-disparities/blob/master/input/CORAAL_transcripts.csv).
41 |
42 | ### CHiME-6
43 |
44 | We downloaded the [CHiME-5 dataset](https://spandh.dcs.shef.ac.uk//chime_challenge/CHiME5/download.html) and followed the stage 0 of the [s5_track1 recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/chime6/s5_track1) to create the CHiME-6 dataset which fixes synchronization. We then used the binaural recordings (`*_P??.wav`) and the corresponding transcripts.
45 |
46 | ### AMI-IHM, AMI-SDM1
47 |
48 | We preprocessed the [AMI Corpus](https://groups.inf.ed.ac.uk/ami/corpus/overview.shtml) by following the stage 0 ad 2 of the [s5b recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/ami/s5b).
49 |
50 |
51 | ## Long-form English-only datasets
52 |
53 | ### TED-LIUM 3
54 |
55 | To create a long-form transcription dataset from the [TED-LIUM3](https://www.openslr.org/51/) dataset, we sliced the audio between the beginning of the first labeled segment and the end of the last labeled segment of each talk, and we used the concatenated text as the label. Below are the timestamps used for slicing each of the 11 TED talks in the test split.
56 |
57 | | Filename | Begin time (s) | End time (s) |
58 | |---------------------|----------------|--------------|
59 | | DanBarber_2010 | 16.09 | 1116.24 |
60 | | JaneMcGonigal_2010 | 15.476 | 1187.61 |
61 | | BillGates_2010 | 15.861 | 1656.94 |
62 | | TomWujec_2010U | 16.26 | 402.17 |
63 | | GaryFlake_2010 | 16.06 | 367.14 |
64 | | EricMead_2009P | 18.434 | 536.44 |
65 | | MichaelSpecter_2010 | 16.11 | 979.312 |
66 | | DanielKahneman_2010 | 15.8 | 1199.44 |
67 | | AimeeMullins_2009P | 17.82 | 1296.59 |
68 | | JamesCameron_2010 | 16.75 | 1010.65 |
69 | | RobertGupta_2010U | 16.8 | 387.03 |
70 |
71 | ### Meanwhile
72 |
73 | This dataset consists of 64 segments from The Late Show with Stephen Colbert. The YouTube video ID, start and end timestamps, and the labels can be found in [meanwhile.json](meanwhile.json). The labels are collected from the closed-caption data for each video and corrected with manual inspection.
74 |
75 | ### Rev16
76 |
77 | We use a subset of 16 files from the 30 podcast episodes in [Rev.AI's Podcast Transcription Benchmark](https://www.rev.ai/blog/podcast-transcription-benchmark-part-1/), after finding that there are multiple cases where a significant portion of the audio and the labels did not match, mostly on the parts introducing the sponsors. We selected 16 episodes that do not have this error, whose "file number" are:
78 |
79 | 3 4 9 10 11 14 17 18 20 21 23 24 26 27 29 32
80 |
81 | ### Kincaid46
82 |
83 | This dataset consists of 46 audio files and the corresponding transcripts compiled in the blog article [Which automatic transcription service is the most accurate - 2018](https://medium.com/descript/which-automatic-transcription-service-is-the-most-accurate-2018-2e859b23ed19) by Jason Kincaid. We used the 46 audio files and reference transcripts from the Airtable widget in the article.
84 |
85 | For the human transcription benchmark in the paper, we use a subset of 25 examples from this data, whose "Ref ID" are:
86 |
87 | 2 4 5 8 9 10 12 13 14 16 19 21 23 25 26 28 29 30 33 35 36 37 42 43 45
88 |
89 | ### Earnings-21, Earnings-22
90 |
91 | For these datasets, we used the files available in [the speech-datasets repository](https://github.com/revdotcom/speech-datasets), as of their `202206` version.
92 |
93 | ### CORAAL
94 |
95 | We used the 231 interviews from [CORAAL (v. 2021.07)](https://oraal.uoregon.edu/coraal) and used the full-length interview files and transcripts.
96 |
97 |
98 | ## Multilingual datasets
99 |
100 | ### Multilingual LibriSpeech
101 |
102 | We used the test splits from each language in [the Multilingual LibriSpeech (MLS) corpus](https://www.openslr.org/94/).
103 |
104 | ### Fleurs
105 |
106 | We collected audio files and transcripts using the implementation available as [HuggingFace datasets](https://huggingface.co/datasets/google/fleurs/blob/main/fleurs.py). To use as a translation dataset, we matched the numerical utterance IDs to find the corresponding transcript in English.
107 |
108 | ### VoxPopuli
109 |
110 | We used the `get_asr_data.py` script from [the official repository](https://github.com/facebookresearch/voxpopuli) to collect the ASR data in 14 languages.
111 |
112 | ### Common Voice 9
113 |
114 | We downloaded the Common Voice Corpus 9 from [the official website](https://commonvoice.mozilla.org/en/datasets)
115 |
116 | ### CoVOST 2
117 |
118 | We collected the `X into English` data collected using [the official repository](https://github.com/facebookresearch/covost).
119 |
--------------------------------------------------------------------------------
/model-card.md:
--------------------------------------------------------------------------------
1 | # Model Card: Whisper
2 |
3 | This is the official codebase for running the automatic speech recognition (ASR) models (Whisper models) trained and released by OpenAI.
4 |
5 | Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993), we're providing some information about the automatic speech recognition model. More information on how these models were trained and evaluated can be found [in the paper](https://arxiv.org/abs/2212.04356).
6 |
7 |
8 | ## Model Details
9 |
10 | The Whisper models are trained for speech recognition and translation tasks, capable of transcribing speech audio into the text in the language it is spoken (ASR) as well as translated into English (speech translation). Researchers at OpenAI developed the models to study the robustness of speech processing systems trained under large-scale weak supervision. There are 9 models of different sizes and capabilities, summarized in the following table.
11 |
12 | | Size | Parameters | English-only model | Multilingual model |
13 | |:------:|:----------:|:------------------:|:------------------:|
14 | | tiny | 39 M | ✓ | ✓ |
15 | | base | 74 M | ✓ | ✓ |
16 | | small | 244 M | ✓ | ✓ |
17 | | medium | 769 M | ✓ | ✓ |
18 | | large | 1550 M | | ✓ |
19 |
20 | In December 2022, we [released an improved large model named `large-v2`](https://github.com/openai/whisper/discussions/661).
21 |
22 |
23 | ### Release date
24 |
25 | September 2022 (original series) and December 2022 (`large-v2`)
26 |
27 | ### Model type
28 |
29 | Sequence-to-sequence ASR (automatic speech recognition) and speech translation model
30 |
31 | ### Paper & samples
32 |
33 | [Paper](https://arxiv.org/abs/2212.04356) / [Blog](https://openai.com/blog/whisper)
34 |
35 |
36 | ## Model Use
37 |
38 | ### Evaluated Use
39 |
40 | The primary intended users of these models are AI researchers studying robustness, generalization, capabilities, biases, and constraints of the current model. However, Whisper is also potentially quite useful as an ASR solution for developers, especially for English speech recognition. We recognize that once models are released, it is impossible to restrict access to only “intended” uses or to draw reasonable guidelines around what is or is not research.
41 |
42 | The models are primarily trained and evaluated on ASR and speech translation to English tasks. They show strong ASR results in ~10 languages. They may exhibit additional capabilities, particularly if fine-tuned on certain tasks like voice activity detection, speaker classification, or speaker diarization but have not been robustly evaluated in these areas. We strongly recommend that users perform robust evaluations of the models in a particular context and domain before deploying them.
43 |
44 | In particular, we caution against using Whisper models to transcribe recordings of individuals taken without their consent or purporting to use these models for any kind of subjective classification. We recommend against use in high-risk domains like decision-making contexts, where flaws in accuracy can lead to pronounced flaws in outcomes. The models are intended to transcribe and translate speech, use of the model for classification is not only not evaluated but also not appropriate, particularly to infer human attributes.
45 |
46 |
47 | ## Training Data
48 |
49 | The models are trained on 680,000 hours of audio and the corresponding transcripts collected from the internet. 65% of this data (or 438,000 hours) represents English-language audio and matched English transcripts, roughly 18% (or 126,000 hours) represents non-English audio and English transcripts, while the final 17% (or 117,000 hours) represents non-English audio and the corresponding transcript. This non-English data represents 98 different languages.
50 |
51 | As discussed in [the accompanying paper](https://arxiv.org/abs/2212.04356), we see that performance on transcription in a given language is directly correlated with the amount of training data we employ in that language.
52 |
53 |
54 | ## Performance and Limitations
55 |
56 | Our studies show that, over many existing ASR systems, the models exhibit improved robustness to accents, background noise, technical language, as well as zero shot translation from multiple languages into English; and that accuracy on speech recognition and translation is near the state-of-the-art level.
57 |
58 | However, because the models are trained in a weakly supervised manner using large-scale noisy data, the predictions may include texts that are not actually spoken in the audio input (i.e. hallucination). We hypothesize that this happens because, given their general knowledge of language, the models combine trying to predict the next word in audio with trying to transcribe the audio itself.
59 |
60 | Our models perform unevenly across languages, and we observe lower accuracy on low-resource and/or low-discoverability languages or languages where we have less training data. The models also exhibit disparate performance on different accents and dialects of particular languages, which may include higher word error rate across speakers of different genders, races, ages, or other demographic criteria. Our full evaluation results are presented in [the paper accompanying this release](https://arxiv.org/abs/2212.04356).
61 |
62 | In addition, the sequence-to-sequence architecture of the model makes it prone to generating repetitive texts, which can be mitigated to some degree by beam search and temperature scheduling but not perfectly. Further analysis on these limitations are provided in [the paper](https://arxiv.org/abs/2212.04356). It is likely that this behavior and hallucinations may be worse on lower-resource and/or lower-discoverability languages.
63 |
64 |
65 | ## Broader Implications
66 |
67 | We anticipate that Whisper models’ transcription capabilities may be used for improving accessibility tools. While Whisper models cannot be used for real-time transcription out of the box – their speed and size suggest that others may be able to build applications on top of them that allow for near-real-time speech recognition and translation. The real value of beneficial applications built on top of Whisper models suggests that the disparate performance of these models may have real economic implications.
68 |
69 | There are also potential dual use concerns that come with releasing Whisper. While we hope the technology will be used primarily for beneficial purposes, making ASR technology more accessible could enable more actors to build capable surveillance technologies or scale up existing surveillance efforts, as the speed and accuracy allow for affordable automatic transcription and translation of large volumes of audio communication. Moreover, these models may have some capabilities to recognize specific individuals out of the box, which in turn presents safety concerns related both to dual use and disparate performance. In practice, we expect that the cost of transcription is not the limiting factor of scaling up surveillance projects.
70 |
--------------------------------------------------------------------------------
/notebooks/LibriSpeech.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "v5hvo8QWN-a9"
7 | },
8 | "source": [
9 | "# Installing Whisper\n",
10 | "\n",
11 | "The commands below will install the Python packages needed to use Whisper models and evaluate the transcription results."
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 1,
17 | "metadata": {
18 | "id": "ZsJUxc0aRsAf"
19 | },
20 | "outputs": [],
21 | "source": [
22 | "! pip install git+https://github.com/openai/whisper.git\n",
23 | "! pip install jiwer"
24 | ]
25 | },
26 | {
27 | "cell_type": "markdown",
28 | "metadata": {
29 | "id": "1IMEkgyagYto"
30 | },
31 | "source": [
32 | "# Loading the LibriSpeech dataset\n",
33 | "\n",
34 | "The following will load the test-clean split of the LibriSpeech corpus using torchaudio."
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": 2,
40 | "metadata": {
41 | "id": "3CqtR2Fi5-vP"
42 | },
43 | "outputs": [],
44 | "source": [
45 | "import os\n",
46 | "import numpy as np\n",
47 | "\n",
48 | "try:\n",
49 | " import tensorflow # required in Colab to avoid protobuf compatibility issues\n",
50 | "except ImportError:\n",
51 | " pass\n",
52 | "\n",
53 | "import torch\n",
54 | "import pandas as pd\n",
55 | "import whisper\n",
56 | "import torchaudio\n",
57 | "\n",
58 | "from tqdm.notebook import tqdm\n",
59 | "\n",
60 | "\n",
61 | "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\""
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": 3,
67 | "metadata": {
68 | "id": "GuCCB2KYOJCE"
69 | },
70 | "outputs": [],
71 | "source": [
72 | "class LibriSpeech(torch.utils.data.Dataset):\n",
73 | " \"\"\"\n",
74 | " A simple class to wrap LibriSpeech and trim/pad the audio to 30 seconds.\n",
75 | " It will drop the last few seconds of a very small portion of the utterances.\n",
76 | " \"\"\"\n",
77 | " def __init__(self, split=\"test-clean\", device=DEVICE):\n",
78 | " self.dataset = torchaudio.datasets.LIBRISPEECH(\n",
79 | " root=os.path.expanduser(\"~/.cache\"),\n",
80 | " url=split,\n",
81 | " download=True,\n",
82 | " )\n",
83 | " self.device = device\n",
84 | "\n",
85 | " def __len__(self):\n",
86 | " return len(self.dataset)\n",
87 | "\n",
88 | " def __getitem__(self, item):\n",
89 | " audio, sample_rate, text, _, _, _ = self.dataset[item]\n",
90 | " assert sample_rate == 16000\n",
91 | " audio = whisper.pad_or_trim(audio.flatten()).to(self.device)\n",
92 | " mel = whisper.log_mel_spectrogram(audio)\n",
93 | " \n",
94 | " return (mel, text)"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": 4,
100 | "metadata": {
101 | "id": "-YcRU5jqNqo2"
102 | },
103 | "outputs": [],
104 | "source": [
105 | "dataset = LibriSpeech(\"test-clean\")\n",
106 | "loader = torch.utils.data.DataLoader(dataset, batch_size=16)"
107 | ]
108 | },
109 | {
110 | "cell_type": "markdown",
111 | "metadata": {
112 | "id": "0ljocCNuUAde"
113 | },
114 | "source": [
115 | "# Running inference on the dataset using a base Whisper model\n",
116 | "\n",
117 | "The following will take a few minutes to transcribe all utterances in the dataset."
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "execution_count": 5,
123 | "metadata": {
124 | "colab": {
125 | "base_uri": "https://localhost:8080/"
126 | },
127 | "id": "_PokfNJtOYNu",
128 | "outputId": "2c53ec44-bc93-4107-b4fa-214e3f71fe8e"
129 | },
130 | "outputs": [
131 | {
132 | "name": "stdout",
133 | "output_type": "stream",
134 | "text": [
135 | "Model is English-only and has 71,825,408 parameters.\n"
136 | ]
137 | }
138 | ],
139 | "source": [
140 | "model = whisper.load_model(\"base.en\")\n",
141 | "print(\n",
142 | " f\"Model is {'multilingual' if model.is_multilingual else 'English-only'} \"\n",
143 | " f\"and has {sum(np.prod(p.shape) for p in model.parameters()):,} parameters.\"\n",
144 | ")"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": 6,
150 | "metadata": {},
151 | "outputs": [],
152 | "source": [
153 | "# predict without timestamps for short-form transcription\n",
154 | "options = whisper.DecodingOptions(language=\"en\", without_timestamps=True)"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": 7,
160 | "metadata": {
161 | "colab": {
162 | "base_uri": "https://localhost:8080/",
163 | "height": 49,
164 | "referenced_widgets": [
165 | "09a29a91f58d4462942505a3cc415801",
166 | "83391f98a240490987c397048fc1a0d4",
167 | "06b9aa5f49fa44ba8c93b647dc7db224",
168 | "da9c231ee67047fb89073c95326b72a5",
169 | "48da931ebe7f4fd299f8c98c7d2460ff",
170 | "7a901f447c1d477bb49f954e0feacedd",
171 | "39f5a6ae8ba74c8598f9c6d5b8ad2d65",
172 | "a0d10a42c753453283e5219c22239337",
173 | "09f4cb79ff86465aaf48b0de24869af9",
174 | "1b9cecf5b3584fba8258a81d4279a25b",
175 | "039b53f2702c4179af7e0548018d0588"
176 | ]
177 | },
178 | "id": "7OWTn_KvNk59",
179 | "outputId": "a813a792-3c91-4144-f11f-054fd6778023"
180 | },
181 | "outputs": [
182 | {
183 | "data": {
184 | "application/vnd.jupyter.widget-view+json": {
185 | "model_id": "9df048b46f764cf68cbe0045b8ff73a8",
186 | "version_major": 2,
187 | "version_minor": 0
188 | },
189 | "text/plain": [
190 | " 0%| | 0/164 [00:00, ?it/s]"
191 | ]
192 | },
193 | "metadata": {},
194 | "output_type": "display_data"
195 | }
196 | ],
197 | "source": [
198 | "hypotheses = []\n",
199 | "references = []\n",
200 | "\n",
201 | "for mels, texts in tqdm(loader):\n",
202 | " results = model.decode(mels, options)\n",
203 | " hypotheses.extend([result.text for result in results])\n",
204 | " references.extend(texts)"
205 | ]
206 | },
207 | {
208 | "cell_type": "code",
209 | "execution_count": 8,
210 | "metadata": {
211 | "colab": {
212 | "base_uri": "https://localhost:8080/",
213 | "height": 424
214 | },
215 | "id": "4nTyynELQ42j",
216 | "outputId": "1c72d25a-3e87-4c60-a8d1-1da9d2f73bd7"
217 | },
218 | "outputs": [
219 | {
220 | "data": {
221 | "text/html": [
222 | "
\n",
223 | "\n",
236 | "
\n",
237 | " \n",
238 | " \n",
239 | " | \n",
240 | " hypothesis | \n",
241 | " reference | \n",
242 | "
\n",
243 | " \n",
244 | " \n",
245 | " \n",
246 | " 0 | \n",
247 | " He hoped there would be stew for dinner, turni... | \n",
248 | " HE HOPED THERE WOULD BE STEW FOR DINNER TURNIP... | \n",
249 | "
\n",
250 | " \n",
251 | " 1 | \n",
252 | " Stuffered into you, his belly counseled him. | \n",
253 | " STUFF IT INTO YOU HIS BELLY COUNSELLED HIM | \n",
254 | "
\n",
255 | " \n",
256 | " 2 | \n",
257 | " After early nightfall the yellow lamps would l... | \n",
258 | " AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD L... | \n",
259 | "
\n",
260 | " \n",
261 | " 3 | \n",
262 | " Hello Bertie, any good in your mind? | \n",
263 | " HELLO BERTIE ANY GOOD IN YOUR MIND | \n",
264 | "
\n",
265 | " \n",
266 | " 4 | \n",
267 | " Number 10. Fresh Nelly is waiting on you. Good... | \n",
268 | " NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD ... | \n",
269 | "
\n",
270 | " \n",
271 | " ... | \n",
272 | " ... | \n",
273 | " ... | \n",
274 | "
\n",
275 | " \n",
276 | " 2615 | \n",
277 | " Oh, to shoot my soul's full meaning into futur... | \n",
278 | " OH TO SHOOT MY SOUL'S FULL MEANING INTO FUTURE... | \n",
279 | "
\n",
280 | " \n",
281 | " 2616 | \n",
282 | " Then I, long tried by natural ills, received t... | \n",
283 | " THEN I LONG TRIED BY NATURAL ILLS RECEIVED THE... | \n",
284 | "
\n",
285 | " \n",
286 | " 2617 | \n",
287 | " I love thee freely as men strive for right. I ... | \n",
288 | " I LOVE THEE FREELY AS MEN STRIVE FOR RIGHT I L... | \n",
289 | "
\n",
290 | " \n",
291 | " 2618 | \n",
292 | " I love thee with the passion put to use, in my... | \n",
293 | " I LOVE THEE WITH THE PASSION PUT TO USE IN MY ... | \n",
294 | "
\n",
295 | " \n",
296 | " 2619 | \n",
297 | " I love thee with the love I seemed to lose wit... | \n",
298 | " I LOVE THEE WITH A LOVE I SEEMED TO LOSE WITH ... | \n",
299 | "
\n",
300 | " \n",
301 | "
\n",
302 | "
2620 rows × 2 columns
\n",
303 | "
"
304 | ],
305 | "text/plain": [
306 | " hypothesis \\\n",
307 | "0 He hoped there would be stew for dinner, turni... \n",
308 | "1 Stuffered into you, his belly counseled him. \n",
309 | "2 After early nightfall the yellow lamps would l... \n",
310 | "3 Hello Bertie, any good in your mind? \n",
311 | "4 Number 10. Fresh Nelly is waiting on you. Good... \n",
312 | "... ... \n",
313 | "2615 Oh, to shoot my soul's full meaning into futur... \n",
314 | "2616 Then I, long tried by natural ills, received t... \n",
315 | "2617 I love thee freely as men strive for right. I ... \n",
316 | "2618 I love thee with the passion put to use, in my... \n",
317 | "2619 I love thee with the love I seemed to lose wit... \n",
318 | "\n",
319 | " reference \n",
320 | "0 HE HOPED THERE WOULD BE STEW FOR DINNER TURNIP... \n",
321 | "1 STUFF IT INTO YOU HIS BELLY COUNSELLED HIM \n",
322 | "2 AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD L... \n",
323 | "3 HELLO BERTIE ANY GOOD IN YOUR MIND \n",
324 | "4 NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD ... \n",
325 | "... ... \n",
326 | "2615 OH TO SHOOT MY SOUL'S FULL MEANING INTO FUTURE... \n",
327 | "2616 THEN I LONG TRIED BY NATURAL ILLS RECEIVED THE... \n",
328 | "2617 I LOVE THEE FREELY AS MEN STRIVE FOR RIGHT I L... \n",
329 | "2618 I LOVE THEE WITH THE PASSION PUT TO USE IN MY ... \n",
330 | "2619 I LOVE THEE WITH A LOVE I SEEMED TO LOSE WITH ... \n",
331 | "\n",
332 | "[2620 rows x 2 columns]"
333 | ]
334 | },
335 | "execution_count": 8,
336 | "metadata": {},
337 | "output_type": "execute_result"
338 | }
339 | ],
340 | "source": [
341 | "data = pd.DataFrame(dict(hypothesis=hypotheses, reference=references))\n",
342 | "data"
343 | ]
344 | },
345 | {
346 | "cell_type": "markdown",
347 | "metadata": {
348 | "id": "HPppEJRXX4ox"
349 | },
350 | "source": [
351 | "# Calculating the word error rate\n",
352 | "\n",
353 | "Now, we use our English normalizer implementation to standardize the transcription and calculate the WER."
354 | ]
355 | },
356 | {
357 | "cell_type": "code",
358 | "execution_count": 9,
359 | "metadata": {
360 | "id": "dl-KBDflMhrg"
361 | },
362 | "outputs": [],
363 | "source": [
364 | "import jiwer\n",
365 | "from whisper.normalizers import EnglishTextNormalizer\n",
366 | "\n",
367 | "normalizer = EnglishTextNormalizer()"
368 | ]
369 | },
370 | {
371 | "cell_type": "code",
372 | "execution_count": 10,
373 | "metadata": {
374 | "colab": {
375 | "base_uri": "https://localhost:8080/",
376 | "height": 641
377 | },
378 | "id": "6-O048q4WI4o",
379 | "outputId": "f2089bc9-f535-441e-f192-26e52ae82b5e"
380 | },
381 | "outputs": [
382 | {
383 | "data": {
384 | "text/html": [
385 | "\n",
386 | "\n",
399 | "
\n",
400 | " \n",
401 | " \n",
402 | " | \n",
403 | " hypothesis | \n",
404 | " reference | \n",
405 | " hypothesis_clean | \n",
406 | " reference_clean | \n",
407 | "
\n",
408 | " \n",
409 | " \n",
410 | " \n",
411 | " 0 | \n",
412 | " He hoped there would be stew for dinner, turni... | \n",
413 | " HE HOPED THERE WOULD BE STEW FOR DINNER TURNIP... | \n",
414 | " he hoped there would be stew for dinner turnip... | \n",
415 | " he hoped there would be stew for dinner turnip... | \n",
416 | "
\n",
417 | " \n",
418 | " 1 | \n",
419 | " Stuffered into you, his belly counseled him. | \n",
420 | " STUFF IT INTO YOU HIS BELLY COUNSELLED HIM | \n",
421 | " stuffered into you his belly counseled him | \n",
422 | " stuff it into you his belly counseled him | \n",
423 | "
\n",
424 | " \n",
425 | " 2 | \n",
426 | " After early nightfall the yellow lamps would l... | \n",
427 | " AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD L... | \n",
428 | " after early nightfall the yellow lamps would l... | \n",
429 | " after early nightfall the yellow lamps would l... | \n",
430 | "
\n",
431 | " \n",
432 | " 3 | \n",
433 | " Hello Bertie, any good in your mind? | \n",
434 | " HELLO BERTIE ANY GOOD IN YOUR MIND | \n",
435 | " hello bertie any good in your mind | \n",
436 | " hello bertie any good in your mind | \n",
437 | "
\n",
438 | " \n",
439 | " 4 | \n",
440 | " Number 10. Fresh Nelly is waiting on you. Good... | \n",
441 | " NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD ... | \n",
442 | " number 10 fresh nelly is waiting on you good n... | \n",
443 | " number 10 fresh nelly is waiting on you good n... | \n",
444 | "
\n",
445 | " \n",
446 | " ... | \n",
447 | " ... | \n",
448 | " ... | \n",
449 | " ... | \n",
450 | " ... | \n",
451 | "
\n",
452 | " \n",
453 | " 2615 | \n",
454 | " Oh, to shoot my soul's full meaning into futur... | \n",
455 | " OH TO SHOOT MY SOUL'S FULL MEANING INTO FUTURE... | \n",
456 | " 0 to shoot my soul is full meaning into future... | \n",
457 | " 0 to shoot my soul is full meaning into future... | \n",
458 | "
\n",
459 | " \n",
460 | " 2616 | \n",
461 | " Then I, long tried by natural ills, received t... | \n",
462 | " THEN I LONG TRIED BY NATURAL ILLS RECEIVED THE... | \n",
463 | " then i long tried by natural ills received the... | \n",
464 | " then i long tried by natural ills received the... | \n",
465 | "
\n",
466 | " \n",
467 | " 2617 | \n",
468 | " I love thee freely as men strive for right. I ... | \n",
469 | " I LOVE THEE FREELY AS MEN STRIVE FOR RIGHT I L... | \n",
470 | " i love thee freely as men strive for right i l... | \n",
471 | " i love thee freely as men strive for right i l... | \n",
472 | "
\n",
473 | " \n",
474 | " 2618 | \n",
475 | " I love thee with the passion put to use, in my... | \n",
476 | " I LOVE THEE WITH THE PASSION PUT TO USE IN MY ... | \n",
477 | " i love thee with the passion put to use in my ... | \n",
478 | " i love thee with the passion put to use in my ... | \n",
479 | "
\n",
480 | " \n",
481 | " 2619 | \n",
482 | " I love thee with the love I seemed to lose wit... | \n",
483 | " I LOVE THEE WITH A LOVE I SEEMED TO LOSE WITH ... | \n",
484 | " i love thee with the love i seemed to lose wit... | \n",
485 | " i love thee with a love i seemed to lose with ... | \n",
486 | "
\n",
487 | " \n",
488 | "
\n",
489 | "
2620 rows × 4 columns
\n",
490 | "
"
491 | ],
492 | "text/plain": [
493 | " hypothesis \\\n",
494 | "0 He hoped there would be stew for dinner, turni... \n",
495 | "1 Stuffered into you, his belly counseled him. \n",
496 | "2 After early nightfall the yellow lamps would l... \n",
497 | "3 Hello Bertie, any good in your mind? \n",
498 | "4 Number 10. Fresh Nelly is waiting on you. Good... \n",
499 | "... ... \n",
500 | "2615 Oh, to shoot my soul's full meaning into futur... \n",
501 | "2616 Then I, long tried by natural ills, received t... \n",
502 | "2617 I love thee freely as men strive for right. I ... \n",
503 | "2618 I love thee with the passion put to use, in my... \n",
504 | "2619 I love thee with the love I seemed to lose wit... \n",
505 | "\n",
506 | " reference \\\n",
507 | "0 HE HOPED THERE WOULD BE STEW FOR DINNER TURNIP... \n",
508 | "1 STUFF IT INTO YOU HIS BELLY COUNSELLED HIM \n",
509 | "2 AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD L... \n",
510 | "3 HELLO BERTIE ANY GOOD IN YOUR MIND \n",
511 | "4 NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD ... \n",
512 | "... ... \n",
513 | "2615 OH TO SHOOT MY SOUL'S FULL MEANING INTO FUTURE... \n",
514 | "2616 THEN I LONG TRIED BY NATURAL ILLS RECEIVED THE... \n",
515 | "2617 I LOVE THEE FREELY AS MEN STRIVE FOR RIGHT I L... \n",
516 | "2618 I LOVE THEE WITH THE PASSION PUT TO USE IN MY ... \n",
517 | "2619 I LOVE THEE WITH A LOVE I SEEMED TO LOSE WITH ... \n",
518 | "\n",
519 | " hypothesis_clean \\\n",
520 | "0 he hoped there would be stew for dinner turnip... \n",
521 | "1 stuffered into you his belly counseled him \n",
522 | "2 after early nightfall the yellow lamps would l... \n",
523 | "3 hello bertie any good in your mind \n",
524 | "4 number 10 fresh nelly is waiting on you good n... \n",
525 | "... ... \n",
526 | "2615 0 to shoot my soul is full meaning into future... \n",
527 | "2616 then i long tried by natural ills received the... \n",
528 | "2617 i love thee freely as men strive for right i l... \n",
529 | "2618 i love thee with the passion put to use in my ... \n",
530 | "2619 i love thee with the love i seemed to lose wit... \n",
531 | "\n",
532 | " reference_clean \n",
533 | "0 he hoped there would be stew for dinner turnip... \n",
534 | "1 stuff it into you his belly counseled him \n",
535 | "2 after early nightfall the yellow lamps would l... \n",
536 | "3 hello bertie any good in your mind \n",
537 | "4 number 10 fresh nelly is waiting on you good n... \n",
538 | "... ... \n",
539 | "2615 0 to shoot my soul is full meaning into future... \n",
540 | "2616 then i long tried by natural ills received the... \n",
541 | "2617 i love thee freely as men strive for right i l... \n",
542 | "2618 i love thee with the passion put to use in my ... \n",
543 | "2619 i love thee with a love i seemed to lose with ... \n",
544 | "\n",
545 | "[2620 rows x 4 columns]"
546 | ]
547 | },
548 | "execution_count": 10,
549 | "metadata": {},
550 | "output_type": "execute_result"
551 | }
552 | ],
553 | "source": [
554 | "data[\"hypothesis_clean\"] = [normalizer(text) for text in data[\"hypothesis\"]]\n",
555 | "data[\"reference_clean\"] = [normalizer(text) for text in data[\"reference\"]]\n",
556 | "data"
557 | ]
558 | },
559 | {
560 | "cell_type": "code",
561 | "execution_count": 11,
562 | "metadata": {
563 | "colab": {
564 | "base_uri": "https://localhost:8080/"
565 | },
566 | "id": "EBGSITeBYPTT",
567 | "outputId": "7b3dbe7c-a37e-4a07-a50a-b27d5f88b68f"
568 | },
569 | "outputs": [
570 | {
571 | "name": "stdout",
572 | "output_type": "stream",
573 | "text": [
574 | "WER: 4.26 %\n"
575 | ]
576 | }
577 | ],
578 | "source": [
579 | "wer = jiwer.wer(list(data[\"reference_clean\"]), list(data[\"hypothesis_clean\"]))\n",
580 | "\n",
581 | "print(f\"WER: {wer * 100:.2f} %\")"
582 | ]
583 | }
584 | ],
585 | "metadata": {
586 | "accelerator": "GPU",
587 | "colab": {
588 | "collapsed_sections": [],
589 | "provenance": []
590 | },
591 | "gpuClass": "standard",
592 | "kernelspec": {
593 | "display_name": "Python 3 (ipykernel)",
594 | "language": "python",
595 | "name": "python3"
596 | },
597 | "language_info": {
598 | "codemirror_mode": {
599 | "name": "ipython",
600 | "version": 3
601 | },
602 | "file_extension": ".py",
603 | "mimetype": "text/x-python",
604 | "name": "python",
605 | "nbconvert_exporter": "python",
606 | "pygments_lexer": "ipython3",
607 | "version": "3.9.9"
608 | },
609 | "widgets": {
610 | "application/vnd.jupyter.widget-state+json": {
611 | "039b53f2702c4179af7e0548018d0588": {
612 | "model_module": "@jupyter-widgets/controls",
613 | "model_module_version": "1.5.0",
614 | "model_name": "DescriptionStyleModel",
615 | "state": {
616 | "_model_module": "@jupyter-widgets/controls",
617 | "_model_module_version": "1.5.0",
618 | "_model_name": "DescriptionStyleModel",
619 | "_view_count": null,
620 | "_view_module": "@jupyter-widgets/base",
621 | "_view_module_version": "1.2.0",
622 | "_view_name": "StyleView",
623 | "description_width": ""
624 | }
625 | },
626 | "06b9aa5f49fa44ba8c93b647dc7db224": {
627 | "model_module": "@jupyter-widgets/controls",
628 | "model_module_version": "1.5.0",
629 | "model_name": "FloatProgressModel",
630 | "state": {
631 | "_dom_classes": [],
632 | "_model_module": "@jupyter-widgets/controls",
633 | "_model_module_version": "1.5.0",
634 | "_model_name": "FloatProgressModel",
635 | "_view_count": null,
636 | "_view_module": "@jupyter-widgets/controls",
637 | "_view_module_version": "1.5.0",
638 | "_view_name": "ProgressView",
639 | "bar_style": "success",
640 | "description": "",
641 | "description_tooltip": null,
642 | "layout": "IPY_MODEL_a0d10a42c753453283e5219c22239337",
643 | "max": 164,
644 | "min": 0,
645 | "orientation": "horizontal",
646 | "style": "IPY_MODEL_09f4cb79ff86465aaf48b0de24869af9",
647 | "value": 164
648 | }
649 | },
650 | "09a29a91f58d4462942505a3cc415801": {
651 | "model_module": "@jupyter-widgets/controls",
652 | "model_module_version": "1.5.0",
653 | "model_name": "HBoxModel",
654 | "state": {
655 | "_dom_classes": [],
656 | "_model_module": "@jupyter-widgets/controls",
657 | "_model_module_version": "1.5.0",
658 | "_model_name": "HBoxModel",
659 | "_view_count": null,
660 | "_view_module": "@jupyter-widgets/controls",
661 | "_view_module_version": "1.5.0",
662 | "_view_name": "HBoxView",
663 | "box_style": "",
664 | "children": [
665 | "IPY_MODEL_83391f98a240490987c397048fc1a0d4",
666 | "IPY_MODEL_06b9aa5f49fa44ba8c93b647dc7db224",
667 | "IPY_MODEL_da9c231ee67047fb89073c95326b72a5"
668 | ],
669 | "layout": "IPY_MODEL_48da931ebe7f4fd299f8c98c7d2460ff"
670 | }
671 | },
672 | "09f4cb79ff86465aaf48b0de24869af9": {
673 | "model_module": "@jupyter-widgets/controls",
674 | "model_module_version": "1.5.0",
675 | "model_name": "ProgressStyleModel",
676 | "state": {
677 | "_model_module": "@jupyter-widgets/controls",
678 | "_model_module_version": "1.5.0",
679 | "_model_name": "ProgressStyleModel",
680 | "_view_count": null,
681 | "_view_module": "@jupyter-widgets/base",
682 | "_view_module_version": "1.2.0",
683 | "_view_name": "StyleView",
684 | "bar_color": null,
685 | "description_width": ""
686 | }
687 | },
688 | "1b9cecf5b3584fba8258a81d4279a25b": {
689 | "model_module": "@jupyter-widgets/base",
690 | "model_module_version": "1.2.0",
691 | "model_name": "LayoutModel",
692 | "state": {
693 | "_model_module": "@jupyter-widgets/base",
694 | "_model_module_version": "1.2.0",
695 | "_model_name": "LayoutModel",
696 | "_view_count": null,
697 | "_view_module": "@jupyter-widgets/base",
698 | "_view_module_version": "1.2.0",
699 | "_view_name": "LayoutView",
700 | "align_content": null,
701 | "align_items": null,
702 | "align_self": null,
703 | "border": null,
704 | "bottom": null,
705 | "display": null,
706 | "flex": null,
707 | "flex_flow": null,
708 | "grid_area": null,
709 | "grid_auto_columns": null,
710 | "grid_auto_flow": null,
711 | "grid_auto_rows": null,
712 | "grid_column": null,
713 | "grid_gap": null,
714 | "grid_row": null,
715 | "grid_template_areas": null,
716 | "grid_template_columns": null,
717 | "grid_template_rows": null,
718 | "height": null,
719 | "justify_content": null,
720 | "justify_items": null,
721 | "left": null,
722 | "margin": null,
723 | "max_height": null,
724 | "max_width": null,
725 | "min_height": null,
726 | "min_width": null,
727 | "object_fit": null,
728 | "object_position": null,
729 | "order": null,
730 | "overflow": null,
731 | "overflow_x": null,
732 | "overflow_y": null,
733 | "padding": null,
734 | "right": null,
735 | "top": null,
736 | "visibility": null,
737 | "width": null
738 | }
739 | },
740 | "39f5a6ae8ba74c8598f9c6d5b8ad2d65": {
741 | "model_module": "@jupyter-widgets/controls",
742 | "model_module_version": "1.5.0",
743 | "model_name": "DescriptionStyleModel",
744 | "state": {
745 | "_model_module": "@jupyter-widgets/controls",
746 | "_model_module_version": "1.5.0",
747 | "_model_name": "DescriptionStyleModel",
748 | "_view_count": null,
749 | "_view_module": "@jupyter-widgets/base",
750 | "_view_module_version": "1.2.0",
751 | "_view_name": "StyleView",
752 | "description_width": ""
753 | }
754 | },
755 | "48da931ebe7f4fd299f8c98c7d2460ff": {
756 | "model_module": "@jupyter-widgets/base",
757 | "model_module_version": "1.2.0",
758 | "model_name": "LayoutModel",
759 | "state": {
760 | "_model_module": "@jupyter-widgets/base",
761 | "_model_module_version": "1.2.0",
762 | "_model_name": "LayoutModel",
763 | "_view_count": null,
764 | "_view_module": "@jupyter-widgets/base",
765 | "_view_module_version": "1.2.0",
766 | "_view_name": "LayoutView",
767 | "align_content": null,
768 | "align_items": null,
769 | "align_self": null,
770 | "border": null,
771 | "bottom": null,
772 | "display": null,
773 | "flex": null,
774 | "flex_flow": null,
775 | "grid_area": null,
776 | "grid_auto_columns": null,
777 | "grid_auto_flow": null,
778 | "grid_auto_rows": null,
779 | "grid_column": null,
780 | "grid_gap": null,
781 | "grid_row": null,
782 | "grid_template_areas": null,
783 | "grid_template_columns": null,
784 | "grid_template_rows": null,
785 | "height": null,
786 | "justify_content": null,
787 | "justify_items": null,
788 | "left": null,
789 | "margin": null,
790 | "max_height": null,
791 | "max_width": null,
792 | "min_height": null,
793 | "min_width": null,
794 | "object_fit": null,
795 | "object_position": null,
796 | "order": null,
797 | "overflow": null,
798 | "overflow_x": null,
799 | "overflow_y": null,
800 | "padding": null,
801 | "right": null,
802 | "top": null,
803 | "visibility": null,
804 | "width": null
805 | }
806 | },
807 | "7a901f447c1d477bb49f954e0feacedd": {
808 | "model_module": "@jupyter-widgets/base",
809 | "model_module_version": "1.2.0",
810 | "model_name": "LayoutModel",
811 | "state": {
812 | "_model_module": "@jupyter-widgets/base",
813 | "_model_module_version": "1.2.0",
814 | "_model_name": "LayoutModel",
815 | "_view_count": null,
816 | "_view_module": "@jupyter-widgets/base",
817 | "_view_module_version": "1.2.0",
818 | "_view_name": "LayoutView",
819 | "align_content": null,
820 | "align_items": null,
821 | "align_self": null,
822 | "border": null,
823 | "bottom": null,
824 | "display": null,
825 | "flex": null,
826 | "flex_flow": null,
827 | "grid_area": null,
828 | "grid_auto_columns": null,
829 | "grid_auto_flow": null,
830 | "grid_auto_rows": null,
831 | "grid_column": null,
832 | "grid_gap": null,
833 | "grid_row": null,
834 | "grid_template_areas": null,
835 | "grid_template_columns": null,
836 | "grid_template_rows": null,
837 | "height": null,
838 | "justify_content": null,
839 | "justify_items": null,
840 | "left": null,
841 | "margin": null,
842 | "max_height": null,
843 | "max_width": null,
844 | "min_height": null,
845 | "min_width": null,
846 | "object_fit": null,
847 | "object_position": null,
848 | "order": null,
849 | "overflow": null,
850 | "overflow_x": null,
851 | "overflow_y": null,
852 | "padding": null,
853 | "right": null,
854 | "top": null,
855 | "visibility": null,
856 | "width": null
857 | }
858 | },
859 | "83391f98a240490987c397048fc1a0d4": {
860 | "model_module": "@jupyter-widgets/controls",
861 | "model_module_version": "1.5.0",
862 | "model_name": "HTMLModel",
863 | "state": {
864 | "_dom_classes": [],
865 | "_model_module": "@jupyter-widgets/controls",
866 | "_model_module_version": "1.5.0",
867 | "_model_name": "HTMLModel",
868 | "_view_count": null,
869 | "_view_module": "@jupyter-widgets/controls",
870 | "_view_module_version": "1.5.0",
871 | "_view_name": "HTMLView",
872 | "description": "",
873 | "description_tooltip": null,
874 | "layout": "IPY_MODEL_7a901f447c1d477bb49f954e0feacedd",
875 | "placeholder": "",
876 | "style": "IPY_MODEL_39f5a6ae8ba74c8598f9c6d5b8ad2d65",
877 | "value": "100%"
878 | }
879 | },
880 | "a0d10a42c753453283e5219c22239337": {
881 | "model_module": "@jupyter-widgets/base",
882 | "model_module_version": "1.2.0",
883 | "model_name": "LayoutModel",
884 | "state": {
885 | "_model_module": "@jupyter-widgets/base",
886 | "_model_module_version": "1.2.0",
887 | "_model_name": "LayoutModel",
888 | "_view_count": null,
889 | "_view_module": "@jupyter-widgets/base",
890 | "_view_module_version": "1.2.0",
891 | "_view_name": "LayoutView",
892 | "align_content": null,
893 | "align_items": null,
894 | "align_self": null,
895 | "border": null,
896 | "bottom": null,
897 | "display": null,
898 | "flex": null,
899 | "flex_flow": null,
900 | "grid_area": null,
901 | "grid_auto_columns": null,
902 | "grid_auto_flow": null,
903 | "grid_auto_rows": null,
904 | "grid_column": null,
905 | "grid_gap": null,
906 | "grid_row": null,
907 | "grid_template_areas": null,
908 | "grid_template_columns": null,
909 | "grid_template_rows": null,
910 | "height": null,
911 | "justify_content": null,
912 | "justify_items": null,
913 | "left": null,
914 | "margin": null,
915 | "max_height": null,
916 | "max_width": null,
917 | "min_height": null,
918 | "min_width": null,
919 | "object_fit": null,
920 | "object_position": null,
921 | "order": null,
922 | "overflow": null,
923 | "overflow_x": null,
924 | "overflow_y": null,
925 | "padding": null,
926 | "right": null,
927 | "top": null,
928 | "visibility": null,
929 | "width": null
930 | }
931 | },
932 | "da9c231ee67047fb89073c95326b72a5": {
933 | "model_module": "@jupyter-widgets/controls",
934 | "model_module_version": "1.5.0",
935 | "model_name": "HTMLModel",
936 | "state": {
937 | "_dom_classes": [],
938 | "_model_module": "@jupyter-widgets/controls",
939 | "_model_module_version": "1.5.0",
940 | "_model_name": "HTMLModel",
941 | "_view_count": null,
942 | "_view_module": "@jupyter-widgets/controls",
943 | "_view_module_version": "1.5.0",
944 | "_view_name": "HTMLView",
945 | "description": "",
946 | "description_tooltip": null,
947 | "layout": "IPY_MODEL_1b9cecf5b3584fba8258a81d4279a25b",
948 | "placeholder": "",
949 | "style": "IPY_MODEL_039b53f2702c4179af7e0548018d0588",
950 | "value": " 164/164 [05:08<00:00, 1.86s/it]"
951 | }
952 | }
953 | }
954 | }
955 | },
956 | "nbformat": 4,
957 | "nbformat_minor": 1
958 | }
959 |
--------------------------------------------------------------------------------
/notes.txt:
--------------------------------------------------------------------------------
1 |
2 | -------------------
3 | codepath
4 |
5 | transcribe.py -> transcribe
6 | model.decode
7 | decoding.py -> decode
8 | decoding.py -> DecodingTask.run
9 | DecodingTask._get_audio_features (uses encoder only)
10 | --------------- Dichotomy here!
11 | DecodingTask._main_loop (uses decoder only)
12 |
13 | PyTorchInference.logits (uses decoder only)
14 |
15 | Whisper.install_kv_cache_hooks (uses decoder only)
16 |
17 | -----------------------
18 | model loading info
19 |
20 | checkpoint = torch.load(fp, map_location=device)
21 | ...
22 | dims = ModelDimensions(**checkpoint["dims"])
23 | model = Whisper(dims)
24 | model.load_state_dict(checkpoint["model_state_dict"])
25 |
26 | -----------------------
27 | dims, try setting to 0 for encoder / decoder separately:
28 | self.encoder = AudioEncoder(
29 | self.dims.n_mels,
30 | self.dims.n_audio_ctx,
31 | self.dims.n_audio_state,
32 | self.dims.n_audio_head,
33 | self.dims.n_audio_layer,
34 | )
35 | self.decoder = TextDecoder(
36 | self.dims.n_vocab,
37 | self.dims.n_text_ctx,
38 | self.dims.n_text_state,
39 | self.dims.n_text_head,
40 | self.dims.n_text_layer,
41 | )
42 |
43 | state dict keys: ['decoder.positional_embedding', 'encoder.positional_embedding', 'decoder.token_embedding.weight', 'decoder.blocks.0.mlp_ln.weight',
44 | 'decoder.blocks.0.mlp_ln.bias', 'decoder.blocks.0.mlp.0.weight', 'decoder.blocks.0.mlp.0.bias',
45 | 'decoder.blocks.0.mlp.2.weight', 'decoder.blocks.0.mlp.2.bias', 'decoder.blocks.0.attn_ln.weight',
46 | 'decoder.blocks.0.attn_ln.bias', 'decoder.blocks.0.attn.query.weight', 'decoder.blocks.0.attn.query.bias',
47 | 'decoder.blocks.0.attn.key.weight', 'decoder.blocks.0.attn.value.weight', 'decoder.blocks.0.attn.value.bias',
48 | 'decoder.blocks.0.attn.out.weight', 'decoder.blocks.0.attn.out.bias', 'decoder.blocks.0.cross_attn_ln.weight',
49 | 'decoder.blocks.0.cross_attn_ln.bias', 'decoder.blocks.0.cross_attn.query.weight', 'decoder.blocks.0.cross_attn.query.bias',
50 | 'decoder.blocks.0.cross_attn.key.weight', 'decoder.blocks.0.cross_attn.value.weight', 'decoder.blocks.0.cross_attn.value.bias',
51 | 'decoder.blocks.0.cross_attn.out.weight', 'decoder.blocks.0.cross_attn.out.bias', 'decoder.blocks.1.mlp_ln.weight', 'decoder.blocks.1.mlp_ln.bias',
52 | 'decoder.blocks.1.mlp.0.weight', 'decoder.blocks.1.mlp.0.bias', 'decoder.blocks.1.mlp.2.weight', 'decoder.blocks.1.mlp.2.bias',
53 | 'decoder.blocks.1.attn_ln.weight', 'decoder.blocks.1.attn_ln.bias', 'decoder.blocks.1.attn.query.weight',
54 | 'decoder.blocks.1.attn.query.bias', 'decoder.blocks.1.attn.key.weight', 'decoder.blocks.1.attn.value.weight',
55 | 'decoder.blocks.1.attn.value.bias', 'decoder.blocks.1.attn.out.weight', 'decoder.blocks.1.attn.out.bias', 'decoder.blocks.1.cross_attn_ln.weight',
56 | 'decoder.blocks.1.cross_attn_ln.bias', 'decoder.blocks.1.cross_attn.query.weight', 'decoder.blocks.1.cross_attn.query.bias',
57 | 'decoder.blocks.1.cross_attn.key.weight', 'decoder.blocks.1.cross_attn.value.weight', 'decoder.blocks.1.cross_attn.value.bias',
58 | 'decoder.blocks.1.cross_attn.out.weight', 'decoder.blocks.1.cross_attn.out.bias', 'decoder.blocks.2.mlp_ln.weight',
59 | 'decoder.blocks.2.mlp_ln.bias', 'decoder.blocks.2.mlp.0.weight', 'decoder.blocks.2.mlp.0.bias', 'decoder.blocks.2.mlp.2.weight',
60 | 'decoder.blocks.2.mlp.2.bias', 'decoder.blocks.2.attn_ln.weight', 'decoder.blocks.2.attn_ln.bias', 'decoder.blocks.2.attn.query.weight',
61 | 'decoder.blocks.2.attn.query.bias', 'decoder.blocks.2.attn.key.weight', 'decoder.blocks.2.attn.value.weight', 'decoder.blocks.2.attn.value.bias',
62 | 'decoder.blocks.2.attn.out.weight', 'decoder.blocks.2.attn.out.bias', 'decoder.blocks.2.cross_attn_ln.weight', 'decoder.blocks.2.cross_attn_ln.bias',
63 | 'decoder.blocks.2.cross_attn.query.weight', 'decoder.blocks.2.cross_attn.query.bias', 'decoder.blocks.2.cross_attn.key.weight', 'decoder.blocks.2.cross_attn.value.weight',
64 | 'decoder.blocks.2.cross_attn.value.bias', 'decoder.blocks.2.cross_attn.out.weight', 'decoder.blocks.2.cross_attn.out.bias', 'decoder.blocks.3.mlp_ln.weight',
65 | 'decoder.blocks.3.mlp_ln.bias', 'decoder.blocks.3.mlp.0.weight', 'decoder.blocks.3.mlp.0.bias', 'decoder.blocks.3.mlp.2.weight', 'decoder.blocks.3.mlp.2.bias',
66 | 'decoder.blocks.3.attn_ln.weight', 'decoder.blocks.3.attn_ln.bias', 'decoder.blocks.3.attn.query.weight', 'decoder.blocks.3.attn.query.bias',
67 | 'decoder.blocks.3.attn.key.weight', 'decoder.blocks.3.attn.value.weight', 'decoder.blocks.3.attn.value.bias', 'decoder.blocks.3.attn.out.weight',
68 | 'decoder.blocks.3.attn.out.bias', 'decoder.blocks.3.cross_attn_ln.weight', 'decoder.blocks.3.cross_attn_ln.bias', 'decoder.blocks.3.cross_attn.query.weight',
69 | 'decoder.blocks.3.cross_attn.query.bias', 'decoder.blocks.3.cross_attn.key.weight', 'decoder.blocks.3.cross_attn.value.weight', 'decoder.blocks.3.cross_attn.value.bias',
70 | 'decoder.blocks.3.cross_attn.out.weight', 'decoder.blocks.3.cross_attn.out.bias', 'decoder.blocks.4.mlp_ln.weight', 'decoder.blocks.4.mlp_ln.bias', 'decoder.blocks.4.mlp.0.weight',
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 | ----------------------- archived info
92 | Decoder functions:
93 | inference: PytorchInference
94 | inference.logits
95 |
96 | decoding.py -> decode (= model.decode in transcribe.py)
97 | model.logits
98 |
99 | Encoder functions:
100 | model.embed_audio
101 | model._get_audio_features
102 | decoding.py -> detect_language
103 |
104 | -------------------
105 |
106 | model.forward uses both encode and decode (but is unused.)
107 |
108 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 |
3 | [tool.isort]
4 | profile = "black"
5 | include_trailing_comma = true
6 | line_length = 88
7 | multi_line_output = 3
8 |
9 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numba
2 | numpy
3 | torch
4 | tqdm
5 | more-itertools
6 | tiktoken==0.3.1
7 | ffmpeg-python==0.2.0
8 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import platform
3 | import sys
4 |
5 | import pkg_resources
6 | from setuptools import find_packages, setup
7 |
8 |
9 | def read_version(fname="whisper/version.py"):
10 | exec(compile(open(fname, encoding="utf-8").read(), fname, "exec"))
11 | return locals()["__version__"]
12 |
13 |
14 | requirements = []
15 | if sys.platform.startswith("linux") and platform.machine() == "x86_64":
16 | requirements.append("triton==2.0.0")
17 |
18 | setup(
19 | name="openai-whisper",
20 | py_modules=["whisper"],
21 | version=read_version(),
22 | description="Robust Speech Recognition via Large-Scale Weak Supervision",
23 | long_description=open("README.md", encoding="utf-8").read(),
24 | long_description_content_type="text/markdown",
25 | readme="README.md",
26 | python_requires=">=3.8",
27 | author="OpenAI",
28 | url="https://github.com/openai/whisper",
29 | license="MIT",
30 | packages=find_packages(exclude=["tests*"]),
31 | install_requires=requirements
32 | + [
33 | str(r)
34 | for r in pkg_resources.parse_requirements(
35 | open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
36 | )
37 | ],
38 | entry_points={
39 | "console_scripts": ["whisper=whisper.transcribe:cli"],
40 | },
41 | include_package_data=True,
42 | extras_require={"dev": ["pytest", "scipy", "black", "flake8", "isort"]},
43 | )
44 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import random as rand
2 |
3 | import numpy
4 | import pytest
5 |
6 |
7 | def pytest_configure(config):
8 | config.addinivalue_line("markers", "requires_cuda")
9 |
10 |
11 | @pytest.fixture
12 | def random():
13 | rand.seed(42)
14 | numpy.random.seed(42)
15 |
--------------------------------------------------------------------------------
/tests/jfk.flac:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProjectEGU/whisper-for-low-vram/6175c21f63450a971ee75428e1dc4aeb5d1953b4/tests/jfk.flac
--------------------------------------------------------------------------------
/tests/test_audio.py:
--------------------------------------------------------------------------------
1 | import os.path
2 |
3 | import numpy as np
4 |
5 | from whisper.audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
6 |
7 |
8 | def test_audio():
9 | audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
10 | audio = load_audio(audio_path)
11 | assert audio.ndim == 1
12 | assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12
13 | assert 0 < audio.std() < 1
14 |
15 | mel_from_audio = log_mel_spectrogram(audio)
16 | mel_from_file = log_mel_spectrogram(audio_path)
17 |
18 | assert np.allclose(mel_from_audio, mel_from_file)
19 | assert mel_from_audio.max() - mel_from_audio.min() <= 2.0
20 |
--------------------------------------------------------------------------------
/tests/test_normalizer.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from whisper.normalizers import EnglishTextNormalizer
4 | from whisper.normalizers.english import (
5 | EnglishNumberNormalizer,
6 | EnglishSpellingNormalizer,
7 | )
8 |
9 |
10 | @pytest.mark.parametrize("std", [EnglishNumberNormalizer(), EnglishTextNormalizer()])
11 | def test_number_normalizer(std):
12 | assert std("two") == "2"
13 | assert std("thirty one") == "31"
14 | assert std("five twenty four") == "524"
15 | assert std("nineteen ninety nine") == "1999"
16 | assert std("twenty nineteen") == "2019"
17 |
18 | assert std("two point five million") == "2500000"
19 | assert std("four point two billions") == "4200000000s"
20 | assert std("200 thousand") == "200000"
21 | assert std("200 thousand dollars") == "$200000"
22 | assert std("$20 million") == "$20000000"
23 | assert std("€52.4 million") == "€52400000"
24 | assert std("£77 thousands") == "£77000s"
25 |
26 | assert std("two double o eight") == "2008"
27 |
28 | assert std("three thousand twenty nine") == "3029"
29 | assert std("forty three thousand two hundred sixty") == "43260"
30 | assert std("forty three thousand two hundred and sixty") == "43260"
31 |
32 | assert std("nineteen fifties") == "1950s"
33 | assert std("thirty first") == "31st"
34 | assert std("thirty three thousand and three hundred and thirty third") == "33333rd"
35 |
36 | assert std("three billion") == "3000000000"
37 | assert std("millions") == "1000000s"
38 |
39 | assert std("july third twenty twenty") == "july 3rd 2020"
40 | assert std("august twenty sixth twenty twenty one") == "august 26th 2021"
41 | assert std("3 14") == "3 14"
42 | assert std("3.14") == "3.14"
43 | assert std("3 point 2") == "3.2"
44 | assert std("3 point 14") == "3.14"
45 | assert std("fourteen point 4") == "14.4"
46 | assert std("two point two five dollars") == "$2.25"
47 | assert std("two hundred million dollars") == "$200000000"
48 | assert std("$20.1 million") == "$20100000"
49 |
50 | assert std("ninety percent") == "90%"
51 | assert std("seventy six per cent") == "76%"
52 |
53 | assert std("double oh seven") == "007"
54 | assert std("double zero seven") == "007"
55 | assert std("nine one one") == "911"
56 | assert std("nine double one") == "911"
57 | assert std("one triple oh one") == "10001"
58 |
59 | assert std("two thousandth") == "2000th"
60 | assert std("thirty two thousandth") == "32000th"
61 |
62 | assert std("minus 500") == "-500"
63 | assert std("positive twenty thousand") == "+20000"
64 |
65 | assert std("two dollars and seventy cents") == "$2.70"
66 | assert std("3 cents") == "¢3"
67 | assert std("$0.36") == "¢36"
68 | assert std("three euros and sixty five cents") == "€3.65"
69 |
70 | assert std("three and a half million") == "3500000"
71 | assert std("forty eight and a half dollars") == "$48.5"
72 | assert std("b747") == "b 747"
73 | assert std("10 th") == "10th"
74 | assert std("10th") == "10th"
75 |
76 |
77 | def test_spelling_normalizer():
78 | std = EnglishSpellingNormalizer()
79 |
80 | assert std("mobilisation") == "mobilization"
81 | assert std("cancelation") == "cancellation"
82 |
83 |
84 | def test_text_normalizer():
85 | std = EnglishTextNormalizer()
86 | assert std("Let's") == "let us"
87 | assert std("he's like") == "he is like"
88 | assert std("she's been like") == "she has been like"
89 | assert std("10km") == "10 km"
90 | assert std("10mm") == "10 mm"
91 | assert std("RC232") == "rc 232"
92 |
93 | assert (
94 | std("Mr. Park visited Assoc. Prof. Kim Jr.")
95 | == "mister park visited associate professor kim junior"
96 | )
97 |
--------------------------------------------------------------------------------
/tests/test_timing.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import scipy.ndimage
4 | import torch
5 |
6 | from whisper.timing import dtw_cpu, dtw_cuda, median_filter
7 |
8 | sizes = [
9 | (10, 20),
10 | (32, 16),
11 | (123, 1500),
12 | (234, 189),
13 | ]
14 | shapes = [
15 | (10,),
16 | (1, 15),
17 | (4, 5, 345),
18 | (6, 12, 240, 512),
19 | ]
20 |
21 |
22 | @pytest.mark.parametrize("N, M", sizes)
23 | def test_dtw(N: int, M: int):
24 | steps = np.concatenate([np.zeros(N - 1), np.ones(M - 1)])
25 | np.random.shuffle(steps)
26 | x = np.random.random((N, M)).astype(np.float32)
27 |
28 | i, j, k = 0, 0, 0
29 | trace = []
30 | while True:
31 | x[i, j] -= 1
32 | trace.append((i, j))
33 |
34 | if k == len(steps):
35 | break
36 |
37 | if k + 1 < len(steps) and steps[k] != steps[k + 1]:
38 | i += 1
39 | j += 1
40 | k += 2
41 | continue
42 |
43 | if steps[k] == 0:
44 | i += 1
45 | if steps[k] == 1:
46 | j += 1
47 | k += 1
48 |
49 | trace = np.array(trace).T
50 | dtw_trace = dtw_cpu(x)
51 |
52 | assert np.allclose(trace, dtw_trace)
53 |
54 |
55 | @pytest.mark.requires_cuda
56 | @pytest.mark.parametrize("N, M", sizes)
57 | def test_dtw_cuda_equivalence(N: int, M: int):
58 | x_numpy = np.random.randn(N, M).astype(np.float32)
59 | x_cuda = torch.from_numpy(x_numpy).cuda()
60 |
61 | trace_cpu = dtw_cpu(x_numpy)
62 | trace_cuda = dtw_cuda(x_cuda)
63 |
64 | assert np.allclose(trace_cpu, trace_cuda)
65 |
66 |
67 | @pytest.mark.parametrize("shape", shapes)
68 | def test_median_filter(shape):
69 | x = torch.randn(*shape)
70 |
71 | for filter_width in [3, 5, 7, 13]:
72 | filtered = median_filter(x, filter_width)
73 |
74 | # using np.pad to reflect-pad, because Scipy's behavior is different near the edges.
75 | pad_width = filter_width // 2
76 | padded_x = np.pad(
77 | x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect"
78 | )
79 | scipy_filtered = scipy.ndimage.median_filter(
80 | padded_x, [1] * (x.ndim - 1) + [filter_width]
81 | )
82 | scipy_filtered = scipy_filtered[..., pad_width:-pad_width]
83 |
84 | assert np.allclose(filtered, scipy_filtered)
85 |
86 |
87 | @pytest.mark.requires_cuda
88 | @pytest.mark.parametrize("shape", shapes)
89 | def test_median_filter_equivalence(shape):
90 | x = torch.randn(*shape)
91 |
92 | for filter_width in [3, 5, 7, 13]:
93 | filtered_cpu = median_filter(x, filter_width)
94 | filtered_gpu = median_filter(x.cuda(), filter_width).cpu()
95 |
96 | assert np.allclose(filtered_cpu, filtered_gpu)
97 |
--------------------------------------------------------------------------------
/tests/test_tokenizer.py:
--------------------------------------------------------------------------------
1 | from whisper.tokenizer import get_tokenizer
2 |
3 |
4 | def test_tokenizer():
5 | gpt2_tokenizer = get_tokenizer(multilingual=False)
6 | multilingual_tokenizer = get_tokenizer(multilingual=True)
7 |
8 | text = "다람쥐 헌 쳇바퀴에 타고파"
9 | gpt2_tokens = gpt2_tokenizer.encode(text)
10 | multilingual_tokens = multilingual_tokenizer.encode(text)
11 |
12 | assert gpt2_tokenizer.decode(gpt2_tokens) == text
13 | assert multilingual_tokenizer.decode(multilingual_tokens) == text
14 | assert len(gpt2_tokens) > len(multilingual_tokens)
15 |
16 |
17 | def test_split_on_unicode():
18 | multilingual_tokenizer = get_tokenizer(multilingual=True)
19 |
20 | tokens = [8404, 871, 287, 6, 246, 526, 3210, 20378]
21 | words, word_tokens = multilingual_tokenizer.split_tokens_on_unicode(tokens)
22 |
23 | assert words == [" elle", " est", " l", "'", "�", "é", "rit", "oire"]
24 | assert word_tokens == [[8404], [871], [287], [6], [246], [526], [3210], [20378]]
25 |
--------------------------------------------------------------------------------
/tests/test_transcribe.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 | import torch
5 | import whisper
6 | from whisper.tokenizer import get_tokenizer
7 |
8 |
9 |
10 | @pytest.mark.parametrize("model_name", whisper.available_models())
11 | def test_transcribe(model_name: str):
12 | device = "cuda" if torch.cuda.is_available() else "cpu"
13 | model = whisper.load_model(model_name).to(device)
14 |
15 | audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
16 |
17 | language = "en" if model_name.endswith(".en") else None
18 | result = model.transcribe(
19 | audio_path, language=language, temperature=0.0, word_timestamps=True
20 | )
21 | assert result["language"] == "en"
22 | assert result["text"] == "".join([s["text"] for s in result["segments"]])
23 |
24 | transcription = result["text"].lower()
25 | assert "my fellow americans" in transcription
26 | assert "your country" in transcription
27 | assert "do for you" in transcription
28 |
29 | tokenizer = get_tokenizer(model.is_multilingual)
30 | all_tokens = [t for s in result["segments"] for t in s["tokens"]]
31 | assert tokenizer.decode(all_tokens) == result["text"]
32 | assert tokenizer.decode_with_timestamps(all_tokens).startswith("<|0.00|>")
33 |
34 | timing_checked = False
35 | for segment in result["segments"]:
36 | for timing in segment["words"]:
37 | assert timing["start"] < timing["end"]
38 | if timing["word"].strip(" ,") == "Americans":
39 | assert timing["start"] <= 1.8
40 | assert timing["end"] >= 1.8
41 | timing_checked = True
42 |
43 | assert timing_checked
44 |
--------------------------------------------------------------------------------
/whisper/__init__.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import io
3 | import os
4 | import urllib
5 | import warnings
6 | from typing import List, Optional, Union
7 |
8 | import torch
9 | from tqdm import tqdm
10 |
11 | from .audio import load_audio, log_mel_spectrogram, pad_or_trim
12 | from .decoding import DecodingOptions, DecodingResult, decode, detect_language
13 | from .model import ModelDimensions, Whisper
14 | from .transcribe import transcribe
15 | from .version import __version__
16 |
17 | _MODELS = {
18 | "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
19 | "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
20 | "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
21 | "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
22 | "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
23 | "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
24 | "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
25 | "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
26 | "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
27 | "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
28 | "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
29 | }
30 |
31 | # base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
32 | # highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
33 | _ALIGNMENT_HEADS = {
34 | "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
35 | "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
36 | "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
37 | "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
39 | "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P%R7%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
42 | "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
44 | "large": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
45 | }
46 |
47 |
48 | def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
49 | os.makedirs(root, exist_ok=True)
50 |
51 | expected_sha256 = url.split("/")[-2]
52 | download_target = os.path.join(root, os.path.basename(url))
53 |
54 | if os.path.exists(download_target) and not os.path.isfile(download_target):
55 | raise RuntimeError(f"{download_target} exists and is not a regular file")
56 |
57 | if os.path.isfile(download_target):
58 | with open(download_target, "rb") as f:
59 | model_bytes = f.read()
60 | if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
61 | return model_bytes if in_memory else download_target
62 | else:
63 | warnings.warn(
64 | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
65 | )
66 |
67 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
68 | with tqdm(
69 | total=int(source.info().get("Content-Length")),
70 | ncols=80,
71 | unit="iB",
72 | unit_scale=True,
73 | unit_divisor=1024,
74 | ) as loop:
75 | while True:
76 | buffer = source.read(8192)
77 | if not buffer:
78 | break
79 |
80 | output.write(buffer)
81 | loop.update(len(buffer))
82 |
83 | model_bytes = open(download_target, "rb").read()
84 | if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
85 | raise RuntimeError(
86 | "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
87 | )
88 |
89 | return model_bytes if in_memory else download_target
90 |
91 |
92 | def available_models() -> List[str]:
93 | """Returns the names of available models"""
94 | return list(_MODELS.keys())
95 |
96 |
97 | def load_model(
98 | name: str,
99 | device: Optional[Union[str, torch.device]] = None,
100 | download_root: str = None,
101 | in_memory: bool = False,
102 | ) -> Whisper:
103 | """
104 | Load a Whisper ASR model
105 |
106 | Parameters
107 | ----------
108 | name : str
109 | one of the official model names listed by `whisper.available_models()`, or
110 | path to a model checkpoint containing the model dimensions and the model state_dict.
111 | device : Union[str, torch.device]
112 | the PyTorch device to put the model into
113 | download_root: str
114 | path to download the model files; by default, it uses "~/.cache/whisper"
115 | in_memory: bool
116 | whether to preload the model weights into host memory
117 |
118 | Returns
119 | -------
120 | model : Whisper
121 | The Whisper ASR model instance
122 | """
123 |
124 | if device is None:
125 | device = "cuda" if torch.cuda.is_available() else "cpu"
126 | if download_root is None:
127 | default = os.path.join(os.path.expanduser("~"), ".cache")
128 | download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
129 |
130 | if name in _MODELS:
131 | checkpoint_file = _download(_MODELS[name], download_root, in_memory)
132 | alignment_heads = _ALIGNMENT_HEADS[name]
133 | elif os.path.isfile(name):
134 | checkpoint_file = open(name, "rb").read() if in_memory else name
135 | alignment_heads = None
136 | else:
137 | raise RuntimeError(
138 | f"Model {name} not found; available models = {available_models()}"
139 | )
140 |
141 | with (
142 | io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
143 | ) as fp:
144 | checkpoint = torch.load(fp, map_location=device)
145 | del checkpoint_file
146 |
147 | dims = ModelDimensions(**checkpoint["dims"])
148 | model = Whisper(dims)
149 | model.load_state_dict(checkpoint["model_state_dict"])
150 |
151 | if alignment_heads is not None:
152 | model.set_alignment_heads(alignment_heads)
153 |
154 | return model.to(device)
155 |
--------------------------------------------------------------------------------
/whisper/__main__.py:
--------------------------------------------------------------------------------
1 | from .transcribe import cli
2 |
3 | cli()
4 |
--------------------------------------------------------------------------------
/whisper/assets/mel_filters.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProjectEGU/whisper-for-low-vram/6175c21f63450a971ee75428e1dc4aeb5d1953b4/whisper/assets/mel_filters.npz
--------------------------------------------------------------------------------
/whisper/audio.py:
--------------------------------------------------------------------------------
1 | import os
2 | from functools import lru_cache
3 | from typing import Optional, Union
4 |
5 | import ffmpeg
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | from .utils import exact_div
11 |
12 | # hard-coded audio hyperparameters
13 | SAMPLE_RATE = 16000
14 | N_FFT = 400
15 | N_MELS = 80
16 | HOP_LENGTH = 160
17 | CHUNK_LENGTH = 30
18 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
19 | N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
20 |
21 | N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
22 | FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
23 | TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
24 |
25 |
26 | def load_audio(file: str, sr: int = SAMPLE_RATE):
27 | """
28 | Open an audio file and read as mono waveform, resampling as necessary
29 |
30 | Parameters
31 | ----------
32 | file: str
33 | The audio file to open
34 |
35 | sr: int
36 | The sample rate to resample the audio if necessary
37 |
38 | Returns
39 | -------
40 | A NumPy array containing the audio waveform, in float32 dtype.
41 | """
42 | try:
43 | # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
44 | # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
45 | out, _ = (
46 | ffmpeg.input(file, threads=0)
47 | .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
48 | .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
49 | )
50 | except ffmpeg.Error as e:
51 | raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
52 |
53 | return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
54 |
55 |
56 | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
57 | """
58 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
59 | """
60 | if torch.is_tensor(array):
61 | if array.shape[axis] > length:
62 | array = array.index_select(
63 | dim=axis, index=torch.arange(length, device=array.device)
64 | )
65 |
66 | if array.shape[axis] < length:
67 | pad_widths = [(0, 0)] * array.ndim
68 | pad_widths[axis] = (0, length - array.shape[axis])
69 | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
70 | else:
71 | if array.shape[axis] > length:
72 | array = array.take(indices=range(length), axis=axis)
73 |
74 | if array.shape[axis] < length:
75 | pad_widths = [(0, 0)] * array.ndim
76 | pad_widths[axis] = (0, length - array.shape[axis])
77 | array = np.pad(array, pad_widths)
78 |
79 | return array
80 |
81 |
82 | @lru_cache(maxsize=None)
83 | def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
84 | """
85 | load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
86 | Allows decoupling librosa dependency; saved using:
87 |
88 | np.savez_compressed(
89 | "mel_filters.npz",
90 | mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
91 | )
92 | """
93 | assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
94 | with np.load(
95 | os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
96 | ) as f:
97 | return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
98 |
99 |
100 | def log_mel_spectrogram(
101 | audio: Union[str, np.ndarray, torch.Tensor],
102 | n_mels: int = N_MELS,
103 | padding: int = 0,
104 | device: Optional[Union[str, torch.device]] = None,
105 | ):
106 | """
107 | Compute the log-Mel spectrogram of
108 |
109 | Parameters
110 | ----------
111 | audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
112 | The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
113 |
114 | n_mels: int
115 | The number of Mel-frequency filters, only 80 is supported
116 |
117 | padding: int
118 | Number of zero samples to pad to the right
119 |
120 | device: Optional[Union[str, torch.device]]
121 | If given, the audio tensor is moved to this device before STFT
122 |
123 | Returns
124 | -------
125 | torch.Tensor, shape = (80, n_frames)
126 | A Tensor that contains the Mel spectrogram
127 | """
128 | if not torch.is_tensor(audio):
129 | if isinstance(audio, str):
130 | audio = load_audio(audio)
131 | audio = torch.from_numpy(audio)
132 |
133 | if device is not None:
134 | audio = audio.to(device)
135 | if padding > 0:
136 | audio = F.pad(audio, (0, padding))
137 | window = torch.hann_window(N_FFT).to(audio.device)
138 | stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
139 | magnitudes = stft[..., :-1].abs() ** 2
140 |
141 | filters = mel_filters(audio.device, n_mels)
142 | mel_spec = filters @ magnitudes
143 |
144 | log_spec = torch.clamp(mel_spec, min=1e-10).log10()
145 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
146 | log_spec = (log_spec + 4.0) / 4.0
147 | return log_spec
148 |
--------------------------------------------------------------------------------
/whisper/decoding.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field, replace
2 | from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import Tensor
8 | from torch.distributions import Categorical
9 |
10 | from .audio import CHUNK_LENGTH
11 | from .tokenizer import Tokenizer, get_tokenizer
12 | from .utils import compression_ratio
13 |
14 | if TYPE_CHECKING:
15 | from .model import Whisper
16 |
17 |
18 | @torch.no_grad()
19 | def detect_language(
20 | model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
21 | ) -> Tuple[Tensor, List[dict]]:
22 | """
23 | Detect the spoken language in the audio, and return them as list of strings, along with the ids
24 | of the most probable language tokens and the probability distribution over all language tokens.
25 | This is performed outside the main decode loop in order to not interfere with kv-caching.
26 |
27 | Returns
28 | -------
29 | language_tokens : Tensor, shape = (n_audio,)
30 | ids of the most probable language tokens, which appears after the startoftranscript token.
31 | language_probs : List[Dict[str, float]], length = n_audio
32 | list of dictionaries containing the probability distribution over all languages.
33 | """
34 | if tokenizer is None:
35 | tokenizer = get_tokenizer(model.is_multilingual)
36 | if (
37 | tokenizer.language is None
38 | or tokenizer.language_token not in tokenizer.sot_sequence
39 | ):
40 | raise ValueError(
41 | "This model doesn't have language tokens so it can't perform lang id"
42 | )
43 |
44 | single = mel.ndim == 2
45 | if single:
46 | mel = mel.unsqueeze(0)
47 |
48 | # skip encoder forward pass if already-encoded audio features were given
49 | if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
50 | mel = model.encoder(mel)
51 |
52 | # forward pass using a single token, startoftranscript
53 | n_audio = mel.shape[0]
54 | x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
55 | logits = model.logits(x, mel)[:, 0]
56 |
57 | # collect detected languages; suppress all non-language tokens
58 | mask = torch.ones(logits.shape[-1], dtype=torch.bool)
59 | mask[list(tokenizer.all_language_tokens)] = False
60 | logits[:, mask] = -np.inf
61 | language_tokens = logits.argmax(dim=-1)
62 | language_token_probs = logits.softmax(dim=-1).cpu()
63 | language_probs = [
64 | {
65 | c: language_token_probs[i, j].item()
66 | for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
67 | }
68 | for i in range(n_audio)
69 | ]
70 |
71 | if single:
72 | language_tokens = language_tokens[0]
73 | language_probs = language_probs[0]
74 |
75 | return language_tokens, language_probs
76 |
77 |
78 | @dataclass(frozen=True)
79 | class DecodingOptions:
80 | # whether to perform X->X "transcribe" or X->English "translate"
81 | task: str = "transcribe"
82 |
83 | # language that the audio is in; uses detected language if None
84 | language: Optional[str] = None
85 |
86 | # sampling-related options
87 | temperature: float = 0.0
88 | sample_len: Optional[int] = None # maximum number of tokens to sample
89 | best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
90 | beam_size: Optional[int] = None # number of beams in beam search, if t == 0
91 | patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
92 |
93 | # "alpha" in Google NMT, or None for length norm, when ranking generations
94 | # to select which to return among the beams or best-of-N samples
95 | length_penalty: Optional[float] = None
96 |
97 | # text or tokens to feed as the prompt or the prefix; for more info:
98 | # https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
99 | prompt: Optional[Union[str, List[int]]] = None # for the previous context
100 | prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
101 |
102 | # list of tokens ids (or comma-separated token ids) to suppress
103 | # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
104 | suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
105 | suppress_blank: bool = True # this will suppress blank outputs
106 |
107 | # timestamp sampling options
108 | without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
109 | max_initial_timestamp: Optional[float] = 1.0
110 |
111 | # implementation details
112 | fp16: bool = True # use fp16 for most of the calculation
113 |
114 |
115 | @dataclass(frozen=True)
116 | class DecodingResult:
117 | audio_features: Tensor
118 | language: str
119 | language_probs: Optional[Dict[str, float]] = None
120 | tokens: List[int] = field(default_factory=list)
121 | text: str = ""
122 | avg_logprob: float = np.nan
123 | no_speech_prob: float = np.nan
124 | temperature: float = np.nan
125 | compression_ratio: float = np.nan
126 |
127 |
128 | class Inference:
129 | def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
130 | """Perform a forward pass on the decoder and return per-token logits"""
131 | raise NotImplementedError
132 |
133 | def rearrange_kv_cache(self, source_indices) -> None:
134 | """Update the key-value cache according to the updated beams"""
135 | raise NotImplementedError
136 |
137 | def cleanup_caching(self) -> None:
138 | """Clean up any resources or hooks after decoding is finished"""
139 | pass
140 |
141 |
142 | class PyTorchInference(Inference):
143 | def __init__(self, model: "Whisper", initial_token_length: int):
144 | self.model: "Whisper" = model
145 | self.initial_token_length = initial_token_length
146 | self.kv_cache = {}
147 | self.hooks = []
148 |
149 | def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
150 | if not self.kv_cache:
151 | self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
152 |
153 | if tokens.shape[-1] > self.initial_token_length:
154 | # only need to use the last token except in the first forward pass
155 | tokens = tokens[:, -1:]
156 |
157 | return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
158 |
159 | def cleanup_caching(self):
160 | for hook in self.hooks:
161 | hook.remove()
162 |
163 | self.kv_cache = {}
164 | self.hooks = []
165 |
166 | def rearrange_kv_cache(self, source_indices):
167 | for module, tensor in self.kv_cache.items():
168 | # update the key/value cache to contain the selected sequences
169 | self.kv_cache[module] = tensor[source_indices].detach()
170 |
171 |
172 | class SequenceRanker:
173 | def rank(
174 | self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]
175 | ) -> List[int]:
176 | """
177 | Given a list of groups of samples and their cumulative log probabilities,
178 | return the indices of the samples in each group to select as the final result
179 | """
180 | raise NotImplementedError
181 |
182 |
183 | class MaximumLikelihoodRanker(SequenceRanker):
184 | """
185 | Select the sample with the highest log probabilities, penalized using either
186 | a simple length normalization or Google NMT paper's length penalty
187 | """
188 |
189 | def __init__(self, length_penalty: Optional[float]):
190 | self.length_penalty = length_penalty
191 |
192 | def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
193 | def scores(logprobs, lengths):
194 | result = []
195 | for logprob, length in zip(logprobs, lengths):
196 | if self.length_penalty is None:
197 | penalty = length
198 | else:
199 | # from the Google NMT paper
200 | penalty = ((5 + length) / 6) ** self.length_penalty
201 | result.append(logprob / penalty)
202 | return result
203 |
204 | # get the sequence with the highest score
205 | lengths = [[len(t) for t in s] for s in tokens]
206 | return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
207 |
208 |
209 | class TokenDecoder:
210 | def reset(self):
211 | """Initialize any stateful variables for decoding a new sequence"""
212 |
213 | def update(
214 | self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
215 | ) -> Tuple[Tensor, bool]:
216 | """Specify how to select the next token, based on the current trace and logits
217 |
218 | Parameters
219 | ----------
220 | tokens : Tensor, shape = (n_batch, current_sequence_length)
221 | all tokens in the context so far, including the prefix and sot_sequence tokens
222 |
223 | logits : Tensor, shape = (n_batch, vocab_size)
224 | per-token logits of the probability distribution at the current step
225 |
226 | sum_logprobs : Tensor, shape = (n_batch)
227 | cumulative log probabilities for each sequence
228 |
229 | Returns
230 | -------
231 | tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
232 | the tokens, appended with the selected next token
233 |
234 | completed : bool
235 | True if all sequences has reached the end of text
236 |
237 | """
238 | raise NotImplementedError
239 |
240 | def finalize(
241 | self, tokens: Tensor, sum_logprobs: Tensor
242 | ) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
243 | """Finalize search and return the final candidate sequences
244 |
245 | Parameters
246 | ----------
247 | tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
248 | all tokens in the context so far, including the prefix and sot_sequence
249 |
250 | sum_logprobs : Tensor, shape = (n_audio, n_group)
251 | cumulative log probabilities for each sequence
252 |
253 | Returns
254 | -------
255 | tokens : Sequence[Sequence[Tensor]], length = n_audio
256 | sequence of Tensors containing candidate token sequences, for each audio input
257 |
258 | sum_logprobs : List[List[float]], length = n_audio
259 | sequence of cumulative log probabilities corresponding to the above
260 |
261 | """
262 | raise NotImplementedError
263 |
264 |
265 | class GreedyDecoder(TokenDecoder):
266 | def __init__(self, temperature: float, eot: int):
267 | self.temperature = temperature
268 | self.eot = eot
269 |
270 | def update(
271 | self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
272 | ) -> Tuple[Tensor, bool]:
273 | if self.temperature == 0:
274 | next_tokens = logits.argmax(dim=-1)
275 | else:
276 | next_tokens = Categorical(logits=logits / self.temperature).sample()
277 |
278 | logprobs = F.log_softmax(logits.float(), dim=-1)
279 | current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
280 | sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
281 |
282 | next_tokens[tokens[:, -1] == self.eot] = self.eot
283 | tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
284 |
285 | completed = (tokens[:, -1] == self.eot).all()
286 | return tokens, completed
287 |
288 | def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
289 | # make sure each sequence has at least one EOT token at the end
290 | tokens = F.pad(tokens, (0, 1), value=self.eot)
291 | return tokens, sum_logprobs.tolist()
292 |
293 |
294 | class BeamSearchDecoder(TokenDecoder):
295 | def __init__(
296 | self,
297 | beam_size: int,
298 | eot: int,
299 | inference: Inference,
300 | patience: Optional[float] = None,
301 | ):
302 | self.beam_size = beam_size
303 | self.eot = eot
304 | self.inference = inference
305 | self.patience = patience or 1.0
306 | self.max_candidates: int = round(beam_size * self.patience)
307 | self.finished_sequences = None
308 |
309 | assert (
310 | self.max_candidates > 0
311 | ), f"Invalid beam size ({beam_size}) or patience ({patience})"
312 |
313 | def reset(self):
314 | self.finished_sequences = None
315 |
316 | def update(
317 | self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
318 | ) -> Tuple[Tensor, bool]:
319 | if tokens.shape[0] % self.beam_size != 0:
320 | raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
321 |
322 | n_audio = tokens.shape[0] // self.beam_size
323 | if self.finished_sequences is None: # for the first update
324 | self.finished_sequences = [{} for _ in range(n_audio)]
325 |
326 | logprobs = F.log_softmax(logits.float(), dim=-1)
327 | next_tokens, source_indices, finished_sequences = [], [], []
328 | for i in range(n_audio):
329 | scores, sources, finished = {}, {}, {}
330 |
331 | # STEP 1: calculate the cumulative log probabilities for possible candidates
332 | for j in range(self.beam_size):
333 | idx = i * self.beam_size + j
334 | prefix = tokens[idx].tolist()
335 | for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
336 | new_logprob = (sum_logprobs[idx] + logprob).item()
337 | sequence = tuple(prefix + [token.item()])
338 | scores[sequence] = new_logprob
339 | sources[sequence] = idx
340 |
341 | # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
342 | saved = 0
343 | for sequence in sorted(scores, key=scores.get, reverse=True):
344 | if sequence[-1] == self.eot:
345 | finished[sequence] = scores[sequence]
346 | else:
347 | sum_logprobs[len(next_tokens)] = scores[sequence]
348 | next_tokens.append(sequence)
349 | source_indices.append(sources[sequence])
350 |
351 | saved += 1
352 | if saved == self.beam_size:
353 | break
354 |
355 | finished_sequences.append(finished)
356 |
357 | tokens = torch.tensor(next_tokens, device=tokens.device)
358 | self.inference.rearrange_kv_cache(source_indices)
359 |
360 | # add newly finished sequences to self.finished_sequences
361 | assert len(self.finished_sequences) == len(finished_sequences)
362 | for previously_finished, newly_finished in zip(
363 | self.finished_sequences, finished_sequences
364 | ):
365 | for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
366 | if len(previously_finished) >= self.max_candidates:
367 | break # the candidate list is full
368 | previously_finished[seq] = newly_finished[seq]
369 |
370 | # mark as completed if all audio has enough number of samples
371 | completed = all(
372 | len(sequences) >= self.max_candidates
373 | for sequences in self.finished_sequences
374 | )
375 | return tokens, completed
376 |
377 | def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
378 | # collect all finished sequences, including patience, and add unfinished ones if not enough
379 | sum_logprobs = sum_logprobs.cpu()
380 | for i, sequences in enumerate(self.finished_sequences):
381 | if (
382 | len(sequences) < self.beam_size
383 | ): # when not enough sequences are finished
384 | for j in list(np.argsort(sum_logprobs[i]))[::-1]:
385 | sequence = preceding_tokens[i, j].tolist() + [self.eot]
386 | sequences[tuple(sequence)] = sum_logprobs[i][j].item()
387 | if len(sequences) >= self.beam_size:
388 | break
389 |
390 | tokens: List[List[Tensor]] = [
391 | [torch.tensor(seq) for seq in sequences.keys()]
392 | for sequences in self.finished_sequences
393 | ]
394 | sum_logprobs: List[List[float]] = [
395 | list(sequences.values()) for sequences in self.finished_sequences
396 | ]
397 | return tokens, sum_logprobs
398 |
399 |
400 | class LogitFilter:
401 | def apply(self, logits: Tensor, tokens: Tensor) -> None:
402 | """Apply any filtering or masking to logits in-place
403 |
404 | Parameters
405 | ----------
406 | logits : Tensor, shape = (n_batch, vocab_size)
407 | per-token logits of the probability distribution at the current step
408 |
409 | tokens : Tensor, shape = (n_batch, current_sequence_length)
410 | all tokens in the context so far, including the prefix and sot_sequence tokens
411 |
412 | """
413 | raise NotImplementedError
414 |
415 |
416 | class SuppressBlank(LogitFilter):
417 | def __init__(self, tokenizer: Tokenizer, sample_begin: int):
418 | self.tokenizer = tokenizer
419 | self.sample_begin = sample_begin
420 |
421 | def apply(self, logits: Tensor, tokens: Tensor):
422 | if tokens.shape[1] == self.sample_begin:
423 | logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
424 |
425 |
426 | class SuppressTokens(LogitFilter):
427 | def __init__(self, suppress_tokens: Sequence[int]):
428 | self.suppress_tokens = list(suppress_tokens)
429 |
430 | def apply(self, logits: Tensor, tokens: Tensor):
431 | logits[:, self.suppress_tokens] = -np.inf
432 |
433 |
434 | class ApplyTimestampRules(LogitFilter):
435 | def __init__(
436 | self,
437 | tokenizer: Tokenizer,
438 | sample_begin: int,
439 | max_initial_timestamp_index: Optional[int],
440 | ):
441 | self.tokenizer = tokenizer
442 | self.sample_begin = sample_begin
443 | self.max_initial_timestamp_index = max_initial_timestamp_index
444 |
445 | def apply(self, logits: Tensor, tokens: Tensor):
446 | # suppress <|notimestamps|> which is handled by without_timestamps
447 | if self.tokenizer.no_timestamps is not None:
448 | logits[:, self.tokenizer.no_timestamps] = -np.inf
449 |
450 | # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
451 | for k in range(tokens.shape[0]):
452 | sampled_tokens = tokens[k, self.sample_begin :]
453 | seq = [t for t in sampled_tokens.tolist()]
454 | last_was_timestamp = (
455 | len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
456 | )
457 | penultimate_was_timestamp = (
458 | len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
459 | )
460 |
461 | if last_was_timestamp:
462 | if penultimate_was_timestamp: # has to be non-timestamp
463 | logits[k, self.tokenizer.timestamp_begin :] = -np.inf
464 | else: # cannot be normal text tokens
465 | logits[k, : self.tokenizer.eot] = -np.inf
466 |
467 | timestamps = sampled_tokens[
468 | sampled_tokens.ge(self.tokenizer.timestamp_begin)
469 | ]
470 | if timestamps.numel() > 0:
471 | # timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
472 | logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf
473 |
474 | if tokens.shape[1] == self.sample_begin:
475 | # suppress generating non-timestamp tokens at the beginning
476 | logits[:, : self.tokenizer.timestamp_begin] = -np.inf
477 |
478 | # apply the `max_initial_timestamp` option
479 | if self.max_initial_timestamp_index is not None:
480 | last_allowed = (
481 | self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
482 | )
483 | logits[:, last_allowed + 1 :] = -np.inf
484 |
485 | # if sum of probability over timestamps is above any other token, sample timestamp
486 | logprobs = F.log_softmax(logits.float(), dim=-1)
487 | for k in range(tokens.shape[0]):
488 | timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
489 | dim=-1
490 | )
491 | max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
492 | if timestamp_logprob > max_text_token_logprob:
493 | logits[k, : self.tokenizer.timestamp_begin] = -np.inf
494 |
495 |
496 | class DecodingTask:
497 | inference: Inference
498 | sequence_ranker: SequenceRanker
499 | decoder: TokenDecoder
500 | logit_filters: List[LogitFilter]
501 |
502 | def __init__(self, model: "Whisper", options: DecodingOptions):
503 | self.model = model
504 |
505 | language = options.language or "en"
506 | tokenizer = get_tokenizer(
507 | model.is_multilingual, language=language, task=options.task
508 | )
509 | self.tokenizer: Tokenizer = tokenizer
510 | self.options: DecodingOptions = self._verify_options(options)
511 |
512 | self.n_group: int = options.beam_size or options.best_of or 1
513 | self.n_ctx: int = model.dims.n_text_ctx
514 | self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
515 |
516 | self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
517 | if self.options.without_timestamps:
518 | self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
519 |
520 | self.initial_tokens: Tuple[int] = self._get_initial_tokens()
521 | self.sample_begin: int = len(self.initial_tokens)
522 | self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
523 |
524 | # inference: implements the forward pass through the decoder, including kv caching
525 | self.inference = PyTorchInference(model, len(self.initial_tokens))
526 |
527 | # sequence ranker: implements how to rank a group of sampled sequences
528 | self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
529 |
530 | # decoder: implements how to select the next tokens, given the autoregressive distribution
531 | if options.beam_size is not None:
532 | self.decoder = BeamSearchDecoder(
533 | options.beam_size, tokenizer.eot, self.inference, options.patience
534 | )
535 | else:
536 | self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
537 |
538 | # logit filters: applies various rules to suppress or penalize certain tokens
539 | self.logit_filters = []
540 | if self.options.suppress_blank:
541 | self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
542 | if self.options.suppress_tokens:
543 | self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
544 | if not options.without_timestamps:
545 | precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
546 | max_initial_timestamp_index = None
547 | if options.max_initial_timestamp:
548 | max_initial_timestamp_index = round(
549 | self.options.max_initial_timestamp / precision
550 | )
551 | self.logit_filters.append(
552 | ApplyTimestampRules(
553 | tokenizer, self.sample_begin, max_initial_timestamp_index
554 | )
555 | )
556 |
557 | def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
558 | if options.beam_size is not None and options.best_of is not None:
559 | raise ValueError("beam_size and best_of can't be given together")
560 | if options.temperature == 0:
561 | if options.best_of is not None:
562 | raise ValueError("best_of with greedy sampling (T=0) is not compatible")
563 | if options.patience is not None and options.beam_size is None:
564 | raise ValueError("patience requires beam_size to be given")
565 | if options.length_penalty is not None and not (
566 | 0 <= options.length_penalty <= 1
567 | ):
568 | raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
569 |
570 | return options
571 |
572 | def _get_initial_tokens(self) -> Tuple[int]:
573 | tokens = list(self.sot_sequence)
574 |
575 | if prefix := self.options.prefix:
576 | prefix_tokens = (
577 | self.tokenizer.encode(" " + prefix.strip())
578 | if isinstance(prefix, str)
579 | else prefix
580 | )
581 | if self.sample_len is not None:
582 | max_prefix_len = self.n_ctx // 2 - self.sample_len
583 | prefix_tokens = prefix_tokens[-max_prefix_len:]
584 | tokens = tokens + prefix_tokens
585 |
586 | if prompt := self.options.prompt:
587 | prompt_tokens = (
588 | self.tokenizer.encode(" " + prompt.strip())
589 | if isinstance(prompt, str)
590 | else prompt
591 | )
592 | tokens = (
593 | [self.tokenizer.sot_prev]
594 | + prompt_tokens[-(self.n_ctx // 2 - 1) :]
595 | + tokens
596 | )
597 |
598 | return tuple(tokens)
599 |
600 | def _get_suppress_tokens(self) -> Tuple[int]:
601 | suppress_tokens = self.options.suppress_tokens
602 |
603 | if isinstance(suppress_tokens, str):
604 | suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
605 |
606 | if -1 in suppress_tokens:
607 | suppress_tokens = [t for t in suppress_tokens if t >= 0]
608 | suppress_tokens.extend(self.tokenizer.non_speech_tokens)
609 | elif suppress_tokens is None or len(suppress_tokens) == 0:
610 | suppress_tokens = [] # interpret empty string as an empty list
611 | else:
612 | assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
613 |
614 | suppress_tokens.extend(
615 | [
616 | self.tokenizer.transcribe,
617 | self.tokenizer.translate,
618 | self.tokenizer.sot,
619 | self.tokenizer.sot_prev,
620 | self.tokenizer.sot_lm,
621 | ]
622 | )
623 | if self.tokenizer.no_speech is not None:
624 | # no-speech probability is collected separately
625 | suppress_tokens.append(self.tokenizer.no_speech)
626 |
627 | return tuple(sorted(set(suppress_tokens)))
628 |
629 | def _get_audio_features(self, mel: Tensor):
630 | if self.options.fp16:
631 | mel = mel.half()
632 |
633 | if mel.shape[-2:] == (
634 | self.model.dims.n_audio_ctx,
635 | self.model.dims.n_audio_state,
636 | ):
637 | # encoded audio features are given; skip audio encoding
638 | audio_features = mel
639 | else:
640 | audio_features = self.model.encoder(mel)
641 |
642 | if audio_features.dtype != (
643 | torch.float16 if self.options.fp16 else torch.float32
644 | ):
645 | return TypeError(
646 | f"audio_features has an incorrect dtype: {audio_features.dtype}"
647 | )
648 |
649 | return audio_features
650 |
651 | def _detect_language(self, audio_features: Tensor, tokens: Tensor):
652 | languages = [self.options.language] * audio_features.shape[0]
653 | lang_probs = None
654 |
655 | if self.options.language is None or self.options.task == "lang_id":
656 | lang_tokens, lang_probs = self.model.detect_language(
657 | audio_features, self.tokenizer
658 | )
659 | languages = [max(probs, key=probs.get) for probs in lang_probs]
660 | if self.options.language is None:
661 | tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
662 |
663 | return languages, lang_probs
664 |
665 | def _main_loop(self, audio_features: Tensor, tokens: Tensor):
666 | assert audio_features.shape[0] == tokens.shape[0]
667 | n_batch = tokens.shape[0]
668 | sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
669 | no_speech_probs = [np.nan] * n_batch
670 |
671 | try:
672 | for i in range(self.sample_len):
673 | logits = self.inference.logits(tokens, audio_features)
674 |
675 | if (
676 | i == 0 and self.tokenizer.no_speech is not None
677 | ): # save no_speech_probs
678 | probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
679 | no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
680 |
681 | # now we need to consider the logits at the last token only
682 | logits = logits[:, -1]
683 |
684 | # apply the logit filters, e.g. for suppressing or applying penalty to
685 | for logit_filter in self.logit_filters:
686 | logit_filter.apply(logits, tokens)
687 |
688 | # expand the tokens tensor with the selected next tokens
689 | tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
690 |
691 | if completed or tokens.shape[-1] > self.n_ctx:
692 | break
693 | finally:
694 | self.inference.cleanup_caching()
695 |
696 | return tokens, sum_logprobs, no_speech_probs
697 |
698 | @torch.no_grad()
699 | def run(self, mel: Tensor) -> List[DecodingResult]:
700 | self.decoder.reset()
701 | tokenizer: Tokenizer = self.tokenizer
702 | n_audio: int = mel.shape[0]
703 |
704 | audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
705 | tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
706 |
707 | # detect language if requested, overwriting the language token
708 | languages, language_probs = self._detect_language(audio_features, tokens)
709 | if self.options.task == "lang_id":
710 | return [
711 | DecodingResult(
712 | audio_features=features, language=language, language_probs=probs
713 | )
714 | for features, language, probs in zip(
715 | audio_features, languages, language_probs
716 | )
717 | ]
718 |
719 | # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
720 | audio_features = audio_features.repeat_interleave(self.n_group, dim=0)
721 | tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
722 |
723 | # call the main sampling loop
724 | tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
725 |
726 | # reshape the tensors to have (n_audio, n_group) as the first two dimensions
727 | audio_features = audio_features[:: self.n_group]
728 | no_speech_probs = no_speech_probs[:: self.n_group]
729 | assert audio_features.shape[0] == len(no_speech_probs) == n_audio
730 |
731 | tokens = tokens.reshape(n_audio, self.n_group, -1)
732 | sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
733 |
734 | # get the final candidates for each group, and slice between the first sampled token and EOT
735 | tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
736 | tokens: List[List[Tensor]] = [
737 | [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
738 | for s in tokens
739 | ]
740 |
741 | # select the top-ranked sample in each group
742 | selected = self.sequence_ranker.rank(tokens, sum_logprobs)
743 | tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
744 | texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
745 |
746 | sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
747 | avg_logprobs: List[float] = [
748 | lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
749 | ]
750 |
751 | fields = (
752 | texts,
753 | languages,
754 | tokens,
755 | audio_features,
756 | avg_logprobs,
757 | no_speech_probs,
758 | )
759 | if len(set(map(len, fields))) != 1:
760 | raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
761 |
762 | return [
763 | DecodingResult(
764 | audio_features=features,
765 | language=language,
766 | tokens=tokens,
767 | text=text,
768 | avg_logprob=avg_logprob,
769 | no_speech_prob=no_speech_prob,
770 | temperature=self.options.temperature,
771 | compression_ratio=compression_ratio(text),
772 | )
773 | for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
774 | *fields
775 | )
776 | ]
777 |
778 |
779 | @torch.no_grad()
780 | def decode(
781 | model: "Whisper",
782 | mel: Tensor,
783 | options: DecodingOptions = DecodingOptions(),
784 | **kwargs,
785 | ) -> Union[DecodingResult, List[DecodingResult]]:
786 | """
787 | Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
788 |
789 | Parameters
790 | ----------
791 | model: Whisper
792 | the Whisper model instance
793 |
794 | mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
795 | A tensor containing the Mel spectrogram(s)
796 |
797 | options: DecodingOptions
798 | A dataclass that contains all necessary options for decoding 30-second segments
799 |
800 | Returns
801 | -------
802 | result: Union[DecodingResult, List[DecodingResult]]
803 | The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
804 | """
805 | if single := mel.ndim == 2:
806 | mel = mel.unsqueeze(0)
807 |
808 | if kwargs:
809 | options = replace(options, **kwargs)
810 |
811 | result = DecodingTask(model, options).run(mel)
812 |
813 | return result[0] if single else result
814 |
--------------------------------------------------------------------------------
/whisper/model.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import gzip
3 | from dataclasses import dataclass
4 | from typing import Dict, Iterable, Optional
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | from torch import Tensor, nn
10 |
11 | from .decoding import decode as decode_function
12 | from .decoding import detect_language as detect_language_function
13 | from .transcribe import transcribe as transcribe_function
14 |
15 | import gc, time
16 |
17 | @dataclass
18 | class ModelDimensions:
19 | n_mels: int
20 | n_audio_ctx: int
21 | n_audio_state: int
22 | n_audio_head: int
23 | n_audio_layer: int
24 | n_vocab: int
25 | n_text_ctx: int
26 | n_text_state: int
27 | n_text_head: int
28 | n_text_layer: int
29 |
30 |
31 | class LayerNorm(nn.LayerNorm):
32 | def forward(self, x: Tensor) -> Tensor:
33 | return super().forward(x.float()).type(x.dtype)
34 |
35 |
36 | class Linear(nn.Linear):
37 | def forward(self, x: Tensor) -> Tensor:
38 | return F.linear(
39 | x,
40 | self.weight.to(x.dtype),
41 | None if self.bias is None else self.bias.to(x.dtype),
42 | )
43 |
44 |
45 | class Conv1d(nn.Conv1d):
46 | def _conv_forward(
47 | self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
48 | ) -> Tensor:
49 | return super()._conv_forward(
50 | x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
51 | )
52 |
53 |
54 | def sinusoids(length, channels, max_timescale=10000):
55 | """Returns sinusoids for positional embedding"""
56 | assert channels % 2 == 0
57 | log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
58 | inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
59 | scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
60 | return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
61 |
62 |
63 | class MultiHeadAttention(nn.Module):
64 | def __init__(self, n_state: int, n_head: int):
65 | super().__init__()
66 | self.n_head = n_head
67 | self.query = Linear(n_state, n_state)
68 | self.key = Linear(n_state, n_state, bias=False)
69 | self.value = Linear(n_state, n_state)
70 | self.out = Linear(n_state, n_state)
71 |
72 | def forward(
73 | self,
74 | x: Tensor,
75 | xa: Optional[Tensor] = None,
76 | mask: Optional[Tensor] = None,
77 | kv_cache: Optional[dict] = None,
78 | ):
79 | q = self.query(x)
80 |
81 | if kv_cache is None or xa is None or self.key not in kv_cache:
82 | # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
83 | # otherwise, perform key/value projections for self- or cross-attention as usual.
84 | k = self.key(x if xa is None else xa)
85 | v = self.value(x if xa is None else xa)
86 | else:
87 | # for cross-attention, calculate keys and values once and reuse in subsequent calls.
88 | k = kv_cache[self.key]
89 | v = kv_cache[self.value]
90 |
91 | wv, qk = self.qkv_attention(q, k, v, mask)
92 | return self.out(wv), qk
93 |
94 | def qkv_attention(
95 | self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
96 | ):
97 | n_batch, n_ctx, n_state = q.shape
98 | scale = (n_state // self.n_head) ** -0.25
99 | q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
100 | k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
101 | v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
102 |
103 | qk = q @ k
104 | if mask is not None:
105 | qk = qk + mask[:n_ctx, :n_ctx]
106 | qk = qk.float()
107 |
108 | w = F.softmax(qk, dim=-1).to(q.dtype)
109 | return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
110 |
111 |
112 | class ResidualAttentionBlock(nn.Module):
113 | def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
114 | super().__init__()
115 |
116 | self.attn = MultiHeadAttention(n_state, n_head)
117 | self.attn_ln = LayerNorm(n_state)
118 |
119 | self.cross_attn = (
120 | MultiHeadAttention(n_state, n_head) if cross_attention else None
121 | )
122 | self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
123 |
124 | n_mlp = n_state * 4
125 | self.mlp = nn.Sequential(
126 | Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
127 | )
128 | self.mlp_ln = LayerNorm(n_state)
129 |
130 | def forward(
131 | self,
132 | x: Tensor,
133 | xa: Optional[Tensor] = None,
134 | mask: Optional[Tensor] = None,
135 | kv_cache: Optional[dict] = None,
136 | ):
137 | x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
138 | if self.cross_attn:
139 | x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
140 | x = x + self.mlp(self.mlp_ln(x))
141 | return x
142 |
143 |
144 | class AudioEncoder(nn.Module):
145 | def __init__(
146 | self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
147 | ):
148 | super().__init__()
149 | self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
150 | self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
151 | self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
152 |
153 | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
154 | [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
155 | )
156 | self.ln_post = LayerNorm(n_state)
157 |
158 | def forward(self, x: Tensor):
159 | """
160 | x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
161 | the mel spectrogram of the audio
162 | """
163 | x = F.gelu(self.conv1(x))
164 | x = F.gelu(self.conv2(x))
165 | x = x.permute(0, 2, 1)
166 |
167 | assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
168 | x = (x + self.positional_embedding).to(x.dtype)
169 |
170 | for block in self.blocks:
171 | x = block(x)
172 |
173 | x = self.ln_post(x)
174 | return x
175 |
176 |
177 | class TextDecoder(nn.Module):
178 | def __init__(
179 | self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
180 | ):
181 | super().__init__()
182 |
183 | self.token_embedding = nn.Embedding(n_vocab, n_state)
184 | self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
185 |
186 | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
187 | [
188 | ResidualAttentionBlock(n_state, n_head, cross_attention=True)
189 | for _ in range(n_layer)
190 | ]
191 | )
192 | self.ln = LayerNorm(n_state)
193 |
194 | mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
195 | self.register_buffer("mask", mask, persistent=False)
196 |
197 | def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
198 | """
199 | x : torch.LongTensor, shape = (batch_size, <= n_ctx)
200 | the text tokens
201 | xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
202 | the encoded audio features to be attended on
203 | """
204 | offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
205 | x = (
206 | self.token_embedding(x)
207 | + self.positional_embedding[offset : offset + x.shape[-1]]
208 | )
209 | x = x.to(xa.dtype)
210 |
211 | for block in self.blocks:
212 | x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
213 |
214 | x = self.ln(x)
215 | logits = (
216 | x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
217 | ).float()
218 |
219 | return logits
220 |
221 |
222 | class Whisper(nn.Module):
223 | def __init__(self, dims: ModelDimensions):
224 | super().__init__()
225 | self.dims = dims
226 | self.loaded_model = None
227 | self.loaded_model_obj = None
228 | self.target_device = None
229 | self._encoder = AudioEncoder(
230 | self.dims.n_mels,
231 | self.dims.n_audio_ctx,
232 | self.dims.n_audio_state,
233 | self.dims.n_audio_head,
234 | self.dims.n_audio_layer,
235 | )
236 | self._decoder = TextDecoder(
237 | self.dims.n_vocab,
238 | self.dims.n_text_ctx,
239 | self.dims.n_text_state,
240 | self.dims.n_text_head,
241 | self.dims.n_text_layer,
242 | )
243 |
244 | # use the last half layers for alignment by default; see `set_alignment_heads()` below
245 | all_heads = torch.zeros(
246 | self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
247 | )
248 | all_heads[self.dims.n_text_layer // 2 :] = True
249 | self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
250 |
251 | def set_alignment_heads(self, dump: bytes):
252 | array = np.frombuffer(
253 | gzip.decompress(base64.b85decode(dump)), dtype=bool
254 | ).copy()
255 | mask = torch.from_numpy(array).reshape(
256 | self.dims.n_text_layer, self.dims.n_text_head
257 | )
258 | self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
259 |
260 | @property
261 | def encoder(self):
262 | if self.loaded_model is None or self.loaded_model != "encoder":
263 | # print("Load encoder to " + self.target_device)
264 | # start = time.process_time()
265 |
266 | if self.loaded_model == 'decoder':
267 | self.loaded_model_obj.to('cpu')
268 | del self.loaded_model_obj
269 | torch.cuda.empty_cache()
270 | gc.collect()
271 | self.loaded_model_obj = self._encoder.to(self.target_device)
272 | self.loaded_model = "encoder"
273 |
274 | # print(time.process_time() - start)
275 | return self.loaded_model_obj
276 |
277 | @property
278 | def decoder(self):
279 | if self.loaded_model is None or self.loaded_model != "decoder":
280 | # print("Load decoder to " + self.target_device)
281 | # start = time.process_time()
282 |
283 | if self.loaded_model == 'encoder':
284 | self.loaded_model_obj.to('cpu')
285 | del self.loaded_model_obj
286 | torch.cuda.empty_cache()
287 | gc.collect()
288 | self.loaded_model_obj = self._decoder.to(self.target_device)
289 | self.loaded_model = "decoder"
290 |
291 | # print(time.process_time() - start)
292 | return self.loaded_model_obj
293 |
294 | def load_state_dict(self, state_dict):
295 | for k in list(state_dict.keys()):
296 | newKey = k.replace('encoder.', '_encoder.').replace('decoder.', '_decoder.')
297 | state_dict[newKey] = state_dict[k]
298 | del state_dict[k]
299 | return super().load_state_dict(state_dict)
300 |
301 | def embed_audio(self, mel: torch.Tensor):
302 | return self.encoder(mel)
303 |
304 | def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
305 | return self.decoder(tokens, audio_features)
306 |
307 | def forward(
308 | self, mel: torch.Tensor, tokens: torch.Tensor
309 | ) -> Dict[str, torch.Tensor]:
310 | return self.decoder(tokens, self.encoder(mel))
311 |
312 | def to(self, device):
313 | self.target_device = device
314 | return self
315 |
316 | @property
317 | def device(self):
318 | return self.target_device
319 |
320 | @property
321 | def is_multilingual(self):
322 | return self.dims.n_vocab == 51865
323 |
324 | def install_kv_cache_hooks(self, cache: Optional[dict] = None):
325 | """
326 | The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
327 | tensors calculated for the previous positions. This method returns a dictionary that stores
328 | all caches, and the necessary hooks for the key and value projection modules that save the
329 | intermediate tensors to be reused during later calculations.
330 |
331 | Returns
332 | -------
333 | cache : Dict[nn.Module, torch.Tensor]
334 | A dictionary object mapping the key/value projection modules to its cache
335 | hooks : List[RemovableHandle]
336 | List of PyTorch RemovableHandle objects to stop the hooks to be called
337 | """
338 | cache = {**cache} if cache is not None else {}
339 | hooks = []
340 |
341 | def save_to_cache(module, _, output):
342 | if module not in cache or output.shape[1] > self.dims.n_text_ctx:
343 | # save as-is, for the first token or cross attention
344 | cache[module] = output
345 | else:
346 | cache[module] = torch.cat([cache[module], output], dim=1).detach()
347 | return cache[module]
348 |
349 | def install_hooks(layer: nn.Module):
350 | if isinstance(layer, MultiHeadAttention):
351 | hooks.append(layer.key.register_forward_hook(save_to_cache))
352 | hooks.append(layer.value.register_forward_hook(save_to_cache))
353 |
354 | self.decoder.apply(install_hooks)
355 | return cache, hooks
356 |
357 | detect_language = detect_language_function
358 | transcribe = transcribe_function
359 | decode = decode_function
360 |
--------------------------------------------------------------------------------
/whisper/normalizers/__init__.py:
--------------------------------------------------------------------------------
1 | from .basic import BasicTextNormalizer as BasicTextNormalizer
2 | from .english import EnglishTextNormalizer as EnglishTextNormalizer
3 |
--------------------------------------------------------------------------------
/whisper/normalizers/basic.py:
--------------------------------------------------------------------------------
1 | import re
2 | import unicodedata
3 |
4 | import regex
5 |
6 | # non-ASCII letters that are not separated by "NFKD" normalization
7 | ADDITIONAL_DIACRITICS = {
8 | "œ": "oe",
9 | "Œ": "OE",
10 | "ø": "o",
11 | "Ø": "O",
12 | "æ": "ae",
13 | "Æ": "AE",
14 | "ß": "ss",
15 | "ẞ": "SS",
16 | "đ": "d",
17 | "Đ": "D",
18 | "ð": "d",
19 | "Ð": "D",
20 | "þ": "th",
21 | "Þ": "th",
22 | "ł": "l",
23 | "Ł": "L",
24 | }
25 |
26 |
27 | def remove_symbols_and_diacritics(s: str, keep=""):
28 | """
29 | Replace any other markers, symbols, and punctuations with a space,
30 | and drop any diacritics (category 'Mn' and some manual mappings)
31 | """
32 | return "".join(
33 | c
34 | if c in keep
35 | else ADDITIONAL_DIACRITICS[c]
36 | if c in ADDITIONAL_DIACRITICS
37 | else ""
38 | if unicodedata.category(c) == "Mn"
39 | else " "
40 | if unicodedata.category(c)[0] in "MSP"
41 | else c
42 | for c in unicodedata.normalize("NFKD", s)
43 | )
44 |
45 |
46 | def remove_symbols(s: str):
47 | """
48 | Replace any other markers, symbols, punctuations with a space, keeping diacritics
49 | """
50 | return "".join(
51 | " " if unicodedata.category(c)[0] in "MSP" else c
52 | for c in unicodedata.normalize("NFKC", s)
53 | )
54 |
55 |
56 | class BasicTextNormalizer:
57 | def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
58 | self.clean = (
59 | remove_symbols_and_diacritics if remove_diacritics else remove_symbols
60 | )
61 | self.split_letters = split_letters
62 |
63 | def __call__(self, s: str):
64 | s = s.lower()
65 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
66 | s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
67 | s = self.clean(s).lower()
68 |
69 | if self.split_letters:
70 | s = " ".join(regex.findall(r"\X", s, regex.U))
71 |
72 | s = re.sub(
73 | r"\s+", " ", s
74 | ) # replace any successive whitespace characters with a space
75 |
76 | return s
77 |
--------------------------------------------------------------------------------
/whisper/normalizers/english.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import re
4 | from fractions import Fraction
5 | from typing import Iterator, List, Match, Optional, Union
6 |
7 | from more_itertools import windowed
8 |
9 | from .basic import remove_symbols_and_diacritics
10 |
11 |
12 | class EnglishNumberNormalizer:
13 | """
14 | Convert any spelled-out numbers into arabic numbers, while handling:
15 |
16 | - remove any commas
17 | - keep the suffixes such as: `1960s`, `274th`, `32nd`, etc.
18 | - spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars`
19 | - spell out `one` and `ones`
20 | - interpret successive single-digit numbers as nominal: `one oh one` -> `101`
21 | """
22 |
23 | def __init__(self):
24 | super().__init__()
25 |
26 | self.zeros = {"o", "oh", "zero"}
27 | self.ones = {
28 | name: i
29 | for i, name in enumerate(
30 | [
31 | "one",
32 | "two",
33 | "three",
34 | "four",
35 | "five",
36 | "six",
37 | "seven",
38 | "eight",
39 | "nine",
40 | "ten",
41 | "eleven",
42 | "twelve",
43 | "thirteen",
44 | "fourteen",
45 | "fifteen",
46 | "sixteen",
47 | "seventeen",
48 | "eighteen",
49 | "nineteen",
50 | ],
51 | start=1,
52 | )
53 | }
54 | self.ones_plural = {
55 | "sixes" if name == "six" else name + "s": (value, "s")
56 | for name, value in self.ones.items()
57 | }
58 | self.ones_ordinal = {
59 | "zeroth": (0, "th"),
60 | "first": (1, "st"),
61 | "second": (2, "nd"),
62 | "third": (3, "rd"),
63 | "fifth": (5, "th"),
64 | "twelfth": (12, "th"),
65 | **{
66 | name + ("h" if name.endswith("t") else "th"): (value, "th")
67 | for name, value in self.ones.items()
68 | if value > 3 and value != 5 and value != 12
69 | },
70 | }
71 | self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal}
72 |
73 | self.tens = {
74 | "twenty": 20,
75 | "thirty": 30,
76 | "forty": 40,
77 | "fifty": 50,
78 | "sixty": 60,
79 | "seventy": 70,
80 | "eighty": 80,
81 | "ninety": 90,
82 | }
83 | self.tens_plural = {
84 | name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
85 | }
86 | self.tens_ordinal = {
87 | name.replace("y", "ieth"): (value, "th")
88 | for name, value in self.tens.items()
89 | }
90 | self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
91 |
92 | self.multipliers = {
93 | "hundred": 100,
94 | "thousand": 1_000,
95 | "million": 1_000_000,
96 | "billion": 1_000_000_000,
97 | "trillion": 1_000_000_000_000,
98 | "quadrillion": 1_000_000_000_000_000,
99 | "quintillion": 1_000_000_000_000_000_000,
100 | "sextillion": 1_000_000_000_000_000_000_000,
101 | "septillion": 1_000_000_000_000_000_000_000_000,
102 | "octillion": 1_000_000_000_000_000_000_000_000_000,
103 | "nonillion": 1_000_000_000_000_000_000_000_000_000_000,
104 | "decillion": 1_000_000_000_000_000_000_000_000_000_000_000,
105 | }
106 | self.multipliers_plural = {
107 | name + "s": (value, "s") for name, value in self.multipliers.items()
108 | }
109 | self.multipliers_ordinal = {
110 | name + "th": (value, "th") for name, value in self.multipliers.items()
111 | }
112 | self.multipliers_suffixed = {
113 | **self.multipliers_plural,
114 | **self.multipliers_ordinal,
115 | }
116 | self.decimals = {*self.ones, *self.tens, *self.zeros}
117 |
118 | self.preceding_prefixers = {
119 | "minus": "-",
120 | "negative": "-",
121 | "plus": "+",
122 | "positive": "+",
123 | }
124 | self.following_prefixers = {
125 | "pound": "£",
126 | "pounds": "£",
127 | "euro": "€",
128 | "euros": "€",
129 | "dollar": "$",
130 | "dollars": "$",
131 | "cent": "¢",
132 | "cents": "¢",
133 | }
134 | self.prefixes = set(
135 | list(self.preceding_prefixers.values())
136 | + list(self.following_prefixers.values())
137 | )
138 | self.suffixers = {
139 | "per": {"cent": "%"},
140 | "percent": "%",
141 | }
142 | self.specials = {"and", "double", "triple", "point"}
143 |
144 | self.words = set(
145 | [
146 | key
147 | for mapping in [
148 | self.zeros,
149 | self.ones,
150 | self.ones_suffixed,
151 | self.tens,
152 | self.tens_suffixed,
153 | self.multipliers,
154 | self.multipliers_suffixed,
155 | self.preceding_prefixers,
156 | self.following_prefixers,
157 | self.suffixers,
158 | self.specials,
159 | ]
160 | for key in mapping
161 | ]
162 | )
163 | self.literal_words = {"one", "ones"}
164 |
165 | def process_words(self, words: List[str]) -> Iterator[str]:
166 | prefix: Optional[str] = None
167 | value: Optional[Union[str, int]] = None
168 | skip = False
169 |
170 | def to_fraction(s: str):
171 | try:
172 | return Fraction(s)
173 | except ValueError:
174 | return None
175 |
176 | def output(result: Union[str, int]):
177 | nonlocal prefix, value
178 | result = str(result)
179 | if prefix is not None:
180 | result = prefix + result
181 | value = None
182 | prefix = None
183 | return result
184 |
185 | if len(words) == 0:
186 | return
187 |
188 | for prev, current, next in windowed([None] + words + [None], 3):
189 | if skip:
190 | skip = False
191 | continue
192 |
193 | next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next)
194 | has_prefix = current[0] in self.prefixes
195 | current_without_prefix = current[1:] if has_prefix else current
196 | if re.match(r"^\d+(\.\d+)?$", current_without_prefix):
197 | # arabic numbers (potentially with signs and fractions)
198 | f = to_fraction(current_without_prefix)
199 | assert f is not None
200 | if value is not None:
201 | if isinstance(value, str) and value.endswith("."):
202 | # concatenate decimals / ip address components
203 | value = str(value) + str(current)
204 | continue
205 | else:
206 | yield output(value)
207 |
208 | prefix = current[0] if has_prefix else prefix
209 | if f.denominator == 1:
210 | value = f.numerator # store integers as int
211 | else:
212 | value = current_without_prefix
213 | elif current not in self.words:
214 | # non-numeric words
215 | if value is not None:
216 | yield output(value)
217 | yield output(current)
218 | elif current in self.zeros:
219 | value = str(value or "") + "0"
220 | elif current in self.ones:
221 | ones = self.ones[current]
222 |
223 | if value is None:
224 | value = ones
225 | elif isinstance(value, str) or prev in self.ones:
226 | if (
227 | prev in self.tens and ones < 10
228 | ): # replace the last zero with the digit
229 | assert value[-1] == "0"
230 | value = value[:-1] + str(ones)
231 | else:
232 | value = str(value) + str(ones)
233 | elif ones < 10:
234 | if value % 10 == 0:
235 | value += ones
236 | else:
237 | value = str(value) + str(ones)
238 | else: # eleven to nineteen
239 | if value % 100 == 0:
240 | value += ones
241 | else:
242 | value = str(value) + str(ones)
243 | elif current in self.ones_suffixed:
244 | # ordinal or cardinal; yield the number right away
245 | ones, suffix = self.ones_suffixed[current]
246 | if value is None:
247 | yield output(str(ones) + suffix)
248 | elif isinstance(value, str) or prev in self.ones:
249 | if prev in self.tens and ones < 10:
250 | assert value[-1] == "0"
251 | yield output(value[:-1] + str(ones) + suffix)
252 | else:
253 | yield output(str(value) + str(ones) + suffix)
254 | elif ones < 10:
255 | if value % 10 == 0:
256 | yield output(str(value + ones) + suffix)
257 | else:
258 | yield output(str(value) + str(ones) + suffix)
259 | else: # eleven to nineteen
260 | if value % 100 == 0:
261 | yield output(str(value + ones) + suffix)
262 | else:
263 | yield output(str(value) + str(ones) + suffix)
264 | value = None
265 | elif current in self.tens:
266 | tens = self.tens[current]
267 | if value is None:
268 | value = tens
269 | elif isinstance(value, str):
270 | value = str(value) + str(tens)
271 | else:
272 | if value % 100 == 0:
273 | value += tens
274 | else:
275 | value = str(value) + str(tens)
276 | elif current in self.tens_suffixed:
277 | # ordinal or cardinal; yield the number right away
278 | tens, suffix = self.tens_suffixed[current]
279 | if value is None:
280 | yield output(str(tens) + suffix)
281 | elif isinstance(value, str):
282 | yield output(str(value) + str(tens) + suffix)
283 | else:
284 | if value % 100 == 0:
285 | yield output(str(value + tens) + suffix)
286 | else:
287 | yield output(str(value) + str(tens) + suffix)
288 | elif current in self.multipliers:
289 | multiplier = self.multipliers[current]
290 | if value is None:
291 | value = multiplier
292 | elif isinstance(value, str) or value == 0:
293 | f = to_fraction(value)
294 | p = f * multiplier if f is not None else None
295 | if f is not None and p.denominator == 1:
296 | value = p.numerator
297 | else:
298 | yield output(value)
299 | value = multiplier
300 | else:
301 | before = value // 1000 * 1000
302 | residual = value % 1000
303 | value = before + residual * multiplier
304 | elif current in self.multipliers_suffixed:
305 | multiplier, suffix = self.multipliers_suffixed[current]
306 | if value is None:
307 | yield output(str(multiplier) + suffix)
308 | elif isinstance(value, str):
309 | f = to_fraction(value)
310 | p = f * multiplier if f is not None else None
311 | if f is not None and p.denominator == 1:
312 | yield output(str(p.numerator) + suffix)
313 | else:
314 | yield output(value)
315 | yield output(str(multiplier) + suffix)
316 | else: # int
317 | before = value // 1000 * 1000
318 | residual = value % 1000
319 | value = before + residual * multiplier
320 | yield output(str(value) + suffix)
321 | value = None
322 | elif current in self.preceding_prefixers:
323 | # apply prefix (positive, minus, etc.) if it precedes a number
324 | if value is not None:
325 | yield output(value)
326 |
327 | if next in self.words or next_is_numeric:
328 | prefix = self.preceding_prefixers[current]
329 | else:
330 | yield output(current)
331 | elif current in self.following_prefixers:
332 | # apply prefix (dollars, cents, etc.) only after a number
333 | if value is not None:
334 | prefix = self.following_prefixers[current]
335 | yield output(value)
336 | else:
337 | yield output(current)
338 | elif current in self.suffixers:
339 | # apply suffix symbols (percent -> '%')
340 | if value is not None:
341 | suffix = self.suffixers[current]
342 | if isinstance(suffix, dict):
343 | if next in suffix:
344 | yield output(str(value) + suffix[next])
345 | skip = True
346 | else:
347 | yield output(value)
348 | yield output(current)
349 | else:
350 | yield output(str(value) + suffix)
351 | else:
352 | yield output(current)
353 | elif current in self.specials:
354 | if next not in self.words and not next_is_numeric:
355 | # apply special handling only if the next word can be numeric
356 | if value is not None:
357 | yield output(value)
358 | yield output(current)
359 | elif current == "and":
360 | # ignore "and" after hundreds, thousands, etc.
361 | if prev not in self.multipliers:
362 | if value is not None:
363 | yield output(value)
364 | yield output(current)
365 | elif current == "double" or current == "triple":
366 | if next in self.ones or next in self.zeros:
367 | repeats = 2 if current == "double" else 3
368 | ones = self.ones.get(next, 0)
369 | value = str(value or "") + str(ones) * repeats
370 | skip = True
371 | else:
372 | if value is not None:
373 | yield output(value)
374 | yield output(current)
375 | elif current == "point":
376 | if next in self.decimals or next_is_numeric:
377 | value = str(value or "") + "."
378 | else:
379 | # should all have been covered at this point
380 | raise ValueError(f"Unexpected token: {current}")
381 | else:
382 | # all should have been covered at this point
383 | raise ValueError(f"Unexpected token: {current}")
384 |
385 | if value is not None:
386 | yield output(value)
387 |
388 | def preprocess(self, s: str):
389 | # replace " and a half" with " point five"
390 | results = []
391 |
392 | segments = re.split(r"\band\s+a\s+half\b", s)
393 | for i, segment in enumerate(segments):
394 | if len(segment.strip()) == 0:
395 | continue
396 | if i == len(segments) - 1:
397 | results.append(segment)
398 | else:
399 | results.append(segment)
400 | last_word = segment.rsplit(maxsplit=2)[-1]
401 | if last_word in self.decimals or last_word in self.multipliers:
402 | results.append("point five")
403 | else:
404 | results.append("and a half")
405 |
406 | s = " ".join(results)
407 |
408 | # put a space at number/letter boundary
409 | s = re.sub(r"([a-z])([0-9])", r"\1 \2", s)
410 | s = re.sub(r"([0-9])([a-z])", r"\1 \2", s)
411 |
412 | # but remove spaces which could be a suffix
413 | s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s)
414 |
415 | return s
416 |
417 | def postprocess(self, s: str):
418 | def combine_cents(m: Match):
419 | try:
420 | currency = m.group(1)
421 | integer = m.group(2)
422 | cents = int(m.group(3))
423 | return f"{currency}{integer}.{cents:02d}"
424 | except ValueError:
425 | return m.string
426 |
427 | def extract_cents(m: Match):
428 | try:
429 | return f"¢{int(m.group(1))}"
430 | except ValueError:
431 | return m.string
432 |
433 | # apply currency postprocessing; "$2 and ¢7" -> "$2.07"
434 | s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s)
435 | s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s)
436 |
437 | # write "one(s)" instead of "1(s)", just for the readability
438 | s = re.sub(r"\b1(s?)\b", r"one\1", s)
439 |
440 | return s
441 |
442 | def __call__(self, s: str):
443 | s = self.preprocess(s)
444 | s = " ".join(word for word in self.process_words(s.split()) if word is not None)
445 | s = self.postprocess(s)
446 |
447 | return s
448 |
449 |
450 | class EnglishSpellingNormalizer:
451 | """
452 | Applies British-American spelling mappings as listed in [1].
453 |
454 | [1] https://www.tysto.com/uk-us-spelling-list.html
455 | """
456 |
457 | def __init__(self):
458 | mapping_path = os.path.join(os.path.dirname(__file__), "english.json")
459 | self.mapping = json.load(open(mapping_path))
460 |
461 | def __call__(self, s: str):
462 | return " ".join(self.mapping.get(word, word) for word in s.split())
463 |
464 |
465 | class EnglishTextNormalizer:
466 | def __init__(self):
467 | self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b"
468 | self.replacers = {
469 | # common contractions
470 | r"\bwon't\b": "will not",
471 | r"\bcan't\b": "can not",
472 | r"\blet's\b": "let us",
473 | r"\bain't\b": "aint",
474 | r"\by'all\b": "you all",
475 | r"\bwanna\b": "want to",
476 | r"\bgotta\b": "got to",
477 | r"\bgonna\b": "going to",
478 | r"\bi'ma\b": "i am going to",
479 | r"\bimma\b": "i am going to",
480 | r"\bwoulda\b": "would have",
481 | r"\bcoulda\b": "could have",
482 | r"\bshoulda\b": "should have",
483 | r"\bma'am\b": "madam",
484 | # contractions in titles/prefixes
485 | r"\bmr\b": "mister ",
486 | r"\bmrs\b": "missus ",
487 | r"\bst\b": "saint ",
488 | r"\bdr\b": "doctor ",
489 | r"\bprof\b": "professor ",
490 | r"\bcapt\b": "captain ",
491 | r"\bgov\b": "governor ",
492 | r"\bald\b": "alderman ",
493 | r"\bgen\b": "general ",
494 | r"\bsen\b": "senator ",
495 | r"\brep\b": "representative ",
496 | r"\bpres\b": "president ",
497 | r"\brev\b": "reverend ",
498 | r"\bhon\b": "honorable ",
499 | r"\basst\b": "assistant ",
500 | r"\bassoc\b": "associate ",
501 | r"\blt\b": "lieutenant ",
502 | r"\bcol\b": "colonel ",
503 | r"\bjr\b": "junior ",
504 | r"\bsr\b": "senior ",
505 | r"\besq\b": "esquire ",
506 | # prefect tenses, ideally it should be any past participles, but it's harder..
507 | r"'d been\b": " had been",
508 | r"'s been\b": " has been",
509 | r"'d gone\b": " had gone",
510 | r"'s gone\b": " has gone",
511 | r"'d done\b": " had done", # "'s done" is ambiguous
512 | r"'s got\b": " has got",
513 | # general contractions
514 | r"n't\b": " not",
515 | r"'re\b": " are",
516 | r"'s\b": " is",
517 | r"'d\b": " would",
518 | r"'ll\b": " will",
519 | r"'t\b": " not",
520 | r"'ve\b": " have",
521 | r"'m\b": " am",
522 | }
523 | self.standardize_numbers = EnglishNumberNormalizer()
524 | self.standardize_spellings = EnglishSpellingNormalizer()
525 |
526 | def __call__(self, s: str):
527 | s = s.lower()
528 |
529 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
530 | s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
531 | s = re.sub(self.ignore_patterns, "", s)
532 | s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe
533 |
534 | for pattern, replacement in self.replacers.items():
535 | s = re.sub(pattern, replacement, s)
536 |
537 | s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
538 | s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
539 | s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols
540 |
541 | s = self.standardize_numbers(s)
542 | s = self.standardize_spellings(s)
543 |
544 | # now remove prefix/suffix symbols that are not preceded/followed by numbers
545 | s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
546 | s = re.sub(r"([^0-9])%", r"\1 ", s)
547 |
548 | s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space
549 |
550 | return s
551 |
--------------------------------------------------------------------------------
/whisper/timing.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import subprocess
3 | import warnings
4 | from dataclasses import dataclass
5 | from typing import TYPE_CHECKING, List
6 |
7 | import numba
8 | import numpy as np
9 | import torch
10 | import torch.nn.functional as F
11 |
12 | from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
13 | from .tokenizer import Tokenizer
14 |
15 | if TYPE_CHECKING:
16 | from .model import Whisper
17 |
18 |
19 | def median_filter(x: torch.Tensor, filter_width: int):
20 | """Apply a median filter of width `filter_width` along the last dimension of `x`"""
21 | pad_width = filter_width // 2
22 | if x.shape[-1] <= pad_width:
23 | # F.pad requires the padding width to be smaller than the input dimension
24 | return x
25 |
26 | if (ndim := x.ndim) <= 2:
27 | # `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
28 | x = x[None, None, :]
29 |
30 | assert (
31 | filter_width > 0 and filter_width % 2 == 1
32 | ), "`filter_width` should be an odd number"
33 |
34 | result = None
35 | x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
36 | if x.is_cuda:
37 | try:
38 | from .triton_ops import median_filter_cuda
39 |
40 | result = median_filter_cuda(x, filter_width)
41 | except (RuntimeError, subprocess.CalledProcessError):
42 | warnings.warn(
43 | "Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
44 | "falling back to a slower median kernel implementation..."
45 | )
46 |
47 | if result is None:
48 | # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
49 | result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
50 |
51 | if ndim <= 2:
52 | result = result[0, 0]
53 |
54 | return result
55 |
56 |
57 | @numba.jit
58 | def backtrace(trace: np.ndarray):
59 | i = trace.shape[0] - 1
60 | j = trace.shape[1] - 1
61 | trace[0, :] = 2
62 | trace[:, 0] = 1
63 |
64 | result = []
65 | while i > 0 or j > 0:
66 | result.append((i - 1, j - 1))
67 |
68 | if trace[i, j] == 0:
69 | i -= 1
70 | j -= 1
71 | elif trace[i, j] == 1:
72 | i -= 1
73 | elif trace[i, j] == 2:
74 | j -= 1
75 | else:
76 | raise ValueError("Unexpected trace[i, j]")
77 |
78 | result = np.array(result)
79 | return result[::-1, :].T
80 |
81 |
82 | @numba.jit(nopython=True, parallel=True)
83 | def dtw_cpu(x: np.ndarray):
84 | N, M = x.shape
85 | cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
86 | trace = -np.ones((N + 1, M + 1), dtype=np.float32)
87 |
88 | cost[0, 0] = 0
89 | for j in range(1, M + 1):
90 | for i in range(1, N + 1):
91 | c0 = cost[i - 1, j - 1]
92 | c1 = cost[i - 1, j]
93 | c2 = cost[i, j - 1]
94 |
95 | if c0 < c1 and c0 < c2:
96 | c, t = c0, 0
97 | elif c1 < c0 and c1 < c2:
98 | c, t = c1, 1
99 | else:
100 | c, t = c2, 2
101 |
102 | cost[i, j] = x[i - 1, j - 1] + c
103 | trace[i, j] = t
104 |
105 | return backtrace(trace)
106 |
107 |
108 | def dtw_cuda(x, BLOCK_SIZE=1024):
109 | from .triton_ops import dtw_kernel
110 |
111 | M, N = x.shape
112 | assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
113 |
114 | x_skew = (
115 | F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
116 | )
117 | x_skew = x_skew.T.contiguous()
118 | cost = torch.ones(N + M + 2, M + 2) * np.inf
119 | cost[0, 0] = 0
120 | cost = cost.cuda()
121 | trace = torch.zeros_like(cost, dtype=torch.int32)
122 |
123 | dtw_kernel[(1,)](
124 | cost,
125 | trace,
126 | x_skew,
127 | x_skew.stride(0),
128 | cost.stride(0),
129 | trace.stride(0),
130 | N,
131 | M,
132 | BLOCK_SIZE=BLOCK_SIZE,
133 | )
134 |
135 | trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
136 | :, : N + 1
137 | ]
138 | return backtrace(trace.cpu().numpy())
139 |
140 |
141 | def dtw(x: torch.Tensor) -> np.ndarray:
142 | if x.is_cuda:
143 | try:
144 | return dtw_cuda(x)
145 | except (RuntimeError, subprocess.CalledProcessError):
146 | warnings.warn(
147 | "Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
148 | "falling back to a slower DTW implementation..."
149 | )
150 |
151 | return dtw_cpu(x.double().cpu().numpy())
152 |
153 |
154 | @dataclass
155 | class WordTiming:
156 | word: str
157 | tokens: List[int]
158 | start: float
159 | end: float
160 | probability: float
161 |
162 |
163 | def find_alignment(
164 | model: "Whisper",
165 | tokenizer: Tokenizer,
166 | text_tokens: List[int],
167 | mel: torch.Tensor,
168 | num_frames: int,
169 | *,
170 | medfilt_width: int = 7,
171 | qk_scale: float = 1.0,
172 | ) -> List[WordTiming]:
173 | if len(text_tokens) == 0:
174 | return []
175 |
176 | tokens = torch.tensor(
177 | [
178 | *tokenizer.sot_sequence,
179 | tokenizer.no_timestamps,
180 | *text_tokens,
181 | tokenizer.eot,
182 | ]
183 | ).to(model.device)
184 |
185 | # install hooks on the cross attention layers to retrieve the attention weights
186 | QKs = [None] * model.dims.n_text_layer
187 | hooks = [
188 | block.cross_attn.register_forward_hook(
189 | lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
190 | )
191 | for i, block in enumerate(model.decoder.blocks)
192 | ]
193 |
194 | with torch.no_grad():
195 | logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
196 | sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
197 | token_probs = sampled_logits.softmax(dim=-1)
198 | text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
199 | text_token_probs = text_token_probs.tolist()
200 |
201 | for hook in hooks:
202 | hook.remove()
203 |
204 | # heads * tokens * frames
205 | weights = torch.stack([QKs[l][h] for l, h in model.alignment_heads.indices().T])
206 | weights = weights[:, :, : num_frames // 2]
207 | weights = (weights * qk_scale).softmax(dim=-1)
208 | std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
209 | weights = (weights - mean) / std
210 | weights = median_filter(weights, medfilt_width)
211 |
212 | matrix = weights.mean(axis=0)
213 | matrix = matrix[len(tokenizer.sot_sequence) : -1]
214 | text_indices, time_indices = dtw(-matrix)
215 |
216 | words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
217 | word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
218 |
219 | jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
220 | jump_times = time_indices[jumps] / TOKENS_PER_SECOND
221 | start_times = jump_times[word_boundaries[:-1]]
222 | end_times = jump_times[word_boundaries[1:]]
223 | word_probabilities = [
224 | np.mean(text_token_probs[i:j])
225 | for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
226 | ]
227 |
228 | # hack: ensure the first and second word is not longer than twice the median word duration.
229 | # a better segmentation algorithm based on VAD should be able to replace this.
230 | word_durations = end_times - start_times
231 | word_durations = word_durations[word_durations.nonzero()]
232 | if len(word_durations) > 0:
233 | median_duration = np.median(word_durations)
234 | max_duration = median_duration * 2
235 | if len(word_durations) >= 2 and word_durations[1] > max_duration:
236 | boundary = max(end_times[2] / 2, end_times[2] - max_duration)
237 | end_times[0] = start_times[1] = boundary
238 | if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration:
239 | start_times[0] = max(0, end_times[0] - max_duration)
240 |
241 | return [
242 | WordTiming(word, tokens, start, end, probability)
243 | for word, tokens, start, end, probability in zip(
244 | words, word_tokens, start_times, end_times, word_probabilities
245 | )
246 | ]
247 |
248 |
249 | def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
250 | # merge prepended punctuations
251 | i = len(alignment) - 2
252 | j = len(alignment) - 1
253 | while i >= 0:
254 | previous = alignment[i]
255 | following = alignment[j]
256 | if previous.word.startswith(" ") and previous.word.strip() in prepended:
257 | # prepend it to the following word
258 | following.word = previous.word + following.word
259 | following.tokens = previous.tokens + following.tokens
260 | previous.word = ""
261 | previous.tokens = []
262 | else:
263 | j = i
264 | i -= 1
265 |
266 | # merge appended punctuations
267 | i = 0
268 | j = 1
269 | while j < len(alignment):
270 | previous = alignment[i]
271 | following = alignment[j]
272 | if not previous.word.endswith(" ") and following.word in appended:
273 | # append it to the previous word
274 | previous.word = previous.word + following.word
275 | previous.tokens = previous.tokens + following.tokens
276 | following.word = ""
277 | following.tokens = []
278 | else:
279 | i = j
280 | j += 1
281 |
282 |
283 | def add_word_timestamps(
284 | *,
285 | segments: List[dict],
286 | model: "Whisper",
287 | tokenizer: Tokenizer,
288 | mel: torch.Tensor,
289 | num_frames: int,
290 | prepend_punctuations: str = "\"'“¿([{-",
291 | append_punctuations: str = "\"'.。,,!!??::”)]}、",
292 | **kwargs,
293 | ):
294 | if len(segments) == 0:
295 | return
296 |
297 | text_tokens_per_segment = [
298 | [token for token in segment["tokens"] if token < tokenizer.eot]
299 | for segment in segments
300 | ]
301 |
302 | text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
303 | alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
304 | merge_punctuations(alignment, prepend_punctuations, append_punctuations)
305 |
306 | time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
307 | word_index = 0
308 |
309 | for segment, text_tokens in zip(segments, text_tokens_per_segment):
310 | saved_tokens = 0
311 | words = []
312 |
313 | while word_index < len(alignment) and saved_tokens < len(text_tokens):
314 | timing = alignment[word_index]
315 |
316 | if timing.word:
317 | words.append(
318 | dict(
319 | word=timing.word,
320 | start=round(time_offset + timing.start, 2),
321 | end=round(time_offset + timing.end, 2),
322 | probability=timing.probability,
323 | )
324 | )
325 |
326 | saved_tokens += len(timing.tokens)
327 | word_index += 1
328 |
329 | if len(words) > 0:
330 | # adjust the segment-level timestamps based on the word-level timestamps
331 | segment["start"] = words[0]["start"]
332 | segment["end"] = words[-1]["end"]
333 |
334 | segment["words"] = words
335 |
--------------------------------------------------------------------------------
/whisper/tokenizer.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import os
3 | import string
4 | from dataclasses import dataclass, field
5 | from functools import cached_property, lru_cache
6 | from typing import Dict, List, Optional, Tuple
7 |
8 | import tiktoken
9 | from tiktoken_ext.openai_public import gpt2
10 |
11 | LANGUAGES = {
12 | "en": "english",
13 | "zh": "chinese",
14 | "de": "german",
15 | "es": "spanish",
16 | "ru": "russian",
17 | "ko": "korean",
18 | "fr": "french",
19 | "ja": "japanese",
20 | "pt": "portuguese",
21 | "tr": "turkish",
22 | "pl": "polish",
23 | "ca": "catalan",
24 | "nl": "dutch",
25 | "ar": "arabic",
26 | "sv": "swedish",
27 | "it": "italian",
28 | "id": "indonesian",
29 | "hi": "hindi",
30 | "fi": "finnish",
31 | "vi": "vietnamese",
32 | "he": "hebrew",
33 | "uk": "ukrainian",
34 | "el": "greek",
35 | "ms": "malay",
36 | "cs": "czech",
37 | "ro": "romanian",
38 | "da": "danish",
39 | "hu": "hungarian",
40 | "ta": "tamil",
41 | "no": "norwegian",
42 | "th": "thai",
43 | "ur": "urdu",
44 | "hr": "croatian",
45 | "bg": "bulgarian",
46 | "lt": "lithuanian",
47 | "la": "latin",
48 | "mi": "maori",
49 | "ml": "malayalam",
50 | "cy": "welsh",
51 | "sk": "slovak",
52 | "te": "telugu",
53 | "fa": "persian",
54 | "lv": "latvian",
55 | "bn": "bengali",
56 | "sr": "serbian",
57 | "az": "azerbaijani",
58 | "sl": "slovenian",
59 | "kn": "kannada",
60 | "et": "estonian",
61 | "mk": "macedonian",
62 | "br": "breton",
63 | "eu": "basque",
64 | "is": "icelandic",
65 | "hy": "armenian",
66 | "ne": "nepali",
67 | "mn": "mongolian",
68 | "bs": "bosnian",
69 | "kk": "kazakh",
70 | "sq": "albanian",
71 | "sw": "swahili",
72 | "gl": "galician",
73 | "mr": "marathi",
74 | "pa": "punjabi",
75 | "si": "sinhala",
76 | "km": "khmer",
77 | "sn": "shona",
78 | "yo": "yoruba",
79 | "so": "somali",
80 | "af": "afrikaans",
81 | "oc": "occitan",
82 | "ka": "georgian",
83 | "be": "belarusian",
84 | "tg": "tajik",
85 | "sd": "sindhi",
86 | "gu": "gujarati",
87 | "am": "amharic",
88 | "yi": "yiddish",
89 | "lo": "lao",
90 | "uz": "uzbek",
91 | "fo": "faroese",
92 | "ht": "haitian creole",
93 | "ps": "pashto",
94 | "tk": "turkmen",
95 | "nn": "nynorsk",
96 | "mt": "maltese",
97 | "sa": "sanskrit",
98 | "lb": "luxembourgish",
99 | "my": "myanmar",
100 | "bo": "tibetan",
101 | "tl": "tagalog",
102 | "mg": "malagasy",
103 | "as": "assamese",
104 | "tt": "tatar",
105 | "haw": "hawaiian",
106 | "ln": "lingala",
107 | "ha": "hausa",
108 | "ba": "bashkir",
109 | "jw": "javanese",
110 | "su": "sundanese",
111 | }
112 |
113 | # language code lookup by name, with a few language aliases
114 | TO_LANGUAGE_CODE = {
115 | **{language: code for code, language in LANGUAGES.items()},
116 | "burmese": "my",
117 | "valencian": "ca",
118 | "flemish": "nl",
119 | "haitian": "ht",
120 | "letzeburgesch": "lb",
121 | "pushto": "ps",
122 | "panjabi": "pa",
123 | "moldavian": "ro",
124 | "moldovan": "ro",
125 | "sinhalese": "si",
126 | "castilian": "es",
127 | }
128 |
129 |
130 | @dataclass
131 | class Tokenizer:
132 | """A thin wrapper around `tiktoken` providing quick access to special tokens"""
133 |
134 | encoding: tiktoken.Encoding
135 | language: Optional[str] = None
136 | task: Optional[str] = None
137 | sot_sequence: Tuple[int] = ()
138 | special_tokens: Dict[str, int] = field(default_factory=dict)
139 |
140 | def __post_init__(self):
141 | for special in self.encoding.special_tokens_set:
142 | special_token = self.encoding.encode_single_token(special)
143 | self.special_tokens[special] = special_token
144 |
145 | sot: int = self.special_tokens["<|startoftranscript|>"]
146 | translate: int = self.special_tokens["<|translate|>"]
147 | transcribe: int = self.special_tokens["<|transcribe|>"]
148 |
149 | langs = tuple(LANGUAGES.keys())
150 | sot_sequence = [sot]
151 | if self.language is not None:
152 | sot_sequence.append(sot + 1 + langs.index(self.language))
153 | if self.task is not None:
154 | task_token: int = transcribe if self.task == "transcribe" else translate
155 | sot_sequence.append(task_token)
156 |
157 | self.sot_sequence = tuple(sot_sequence)
158 |
159 | def encode(self, text, **kwargs):
160 | return self.encoding.encode(text, **kwargs)
161 |
162 | def decode(self, token_ids: List[int], **kwargs) -> str:
163 | token_ids = [t for t in token_ids if t < self.timestamp_begin]
164 | return self.encoding.decode(token_ids, **kwargs)
165 |
166 | def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
167 | """
168 | Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
169 | This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
170 | """
171 | return self.encoding.decode(token_ids, **kwargs)
172 |
173 | @cached_property
174 | def eot(self) -> int:
175 | return self.encoding.eot_token
176 |
177 | @cached_property
178 | def transcribe(self) -> int:
179 | return self.special_tokens["<|transcribe|>"]
180 |
181 | @cached_property
182 | def translate(self) -> int:
183 | return self.special_tokens["<|translate|>"]
184 |
185 | @cached_property
186 | def sot(self) -> int:
187 | return self.special_tokens["<|startoftranscript|>"]
188 |
189 | @cached_property
190 | def sot_lm(self) -> int:
191 | return self.special_tokens["<|startoflm|>"]
192 |
193 | @cached_property
194 | def sot_prev(self) -> int:
195 | return self.special_tokens["<|startofprev|>"]
196 |
197 | @cached_property
198 | def no_speech(self) -> int:
199 | return self.special_tokens["<|nospeech|>"]
200 |
201 | @cached_property
202 | def no_timestamps(self) -> int:
203 | return self.special_tokens["<|notimestamps|>"]
204 |
205 | @cached_property
206 | def timestamp_begin(self) -> int:
207 | return self.special_tokens["<|0.00|>"]
208 |
209 | @cached_property
210 | def language_token(self) -> int:
211 | """Returns the token id corresponding to the value of the `language` field"""
212 | if self.language is None:
213 | raise ValueError("This tokenizer does not have language token configured")
214 |
215 | if token := self.special_tokens.get(f"<|{self.language}|>", None):
216 | return token
217 |
218 | raise KeyError(f"Language {self.language} not found in tokenizer.")
219 |
220 | @cached_property
221 | def all_language_tokens(self) -> Tuple[int]:
222 | result = []
223 | for token, token_id in self.special_tokens.items():
224 | if token.strip("<|>") in LANGUAGES:
225 | result.append(token_id)
226 | return tuple(result)
227 |
228 | @cached_property
229 | def all_language_codes(self) -> Tuple[str]:
230 | return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
231 |
232 | @cached_property
233 | def sot_sequence_including_notimestamps(self) -> Tuple[int]:
234 | return tuple(list(self.sot_sequence) + [self.no_timestamps])
235 |
236 | @cached_property
237 | def non_speech_tokens(self) -> Tuple[int]:
238 | """
239 | Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
240 | annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
241 |
242 | - ♪♪♪
243 | - ( SPEAKING FOREIGN LANGUAGE )
244 | - [DAVID] Hey there,
245 |
246 | keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
247 | """
248 | symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
249 | symbols += (
250 | "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
251 | )
252 |
253 | # symbols that may be a single token or multiple tokens depending on the tokenizer.
254 | # In case they're multiple tokens, suppress the first token, which is safe because:
255 | # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
256 | # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
257 | miscellaneous = set("♩♪♫♬♭♮♯")
258 | assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
259 |
260 | # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
261 | result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
262 | for symbol in symbols + list(miscellaneous):
263 | for tokens in [
264 | self.encoding.encode(symbol),
265 | self.encoding.encode(" " + symbol),
266 | ]:
267 | if len(tokens) == 1 or symbol in miscellaneous:
268 | result.add(tokens[0])
269 |
270 | return tuple(sorted(result))
271 |
272 | def split_to_word_tokens(self, tokens: List[int]):
273 | if self.language in {"zh", "ja", "th", "lo", "my"}:
274 | # These languages don't typically use spaces, so it is difficult to split words
275 | # without morpheme analysis. Here, we instead split words at any
276 | # position where the tokens are decoded as valid unicode points
277 | return self.split_tokens_on_unicode(tokens)
278 |
279 | return self.split_tokens_on_spaces(tokens)
280 |
281 | def split_tokens_on_unicode(self, tokens: List[int]):
282 | decoded_full = self.decode_with_timestamps(tokens)
283 | replacement_char = "\ufffd"
284 |
285 | words = []
286 | word_tokens = []
287 | current_tokens = []
288 | unicode_offset = 0
289 |
290 | for token in tokens:
291 | current_tokens.append(token)
292 | decoded = self.decode_with_timestamps(current_tokens)
293 |
294 | if (
295 | replacement_char not in decoded
296 | or decoded_full[unicode_offset + decoded.index(replacement_char)]
297 | == replacement_char
298 | ):
299 | words.append(decoded)
300 | word_tokens.append(current_tokens)
301 | current_tokens = []
302 | unicode_offset += len(decoded)
303 |
304 | return words, word_tokens
305 |
306 | def split_tokens_on_spaces(self, tokens: List[int]):
307 | subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
308 | words = []
309 | word_tokens = []
310 |
311 | for subword, subword_tokens in zip(subwords, subword_tokens_list):
312 | special = subword_tokens[0] >= self.eot
313 | with_space = subword.startswith(" ")
314 | punctuation = subword.strip() in string.punctuation
315 | if special or with_space or punctuation or len(words) == 0:
316 | words.append(subword)
317 | word_tokens.append(subword_tokens)
318 | else:
319 | words[-1] = words[-1] + subword
320 | word_tokens[-1].extend(subword_tokens)
321 |
322 | return words, word_tokens
323 |
324 |
325 | @lru_cache(maxsize=None)
326 | def get_encoding(name: str = "gpt2"):
327 | vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
328 | ranks = {
329 | base64.b64decode(token): int(rank)
330 | for token, rank in (line.split() for line in open(vocab_path) if line)
331 | }
332 | n_vocab = len(ranks)
333 | special_tokens = {}
334 |
335 | specials = [
336 | "<|endoftext|>",
337 | "<|startoftranscript|>",
338 | *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
339 | "<|translate|>",
340 | "<|transcribe|>",
341 | "<|startoflm|>",
342 | "<|startofprev|>",
343 | "<|nospeech|>",
344 | "<|notimestamps|>",
345 | *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
346 | ]
347 |
348 | for token in specials:
349 | special_tokens[token] = n_vocab
350 | n_vocab += 1
351 |
352 | return tiktoken.Encoding(
353 | name=os.path.basename(vocab_path),
354 | explicit_n_vocab=n_vocab,
355 | pat_str=gpt2()["pat_str"],
356 | mergeable_ranks=ranks,
357 | special_tokens=special_tokens,
358 | )
359 |
360 |
361 | @lru_cache(maxsize=None)
362 | def get_tokenizer(
363 | multilingual: bool,
364 | *,
365 | language: Optional[str] = None,
366 | task: Optional[str] = None, # Literal["transcribe", "translate", None]
367 | ) -> Tokenizer:
368 | if language is not None:
369 | language = language.lower()
370 | if language not in LANGUAGES:
371 | if language in TO_LANGUAGE_CODE:
372 | language = TO_LANGUAGE_CODE[language]
373 | else:
374 | raise ValueError(f"Unsupported language: {language}")
375 |
376 | if multilingual:
377 | encoding_name = "multilingual"
378 | language = language or "en"
379 | task = task or "transcribe"
380 | else:
381 | encoding_name = "gpt2"
382 | language = None
383 | task = None
384 |
385 | encoding = get_encoding(name=encoding_name)
386 |
387 | return Tokenizer(encoding=encoding, language=language, task=task)
388 |
--------------------------------------------------------------------------------
/whisper/transcribe.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import warnings
4 | from typing import TYPE_CHECKING, Optional, Tuple, Union
5 |
6 | import numpy as np
7 | import torch
8 | import tqdm
9 |
10 | from .audio import (
11 | FRAMES_PER_SECOND,
12 | HOP_LENGTH,
13 | N_FRAMES,
14 | N_SAMPLES,
15 | SAMPLE_RATE,
16 | log_mel_spectrogram,
17 | pad_or_trim,
18 | )
19 | from .decoding import DecodingOptions, DecodingResult
20 | from .timing import add_word_timestamps
21 | from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
22 | from .utils import (
23 | exact_div,
24 | format_timestamp,
25 | get_writer,
26 | make_safe,
27 | optional_float,
28 | optional_int,
29 | str2bool,
30 | )
31 |
32 | if TYPE_CHECKING:
33 | from .model import Whisper
34 |
35 |
36 | def transcribe(
37 | model: "Whisper",
38 | audio: Union[str, np.ndarray, torch.Tensor],
39 | *,
40 | verbose: Optional[bool] = None,
41 | temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
42 | compression_ratio_threshold: Optional[float] = 2.4,
43 | logprob_threshold: Optional[float] = -1.0,
44 | no_speech_threshold: Optional[float] = 0.6,
45 | condition_on_previous_text: bool = True,
46 | initial_prompt: Optional[str] = None,
47 | word_timestamps: bool = False,
48 | prepend_punctuations: str = "\"'“¿([{-",
49 | append_punctuations: str = "\"'.。,,!!??::”)]}、",
50 | **decode_options,
51 | ):
52 | """
53 | Transcribe an audio file using Whisper
54 |
55 | Parameters
56 | ----------
57 | model: Whisper
58 | The Whisper model instance
59 |
60 | audio: Union[str, np.ndarray, torch.Tensor]
61 | The path to the audio file to open, or the audio waveform
62 |
63 | verbose: bool
64 | Whether to display the text being decoded to the console. If True, displays all the details,
65 | If False, displays minimal details. If None, does not display anything
66 |
67 | temperature: Union[float, Tuple[float, ...]]
68 | Temperature for sampling. It can be a tuple of temperatures, which will be successively used
69 | upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
70 |
71 | compression_ratio_threshold: float
72 | If the gzip compression ratio is above this value, treat as failed
73 |
74 | logprob_threshold: float
75 | If the average log probability over sampled tokens is below this value, treat as failed
76 |
77 | no_speech_threshold: float
78 | If the no_speech probability is higher than this value AND the average log probability
79 | over sampled tokens is below `logprob_threshold`, consider the segment as silent
80 |
81 | condition_on_previous_text: bool
82 | if True, the previous output of the model is provided as a prompt for the next window;
83 | disabling may make the text inconsistent across windows, but the model becomes less prone to
84 | getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
85 |
86 | word_timestamps: bool
87 | Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
88 | and include the timestamps for each word in each segment.
89 |
90 | prepend_punctuations: str
91 | If word_timestamps is True, merge these punctuation symbols with the next word
92 |
93 | append_punctuations: str
94 | If word_timestamps is True, merge these punctuation symbols with the previous word
95 |
96 | initial_prompt: Optional[str]
97 | Optional text to provide as a prompt for the first window. This can be used to provide, or
98 | "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
99 | to make it more likely to predict those word correctly.
100 |
101 | decode_options: dict
102 | Keyword arguments to construct `DecodingOptions` instances
103 |
104 | Returns
105 | -------
106 | A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
107 | the spoken language ("language"), which is detected when `decode_options["language"]` is None.
108 | """
109 | dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
110 | if model.device == torch.device("cpu"):
111 | if torch.cuda.is_available():
112 | warnings.warn("Performing inference on CPU when CUDA is available")
113 | if dtype == torch.float16:
114 | warnings.warn("FP16 is not supported on CPU; using FP32 instead")
115 | dtype = torch.float32
116 |
117 | if dtype == torch.float32:
118 | decode_options["fp16"] = False
119 |
120 | # Pad 30-seconds of silence to the input audio, for slicing
121 | mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
122 | content_frames = mel.shape[-1] - N_FRAMES
123 |
124 | if decode_options.get("language", None) is None:
125 | if not model.is_multilingual:
126 | decode_options["language"] = "en"
127 | else:
128 | if verbose:
129 | print(
130 | "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
131 | )
132 | mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
133 | _, probs = model.detect_language(mel_segment)
134 | decode_options["language"] = max(probs, key=probs.get)
135 | if verbose is not None:
136 | print(
137 | f"Detected language: {LANGUAGES[decode_options['language']].title()}"
138 | )
139 |
140 | language: str = decode_options["language"]
141 | task: str = decode_options.get("task", "transcribe")
142 | tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
143 |
144 | if word_timestamps and task == "translate":
145 | warnings.warn("Word-level timestamps on translations may not be reliable.")
146 |
147 | def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
148 | temperatures = (
149 | [temperature] if isinstance(temperature, (int, float)) else temperature
150 | )
151 | decode_result = None
152 |
153 | for t in temperatures:
154 | kwargs = {**decode_options}
155 | if t > 0:
156 | # disable beam_size and patience when t > 0
157 | kwargs.pop("beam_size", None)
158 | kwargs.pop("patience", None)
159 | else:
160 | # disable best_of when t == 0
161 | kwargs.pop("best_of", None)
162 |
163 | options = DecodingOptions(**kwargs, temperature=t)
164 | decode_result = model.decode(segment, options)
165 |
166 | needs_fallback = False
167 | if (
168 | compression_ratio_threshold is not None
169 | and decode_result.compression_ratio > compression_ratio_threshold
170 | ):
171 | needs_fallback = True # too repetitive
172 | if (
173 | logprob_threshold is not None
174 | and decode_result.avg_logprob < logprob_threshold
175 | ):
176 | needs_fallback = True # average log probability is too low
177 |
178 | if not needs_fallback:
179 | break
180 |
181 | return decode_result
182 |
183 | seek = 0
184 | input_stride = exact_div(
185 | N_FRAMES, model.dims.n_audio_ctx
186 | ) # mel frames per output token: 2
187 | time_precision = (
188 | input_stride * HOP_LENGTH / SAMPLE_RATE
189 | ) # time per output token: 0.02 (seconds)
190 | all_tokens = []
191 | all_segments = []
192 | prompt_reset_since = 0
193 |
194 | if initial_prompt is not None:
195 | initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
196 | all_tokens.extend(initial_prompt_tokens)
197 | else:
198 | initial_prompt_tokens = []
199 |
200 | def new_segment(
201 | *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
202 | ):
203 | tokens = tokens.tolist()
204 | text_tokens = [token for token in tokens if token < tokenizer.eot]
205 | return {
206 | "seek": seek,
207 | "start": start,
208 | "end": end,
209 | "text": tokenizer.decode(text_tokens),
210 | "tokens": tokens,
211 | "temperature": result.temperature,
212 | "avg_logprob": result.avg_logprob,
213 | "compression_ratio": result.compression_ratio,
214 | "no_speech_prob": result.no_speech_prob,
215 | }
216 |
217 | # show the progress bar when verbose is False (if True, transcribed text will be printed)
218 | with tqdm.tqdm(
219 | total=content_frames, unit="frames", disable=verbose is not False
220 | ) as pbar:
221 | while seek < content_frames:
222 | time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
223 | mel_segment = mel[:, seek : seek + N_FRAMES]
224 | segment_size = min(N_FRAMES, content_frames - seek)
225 | segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
226 | mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
227 |
228 | decode_options["prompt"] = all_tokens[prompt_reset_since:]
229 | result: DecodingResult = decode_with_fallback(mel_segment)
230 | tokens = torch.tensor(result.tokens)
231 |
232 | if no_speech_threshold is not None:
233 | # no voice activity check
234 | should_skip = result.no_speech_prob > no_speech_threshold
235 | if (
236 | logprob_threshold is not None
237 | and result.avg_logprob > logprob_threshold
238 | ):
239 | # don't skip if the logprob is high enough, despite the no_speech_prob
240 | should_skip = False
241 |
242 | if should_skip:
243 | seek += segment_size # fast-forward to the next segment boundary
244 | continue
245 |
246 | previous_seek = seek
247 | current_segments = []
248 |
249 | timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
250 | single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
251 |
252 | consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
253 | consecutive.add_(1)
254 | if len(consecutive) > 0:
255 | # if the output contains two consecutive timestamp tokens
256 | slices = consecutive.tolist()
257 | if single_timestamp_ending:
258 | slices.append(len(tokens))
259 |
260 | last_slice = 0
261 | for current_slice in slices:
262 | sliced_tokens = tokens[last_slice:current_slice]
263 | start_timestamp_pos = (
264 | sliced_tokens[0].item() - tokenizer.timestamp_begin
265 | )
266 | end_timestamp_pos = (
267 | sliced_tokens[-1].item() - tokenizer.timestamp_begin
268 | )
269 | current_segments.append(
270 | new_segment(
271 | start=time_offset + start_timestamp_pos * time_precision,
272 | end=time_offset + end_timestamp_pos * time_precision,
273 | tokens=sliced_tokens,
274 | result=result,
275 | )
276 | )
277 | last_slice = current_slice
278 |
279 | if single_timestamp_ending:
280 | # single timestamp at the end means no speech after the last timestamp.
281 | seek += segment_size
282 | else:
283 | # otherwise, ignore the unfinished segment and seek to the last timestamp
284 | last_timestamp_pos = (
285 | tokens[last_slice - 1].item() - tokenizer.timestamp_begin
286 | )
287 | seek += last_timestamp_pos * input_stride
288 | else:
289 | duration = segment_duration
290 | timestamps = tokens[timestamp_tokens.nonzero().flatten()]
291 | if (
292 | len(timestamps) > 0
293 | and timestamps[-1].item() != tokenizer.timestamp_begin
294 | ):
295 | # no consecutive timestamps but it has a timestamp; use the last one.
296 | last_timestamp_pos = (
297 | timestamps[-1].item() - tokenizer.timestamp_begin
298 | )
299 | duration = last_timestamp_pos * time_precision
300 |
301 | current_segments.append(
302 | new_segment(
303 | start=time_offset,
304 | end=time_offset + duration,
305 | tokens=tokens,
306 | result=result,
307 | )
308 | )
309 | seek += segment_size
310 |
311 | if not condition_on_previous_text or result.temperature > 0.5:
312 | # do not feed the prompt tokens if a high temperature was used
313 | prompt_reset_since = len(all_tokens)
314 |
315 | if word_timestamps:
316 | add_word_timestamps(
317 | segments=current_segments,
318 | model=model,
319 | tokenizer=tokenizer,
320 | mel=mel_segment,
321 | num_frames=segment_size,
322 | prepend_punctuations=prepend_punctuations,
323 | append_punctuations=append_punctuations,
324 | )
325 | word_end_timestamps = [
326 | w["end"] for s in current_segments for w in s["words"]
327 | ]
328 | if not single_timestamp_ending and len(word_end_timestamps) > 0:
329 | seek_shift = round(
330 | (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
331 | )
332 | if seek_shift > 0:
333 | seek = previous_seek + seek_shift
334 |
335 | if verbose:
336 | for segment in current_segments:
337 | start, end, text = segment["start"], segment["end"], segment["text"]
338 | line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
339 | print(make_safe(line))
340 |
341 | # if a segment is instantaneous or does not contain text, clear it
342 | for i, segment in enumerate(current_segments):
343 | if segment["start"] == segment["end"] or segment["text"].strip() == "":
344 | segment["text"] = ""
345 | segment["tokens"] = []
346 | segment["words"] = []
347 |
348 | all_segments.extend(
349 | [
350 | {"id": i, **segment}
351 | for i, segment in enumerate(
352 | current_segments, start=len(all_segments)
353 | )
354 | ]
355 | )
356 | all_tokens.extend(
357 | [token for segment in current_segments for token in segment["tokens"]]
358 | )
359 |
360 | # update progress bar
361 | pbar.update(min(content_frames, seek) - previous_seek)
362 |
363 | return dict(
364 | text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
365 | segments=all_segments,
366 | language=language,
367 | )
368 |
369 |
370 | def cli():
371 | from . import available_models
372 |
373 | # fmt: off
374 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
375 | parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
376 | parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
377 | parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
378 | parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
379 | parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
380 | parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
381 | parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
382 |
383 | parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
384 | parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
385 |
386 | parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
387 | parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
388 | parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
389 | parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
390 | parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
391 |
392 | parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
393 | parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
394 | parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
395 | parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
396 |
397 | parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
398 | parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
399 | parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
400 | parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
401 | parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
402 | parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
403 | parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
404 | parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
405 | # fmt: on
406 |
407 | args = parser.parse_args().__dict__
408 | model_name: str = args.pop("model")
409 | model_dir: str = args.pop("model_dir")
410 | output_dir: str = args.pop("output_dir")
411 | output_format: str = args.pop("output_format")
412 | device: str = args.pop("device")
413 | os.makedirs(output_dir, exist_ok=True)
414 |
415 | if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
416 | if args["language"] is not None:
417 | warnings.warn(
418 | f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
419 | )
420 | args["language"] = "en"
421 |
422 | temperature = args.pop("temperature")
423 | if (increment := args.pop("temperature_increment_on_fallback")) is not None:
424 | temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
425 | else:
426 | temperature = [temperature]
427 |
428 | if (threads := args.pop("threads")) > 0:
429 | torch.set_num_threads(threads)
430 |
431 | from . import load_model
432 |
433 | model = load_model(model_name, device=device, download_root=model_dir)
434 |
435 | writer = get_writer(output_format, output_dir)
436 | for audio_path in args.pop("audio"):
437 | result = transcribe(model, audio_path, temperature=temperature, **args)
438 | writer(result, audio_path)
439 |
440 |
441 | if __name__ == "__main__":
442 | cli()
443 |
--------------------------------------------------------------------------------
/whisper/triton_ops.py:
--------------------------------------------------------------------------------
1 | from functools import lru_cache
2 |
3 | import numpy as np
4 | import torch
5 |
6 | try:
7 | import triton
8 | import triton.language as tl
9 | except ImportError:
10 | raise RuntimeError("triton import failed; try `pip install --pre triton`")
11 |
12 |
13 | @triton.jit
14 | def dtw_kernel(
15 | cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr
16 | ):
17 | offsets = tl.arange(0, BLOCK_SIZE)
18 | mask = offsets < M
19 |
20 | for k in range(1, N + M + 1): # k = i + j
21 | tl.debug_barrier()
22 |
23 | p0 = cost + (k - 1) * cost_stride
24 | p1 = cost + k * cost_stride
25 | p2 = cost + k * cost_stride + 1
26 |
27 | c0 = tl.load(p0 + offsets, mask=mask)
28 | c1 = tl.load(p1 + offsets, mask=mask)
29 | c2 = tl.load(p2 + offsets, mask=mask)
30 |
31 | x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0)
32 | cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2)
33 |
34 | cost_ptr = cost + (k + 1) * cost_stride + 1
35 | tl.store(cost_ptr + offsets, cost_row, mask=mask)
36 |
37 | trace_ptr = trace + (k + 1) * trace_stride + 1
38 | tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1))
39 | tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2))
40 | tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2))
41 |
42 |
43 | @lru_cache(maxsize=None)
44 | def median_kernel(filter_width: int):
45 | @triton.jit
46 | def kernel(
47 | y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr
48 | ): # x.shape[-1] == filter_width
49 | row_idx = tl.program_id(0)
50 | offsets = tl.arange(0, BLOCK_SIZE)
51 | mask = offsets < y_stride
52 |
53 | x_ptr = x + row_idx * x_stride # noqa: F841
54 | y_ptr = y + row_idx * y_stride
55 |
56 | LOAD_ALL_ROWS_HERE # noqa: F821
57 |
58 | BUBBLESORT_HERE # noqa: F821
59 |
60 | tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
61 |
62 | kernel = triton.JITFunction(kernel.fn)
63 | kernel.src = kernel.src.replace(
64 | " LOAD_ALL_ROWS_HERE",
65 | "\n".join(
66 | [
67 | f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
68 | for i in range(filter_width)
69 | ]
70 | ),
71 | )
72 | kernel.src = kernel.src.replace(
73 | " BUBBLESORT_HERE",
74 | "\n\n".join(
75 | [
76 | "\n\n".join(
77 | [
78 | "\n".join(
79 | [
80 | f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
81 | f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
82 | f" row{j} = smaller",
83 | f" row{j + 1} = larger",
84 | ]
85 | )
86 | for j in range(filter_width - i - 1)
87 | ]
88 | )
89 | for i in range(filter_width // 2 + 1)
90 | ]
91 | ),
92 | )
93 | kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
94 |
95 | return kernel
96 |
97 |
98 | def median_filter_cuda(x: torch.Tensor, filter_width: int):
99 | """Apply a median filter of given width along the last dimension of x"""
100 | slices = x.contiguous().unfold(-1, filter_width, 1)
101 | grid = np.prod(slices.shape[:-2])
102 |
103 | kernel = median_kernel(filter_width)
104 | y = torch.empty_like(slices[..., 0])
105 |
106 | BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length()
107 | kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE)
108 |
109 | return y
110 |
--------------------------------------------------------------------------------
/whisper/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import sys
4 | import zlib
5 | from typing import Callable, TextIO
6 |
7 | system_encoding = sys.getdefaultencoding()
8 |
9 | if system_encoding != "utf-8":
10 |
11 | def make_safe(string):
12 | # replaces any character not representable using the system default encoding with an '?',
13 | # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
14 | return string.encode(system_encoding, errors="replace").decode(system_encoding)
15 |
16 | else:
17 |
18 | def make_safe(string):
19 | # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
20 | return string
21 |
22 |
23 | def exact_div(x, y):
24 | assert x % y == 0
25 | return x // y
26 |
27 |
28 | def str2bool(string):
29 | str2val = {"True": True, "False": False}
30 | if string in str2val:
31 | return str2val[string]
32 | else:
33 | raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
34 |
35 |
36 | def optional_int(string):
37 | return None if string == "None" else int(string)
38 |
39 |
40 | def optional_float(string):
41 | return None if string == "None" else float(string)
42 |
43 |
44 | def compression_ratio(text) -> float:
45 | text_bytes = text.encode("utf-8")
46 | return len(text_bytes) / len(zlib.compress(text_bytes))
47 |
48 |
49 | def format_timestamp(
50 | seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
51 | ):
52 | assert seconds >= 0, "non-negative timestamp expected"
53 | milliseconds = round(seconds * 1000.0)
54 |
55 | hours = milliseconds // 3_600_000
56 | milliseconds -= hours * 3_600_000
57 |
58 | minutes = milliseconds // 60_000
59 | milliseconds -= minutes * 60_000
60 |
61 | seconds = milliseconds // 1_000
62 | milliseconds -= seconds * 1_000
63 |
64 | hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
65 | return (
66 | f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
67 | )
68 |
69 |
70 | class ResultWriter:
71 | extension: str
72 |
73 | def __init__(self, output_dir: str):
74 | self.output_dir = output_dir
75 |
76 | def __call__(self, result: dict, audio_path: str):
77 | audio_basename = os.path.basename(audio_path)
78 | audio_basename = os.path.splitext(audio_basename)[0]
79 | output_path = os.path.join(
80 | self.output_dir, audio_basename + "." + self.extension
81 | )
82 |
83 | with open(output_path, "w", encoding="utf-8") as f:
84 | self.write_result(result, file=f)
85 |
86 | def write_result(self, result: dict, file: TextIO):
87 | raise NotImplementedError
88 |
89 |
90 | class WriteTXT(ResultWriter):
91 | extension: str = "txt"
92 |
93 | def write_result(self, result: dict, file: TextIO):
94 | for segment in result["segments"]:
95 | print(segment["text"].strip(), file=file, flush=True)
96 |
97 |
98 | class SubtitlesWriter(ResultWriter):
99 | always_include_hours: bool
100 | decimal_marker: str
101 |
102 | def iterate_result(self, result: dict):
103 | for segment in result["segments"]:
104 | segment_start = self.format_timestamp(segment["start"])
105 | segment_end = self.format_timestamp(segment["end"])
106 | segment_text = segment["text"].strip().replace("-->", "->")
107 |
108 | if word_timings := segment.get("words", None):
109 | all_words = [timing["word"] for timing in word_timings]
110 | all_words[0] = all_words[0].strip() # remove the leading space, if any
111 | last = segment_start
112 | for i, this_word in enumerate(word_timings):
113 | start = self.format_timestamp(this_word["start"])
114 | end = self.format_timestamp(this_word["end"])
115 | if last != start:
116 | yield last, start, segment_text
117 |
118 | yield start, end, "".join(
119 | [
120 | f"{word}" if j == i else word
121 | for j, word in enumerate(all_words)
122 | ]
123 | )
124 | last = end
125 |
126 | if last != segment_end:
127 | yield last, segment_end, segment_text
128 | else:
129 | yield segment_start, segment_end, segment_text
130 |
131 | def format_timestamp(self, seconds: float):
132 | return format_timestamp(
133 | seconds=seconds,
134 | always_include_hours=self.always_include_hours,
135 | decimal_marker=self.decimal_marker,
136 | )
137 |
138 |
139 | class WriteVTT(SubtitlesWriter):
140 | extension: str = "vtt"
141 | always_include_hours: bool = False
142 | decimal_marker: str = "."
143 |
144 | def write_result(self, result: dict, file: TextIO):
145 | print("WEBVTT\n", file=file)
146 | for start, end, text in self.iterate_result(result):
147 | print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
148 |
149 |
150 | class WriteSRT(SubtitlesWriter):
151 | extension: str = "srt"
152 | always_include_hours: bool = True
153 | decimal_marker: str = ","
154 |
155 | def write_result(self, result: dict, file: TextIO):
156 | for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
157 | print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
158 |
159 |
160 | class WriteTSV(ResultWriter):
161 | """
162 | Write a transcript to a file in TSV (tab-separated values) format containing lines like:
163 | \t\t
164 |
165 | Using integer milliseconds as start and end times means there's no chance of interference from
166 | an environment setting a language encoding that causes the decimal in a floating point number
167 | to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
168 | """
169 |
170 | extension: str = "tsv"
171 |
172 | def write_result(self, result: dict, file: TextIO):
173 | print("start", "end", "text", sep="\t", file=file)
174 | for segment in result["segments"]:
175 | print(round(1000 * segment["start"]), file=file, end="\t")
176 | print(round(1000 * segment["end"]), file=file, end="\t")
177 | print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
178 |
179 |
180 | class WriteJSON(ResultWriter):
181 | extension: str = "json"
182 |
183 | def write_result(self, result: dict, file: TextIO):
184 | json.dump(result, file)
185 |
186 |
187 | def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
188 | writers = {
189 | "txt": WriteTXT,
190 | "vtt": WriteVTT,
191 | "srt": WriteSRT,
192 | "tsv": WriteTSV,
193 | "json": WriteJSON,
194 | }
195 |
196 | if output_format == "all":
197 | all_writers = [writer(output_dir) for writer in writers.values()]
198 |
199 | def write_all(result: dict, file: TextIO):
200 | for writer in all_writers:
201 | writer(result, file)
202 |
203 | return write_all
204 |
205 | return writers[output_format](output_dir)
206 |
--------------------------------------------------------------------------------
/whisper/version.py:
--------------------------------------------------------------------------------
1 | __version__ = "20230308"
2 |
--------------------------------------------------------------------------------