├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── approach.png ├── data ├── README.md └── meanwhile.json ├── language-breakdown.svg ├── model-card.md ├── notebooks ├── LibriSpeech.ipynb ├── Multilingual_ASR.ipynb └── efficient_whisper.ipynb ├── requirements.txt ├── setup.py ├── tests ├── jfk.flac ├── test_audio.py ├── test_normalizer.py ├── test_tokenizer.py └── test_transcribe.py └── whisper ├── __init__.py ├── __main__.py ├── assets ├── gpt2 │ ├── merges.txt │ ├── special_tokens_map.json │ ├── tokenizer_config.json │ └── vocab.json ├── mel_filters.npz └── multilingual │ ├── added_tokens.json │ ├── merges.txt │ ├── special_tokens_map.json │ ├── tokenizer_config.json │ └── vocab.json ├── audio.py ├── decoding.py ├── model.py ├── normalizers ├── __init__.py ├── basic.py ├── english.json └── english.py ├── tokenizer.py ├── transcribe.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | *.egg-info 5 | .pytest_cache 6 | .ipynb_checkpoints 7 | 8 | thumbs.db 9 | .DS_Store 10 | .idea 11 | 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include whisper/assets/* 2 | include whisper/assets/gpt2/* 3 | include whisper/assets/multilingual/* 4 | include whisper/normalizers/english.json 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Whisper 2 | 3 | [[Blog]](https://openai.com/blog/whisper) 4 | [[Paper]](https://cdn.openai.com/papers/whisper.pdf) 5 | [[Model card]](model-card.md) 6 | [[Colab example]](https://colab.research.google.com/github/openai/whisper/blob/master/notebooks/LibriSpeech.ipynb) 7 | 8 | Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multi-task model that can perform multilingual speech recognition as well as speech translation and language identification. 9 | 10 | 11 | ## Approach 12 | 13 | ![Approach](approach.png) 14 | 15 | A Transformer sequence-to-sequence model is trained on various speech processing tasks, including multilingual speech recognition, speech translation, spoken language identification, and voice activity detection. All of these tasks are jointly represented as a sequence of tokens to be predicted by the decoder, allowing for a single model to replace many different stages of a traditional speech processing pipeline. The multitask training format uses a set of special tokens that serve as task specifiers or classification targets. 16 | 17 | 18 | ## Setup 19 | 20 | We used Python 3.9.9 and [PyTorch](https://pytorch.org/) 1.10.1 to train and test our models, but the codebase is expected to be compatible with Python 3.7 or later and recent PyTorch versions. The codebase also depends on a few Python packages, most notably [HuggingFace Transformers](https://huggingface.co/docs/transformers/index) for their fast tokenizer implementation and [ffmpeg-python](https://github.com/kkroening/ffmpeg-python) for reading audio files. The following command will pull and install the latest commit from this repository, along with its Python dependencies 21 | 22 | pip install git+https://github.com/openai/whisper.git 23 | 24 | To update the package to the latest version of this repository, please run: 25 | 26 | pip install --upgrade --no-deps --force-reinstall git+https://github.com/openai/whisper.git 27 | 28 | It also requires the command-line tool [`ffmpeg`](https://ffmpeg.org/) to be installed on your system, which is available from most package managers: 29 | 30 | ```bash 31 | # on Ubuntu or Debian 32 | sudo apt update && sudo apt install ffmpeg 33 | 34 | # on Arch Linux 35 | sudo pacman -S ffmpeg 36 | 37 | # on MacOS using Homebrew (https://brew.sh/) 38 | brew install ffmpeg 39 | 40 | # on Windows using Chocolatey (https://chocolatey.org/) 41 | choco install ffmpeg 42 | 43 | # on Windows using Scoop (https://scoop.sh/) 44 | scoop install ffmpeg 45 | ``` 46 | 47 | You may need [`rust`](http://rust-lang.org) installed as well, in case [tokenizers](https://pypi.org/project/tokenizers/) does not provide a pre-built wheel for your platform. If you see installation errors during the `pip install` command above, please follow the [Getting started page](https://www.rust-lang.org/learn/get-started) to install Rust development environment. Additionally, you may need to configure the `PATH` environment variable, e.g. `export PATH="$HOME/.cargo/bin:$PATH"`. If the installation fails with `No module named 'setuptools_rust'`, you need to install `setuptools_rust`, e.g. by running: 48 | 49 | ```bash 50 | pip install setuptools-rust 51 | ``` 52 | 53 | 54 | ## Available models and languages 55 | 56 | There are five model sizes, four with English-only versions, offering speed and accuracy tradeoffs. Below are the names of the available models and their approximate memory requirements and relative speed. 57 | 58 | 59 | | Size | Parameters | English-only model | Multilingual model | Required VRAM | Relative speed | 60 | |:------:|:----------:|:------------------:|:------------------:|:-------------:|:--------------:| 61 | | tiny | 39 M | `tiny.en` | `tiny` | ~1 GB | ~32x | 62 | | base | 74 M | `base.en` | `base` | ~1 GB | ~16x | 63 | | small | 244 M | `small.en` | `small` | ~2 GB | ~6x | 64 | | medium | 769 M | `medium.en` | `medium` | ~5 GB | ~2x | 65 | | large | 1550 M | N/A | `large` | ~10 GB | 1x | 66 | 67 | For English-only applications, the `.en` models tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models. 68 | 69 | Whisper's performance varies widely depending on the language. The figure below shows a WER breakdown by languages of Fleurs dataset, using the `large` model. More WER and BLEU scores corresponding to the other models and datasets can be found in Appendix D in [the paper](https://cdn.openai.com/papers/whisper.pdf). 70 | 71 | ![WER breakdown by language](language-breakdown.svg) 72 | 73 | 74 | 75 | ## Command-line usage 76 | 77 | The following command will transcribe speech in audio files, using the `medium` model: 78 | 79 | whisper audio.flac audio.mp3 audio.wav --model medium 80 | 81 | The default setting (which selects the `small` model) works well for transcribing English. To transcribe an audio file containing non-English speech, you can specify the language using the `--language` option: 82 | 83 | whisper japanese.wav --language Japanese 84 | 85 | Adding `--task translate` will translate the speech into English: 86 | 87 | whisper japanese.wav --language Japanese --task translate 88 | 89 | Run the following to view all available options: 90 | 91 | whisper --help 92 | 93 | See [tokenizer.py](whisper/tokenizer.py) for the list of all available languages. 94 | 95 | 96 | ## Python usage 97 | 98 | Transcription can also be performed within Python: 99 | 100 | ```python 101 | import whisper 102 | 103 | model = whisper.load_model("base") 104 | result = model.transcribe("audio.mp3") 105 | print(result["text"]) 106 | ``` 107 | 108 | Internally, the `transcribe()` method reads the entire file and processes the audio with a sliding 30-second window, performing autoregressive sequence-to-sequence predictions on each window. 109 | 110 | Below is an example usage of `whisper.detect_language()` and `whisper.decode()` which provide lower-level access to the model. 111 | 112 | ```python 113 | import whisper 114 | 115 | model = whisper.load_model("base") 116 | 117 | # load audio and pad/trim it to fit 30 seconds 118 | audio = whisper.load_audio("audio.mp3") 119 | audio = whisper.pad_or_trim(audio) 120 | 121 | # make log-Mel spectrogram and move to the same device as the model 122 | mel = whisper.log_mel_spectrogram(audio).to(model.device) 123 | 124 | # detect the spoken language 125 | _, probs = model.detect_language(mel) 126 | print(f"Detected language: {max(probs, key=probs.get)}") 127 | 128 | # decode the audio 129 | options = whisper.DecodingOptions() 130 | result = whisper.decode(model, mel, options) 131 | 132 | # print the recognized text 133 | print(result.text) 134 | ``` 135 | 136 | ## More examples 137 | 138 | Please use the [🙌 Show and tell](https://github.com/openai/whisper/discussions/categories/show-and-tell) category in Discussions for sharing more example usages of Whisper and third-party extensions such as web demos, integrations with other tools, ports for different platforms, etc. 139 | 140 | 141 | ## License 142 | 143 | The code and the model weights of Whisper are released under the MIT License. See [LICENSE](LICENSE) for further details. 144 | -------------------------------------------------------------------------------- /approach.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/projectlucas/efficient_whisper/ed0c6cefbab7cad208ca33c91facd5674e1101a7/approach.png -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | This directory supplements the paper with more details on how we prepared the data for evaluation, to help replicate our experiments. 2 | 3 | ## Short-form English-only datasets 4 | 5 | ### LibriSpeech 6 | 7 | We used the test-clean and test-other splits from the [LibriSpeech ASR corpus](https://www.openslr.org/12). 8 | 9 | ### TED-LIUM 3 10 | 11 | We used the test split of [TED-LIUM Release 3](https://www.openslr.org/51/), using the segmented manual transcripts included in the release. 12 | 13 | ### Common Voice 5.1 14 | 15 | We downloaded the English subset of Common Voice Corpus 5.1 from [the official website](https://commonvoice.mozilla.org/en/datasets) 16 | 17 | ### Artie 18 | 19 | We used the [Artie bias corpus](https://github.com/artie-inc/artie-bias-corpus). This is a subset of the Common Voice dataset. 20 | 21 | ### CallHome & Switchboard 22 | 23 | We used the two corpora from [LDC2002S09](https://catalog.ldc.upenn.edu/LDC2002S09) and [LDC2002T43](https://catalog.ldc.upenn.edu/LDC2002T43) and followed the [eval2000_data_prep.sh](https://github.com/kaldi-asr/kaldi/blob/master/egs/fisher_swbd/s5/local/eval2000_data_prep.sh) script for preprocessing. The `wav.scp` files can be converted to WAV files with the following bash commands: 24 | 25 | ```bash 26 | mkdir -p wav 27 | while read name cmd; do 28 | echo $name 29 | echo ${cmd/\|/} wav/$name.wav | bash 30 | done < wav.scp 31 | ``` 32 | 33 | 34 | ### WSJ 35 | 36 | We used [LDC93S6B](https://catalog.ldc.upenn.edu/LDC93S6B) and [LDC94S13B](https://catalog.ldc.upenn.edu/LDC94S13B) and followed the [s5 recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/wsj/s5) to preprocess the dataset. 37 | 38 | ### CORAAL 39 | 40 | We used the 231 interviews from [CORAAL (v. 2021.07)](https://oraal.uoregon.edu/coraal) and used the segmentations from [the FairSpeech project](https://github.com/stanford-policylab/asr-disparities/blob/master/input/CORAAL_transcripts.csv). 41 | 42 | ### CHiME-6 43 | 44 | We downloaded the [CHiME-5 dataset](https://spandh.dcs.shef.ac.uk//chime_challenge/CHiME5/download.html) and followed the stage 0 of the [s5_track1 recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/chime6/s5_track1) to create the CHiME-6 dataset which fixes synchronization. We then used the binaural recordings (`*_P??.wav`) and the corresponding transcripts. 45 | 46 | ### AMI-IHM, AMI-SDM1 47 | 48 | We preprocessed the [AMI Corpus](https://groups.inf.ed.ac.uk/ami/corpus/overview.shtml) by following the stage 0 ad 2 of the [s5b recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/ami/s5b). 49 | 50 | 51 | ## Long-form English-only datasets 52 | 53 | ### TED-LIUM 3 54 | 55 | To create a long-form transcription dataset from the [TED-LIUM3](https://www.openslr.org/51/) dataset, we sliced the audio between the beginning of the first labeled segment and the end of the last labeled segment of each talk, and we used the concatenated text as the label. Below are the timestamps used for slicing each of the 11 TED talks in the test split. 56 | 57 | | Filename | Begin time (s) | End time (s) | 58 | |---------------------|----------------|--------------| 59 | | DanBarber_2010 | 16.09 | 1116.24 | 60 | | JaneMcGonigal_2010 | 15.476 | 1187.61 | 61 | | BillGates_2010 | 15.861 | 1656.94 | 62 | | TomWujec_2010U | 16.26 | 402.17 | 63 | | GaryFlake_2010 | 16.06 | 367.14 | 64 | | EricMead_2009P | 18.434 | 536.44 | 65 | | MichaelSpecter_2010 | 16.11 | 979.312 | 66 | | DanielKahneman_2010 | 15.8 | 1199.44 | 67 | | AimeeMullins_2009P | 17.82 | 1296.59 | 68 | | JamesCameron_2010 | 16.75 | 1010.65 | 69 | | RobertGupta_2010U | 16.8 | 387.03 | 70 | 71 | ### Meanwhile 72 | 73 | This dataset consists of 64 segments from The Late Show with Stephen Colbert. The YouTube video ID, start and end timestamps, and the labels can be found in [meanwhile.json](meanwhile.json). The labels are collected from the closed-caption data for each video and corrected with manual inspection. 74 | 75 | ### Rev16 76 | 77 | We use a subset of 16 files from the 30 podcast episodes in [Rev.AI's Podcast Transcription Benchmark](https://www.rev.ai/blog/podcast-transcription-benchmark-part-1/), after finding that there are multiple cases where a significant portion of the audio and the labels did not match, mostly on the parts introducing the sponsors. We selected 16 episodes that do not have this error, whose "file number" are: 78 | 79 | 3 4 9 10 11 14 17 18 20 21 23 24 26 27 29 32 80 | 81 | ### Kincaid46 82 | 83 | This dataset consists of 46 audio files and the corresponding transcripts compiled in the blog article [Which automatic transcription service is the most accurate - 2018](https://medium.com/descript/which-automatic-transcription-service-is-the-most-accurate-2018-2e859b23ed19) by Jason Kincaid. We used the 46 audio files and reference transcripts from the Airtable widget in the article. 84 | 85 | For the human transcription benchmark in the paper, we use a subset of 25 examples from this data, whose "Ref ID" are: 86 | 87 | 2 4 5 8 9 10 12 13 14 16 19 21 23 25 26 28 29 30 33 35 36 37 42 43 45 88 | 89 | ### Earnings-21, Earnings-22 90 | 91 | For these datasets, we used the files available in [the speech-datasets repository](https://github.com/revdotcom/speech-datasets), as of their `202206` version. 92 | 93 | ### CORAAL 94 | 95 | We used the 231 interviews from [CORAAL (v. 2021.07)](https://oraal.uoregon.edu/coraal) and used the full-length interview files and transcripts. 96 | 97 | 98 | ## Multilingual datasets 99 | 100 | ### Multilingual LibriSpeech 101 | 102 | We used the test splits from each language in [the Multilingual LibriSpeech (MLS) corpus](https://www.openslr.org/94/). 103 | 104 | ### Fleurs 105 | 106 | We collected audio files and transcripts using the implementation available as [HuggingFace datasets](https://huggingface.co/datasets/google/fleurs/blob/main/fleurs.py). To use as a translation dataset, we matched the numerical utterance IDs to find the corresponding transcript in English. 107 | 108 | ### VoxPopuli 109 | 110 | We used the `get_asr_data.py` script from [the official repository](https://github.com/facebookresearch/voxpopuli) to collect the ASR data in 14 languages. 111 | 112 | ### Common Voice 9 113 | 114 | We downloaded the Common Voice Corpus 9 from [the official website](https://commonvoice.mozilla.org/en/datasets) 115 | 116 | ### CoVOST 2 117 | 118 | We collected the `X into English` data collected using [the official repository](https://github.com/facebookresearch/covost). 119 | -------------------------------------------------------------------------------- /model-card.md: -------------------------------------------------------------------------------- 1 | # Model Card: Whisper 2 | 3 | This is the official codebase for running the automatic speech recognition (ASR) models (Whisper models) trained and released by OpenAI. 4 | 5 | Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993), we're providing some information about the automatic speech recognition model. More information on how these models were trained and evaluated can be found [in the paper](https://cdn.openai.com/papers/whisper.pdf). 6 | 7 | 8 | ## Model Details 9 | 10 | The Whisper models are trained for speech recognition and translation tasks, capable of transcribing speech audio into the text in the language it is spoken (ASR) as well as translated into English (speech translation). Researchers at OpenAI developed the models to study the robustness of speech processing systems trained under large-scale weak supervision. There are 9 models of different sizes and capabilities, summarized in the following table. 11 | 12 | | Size | Parameters | English-only model | Multilingual model | 13 | |:------:|:----------:|:------------------:|:------------------:| 14 | | tiny | 39 M | ✓ | ✓ | 15 | | base | 74 M | ✓ | ✓ | 16 | | small | 244 M | ✓ | ✓ | 17 | | medium | 769 M | ✓ | ✓ | 18 | | large | 1550 M | | ✓ | 19 | 20 | 21 | ### Release date 22 | 23 | September 2022 24 | 25 | ### Model type 26 | 27 | Sequence-to-sequence ASR (automatic speech recognition) and speech translation model 28 | 29 | ### Paper & samples 30 | 31 | [Paper](https://cdn.openai.com/papers/whisper.pdf) / [Blog](https://openai.com/blog/whisper) 32 | 33 | 34 | ## Model Use 35 | 36 | ### Evaluated Use 37 | 38 | The primary intended users of these models are AI researchers studying robustness, generalization, capabilities, biases, and constraints of the current model. However, Whisper is also potentially quite useful as an ASR solution for developers, especially for English speech recognition. We recognize that once models are released, it is impossible to restrict access to only “intended” uses or to draw reasonable guidelines around what is or is not research. 39 | 40 | The models are primarily trained and evaluated on ASR and speech translation to English tasks. They show strong ASR results in ~10 languages. They may exhibit additional capabilities, particularly if fine-tuned on certain tasks like voice activity detection, speaker classification, or speaker diarization but have not been robustly evaluated in these areas. We strongly recommend that users perform robust evaluations of the models in a particular context and domain before deploying them. 41 | 42 | In particular, we caution against using Whisper models to transcribe recordings of individuals taken without their consent or purporting to use these models for any kind of subjective classification. We recommend against use in high-risk domains like decision-making contexts, where flaws in accuracy can lead to pronounced flaws in outcomes. The models are intended to transcribe and translate speech, use of the model for classification is not only not evaluated but also not appropriate, particularly to infer human attributes. 43 | 44 | 45 | ## Training Data 46 | 47 | The models are trained on 680,000 hours of audio and the corresponding transcripts collected from the internet. 65% of this data (or 438,000 hours) represents English-language audio and matched English transcripts, roughly 18% (or 126,000 hours) represents non-English audio and English transcripts, while the final 17% (or 117,000 hours) represents non-English audio and the corresponding transcript. This non-English data represents 98 different languages. 48 | 49 | As discussed in [the accompanying paper](https://cdn.openai.com/papers/whisper.pdf), we see that performance on transcription in a given language is directly correlated with the amount of training data we employ in that language. 50 | 51 | 52 | ## Performance and Limitations 53 | 54 | Our studies show that, over many existing ASR systems, the models exhibit improved robustness to accents, background noise, technical language, as well as zero shot translation from multiple languages into English; and that accuracy on speech recognition and translation is near the state-of-the-art level. 55 | 56 | However, because the models are trained in a weakly supervised manner using large-scale noisy data, the predictions may include texts that are not actually spoken in the audio input (i.e. hallucination). We hypothesize that this happens because, given their general knowledge of language, the models combine trying to predict the next word in audio with trying to transcribe the audio itself. 57 | 58 | Our models perform unevenly across languages, and we observe lower accuracy on low-resource and/or low-discoverability languages or languages where we have less training data. The models also exhibit disparate performance on different accents and dialects of particular languages, which may include higher word error rate across speakers of different genders, races, ages, or other demographic criteria. Our full evaluation results are presented in [the paper accompanying this release](https://cdn.openai.com/papers/whisper.pdf). 59 | 60 | In addition, the sequence-to-sequence architecture of the model makes it prone to generating repetitive texts, which can be mitigated to some degree by beam search and temperature scheduling but not perfectly. Further analysis on these limitations are provided in [the paper](https://cdn.openai.com/papers/whisper.pdf). It is likely that this behavior and hallucinations may be worse on lower-resource and/or lower-discoverability languages. 61 | 62 | 63 | ## Broader Implications 64 | 65 | We anticipate that Whisper models’ transcription capabilities may be used for improving accessibility tools. While Whisper models cannot be used for real-time transcription out of the box – their speed and size suggest that others may be able to build applications on top of them that allow for near-real-time speech recognition and translation. The real value of beneficial applications built on top of Whisper models suggests that the disparate performance of these models may have real economic implications. 66 | 67 | There are also potential dual use concerns that come with releasing Whisper. While we hope the technology will be used primarily for beneficial purposes, making ASR technology more accessible could enable more actors to build capable surveillance technologies or scale up existing surveillance efforts, as the speed and accuracy allow for affordable automatic transcription and translation of large volumes of audio communication. Moreover, these models may have some capabilities to recognize specific individuals out of the box, which in turn presents safety concerns related both to dual use and disparate performance. In practice, we expect that the cost of transcription is not the limiting factor of scaling up surveillance projects. 68 | -------------------------------------------------------------------------------- /notebooks/LibriSpeech.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "v5hvo8QWN-a9" 7 | }, 8 | "source": [ 9 | "# Installing Whisper\n", 10 | "\n", 11 | "The commands below will install the Python packages needed to use Whisper models and evaluate the transcription results." 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 1, 17 | "metadata": { 18 | "id": "ZsJUxc0aRsAf" 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "! pip install git+https://github.com/openai/whisper.git\n", 23 | "! pip install jiwer" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": { 29 | "id": "1IMEkgyagYto" 30 | }, 31 | "source": [ 32 | "# Loading the LibriSpeech dataset\n", 33 | "\n", 34 | "The following will load the test-clean split of the LibriSpeech corpus using torchaudio." 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": { 41 | "id": "3CqtR2Fi5-vP" 42 | }, 43 | "outputs": [], 44 | "source": [ 45 | "import os\n", 46 | "import numpy as np\n", 47 | "\n", 48 | "try:\n", 49 | " import tensorflow # required in Colab to avoid protobuf compatibility issues\n", 50 | "except ImportError:\n", 51 | " pass\n", 52 | "\n", 53 | "import torch\n", 54 | "import pandas as pd\n", 55 | "import whisper\n", 56 | "import torchaudio\n", 57 | "\n", 58 | "from tqdm.notebook import tqdm\n", 59 | "\n", 60 | "\n", 61 | "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 3, 67 | "metadata": { 68 | "id": "GuCCB2KYOJCE" 69 | }, 70 | "outputs": [], 71 | "source": [ 72 | "class LibriSpeech(torch.utils.data.Dataset):\n", 73 | " \"\"\"\n", 74 | " A simple class to wrap LibriSpeech and trim/pad the audio to 30 seconds.\n", 75 | " It will drop the last few seconds of a very small portion of the utterances.\n", 76 | " \"\"\"\n", 77 | " def __init__(self, split=\"test-clean\", device=DEVICE):\n", 78 | " self.dataset = torchaudio.datasets.LIBRISPEECH(\n", 79 | " root=os.path.expanduser(\"~/.cache\"),\n", 80 | " url=split,\n", 81 | " download=True,\n", 82 | " )\n", 83 | " self.device = device\n", 84 | "\n", 85 | " def __len__(self):\n", 86 | " return len(self.dataset)\n", 87 | "\n", 88 | " def __getitem__(self, item):\n", 89 | " audio, sample_rate, text, _, _, _ = self.dataset[item]\n", 90 | " assert sample_rate == 16000\n", 91 | " audio = whisper.pad_or_trim(audio.flatten()).to(self.device)\n", 92 | " mel = whisper.log_mel_spectrogram(audio)\n", 93 | " \n", 94 | " return (mel, text)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 4, 100 | "metadata": { 101 | "id": "-YcRU5jqNqo2" 102 | }, 103 | "outputs": [], 104 | "source": [ 105 | "dataset = LibriSpeech(\"test-clean\")\n", 106 | "loader = torch.utils.data.DataLoader(dataset, batch_size=16)" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": { 112 | "id": "0ljocCNuUAde" 113 | }, 114 | "source": [ 115 | "# Running inference on the dataset using a base Whisper model\n", 116 | "\n", 117 | "The following will take a few minutes to transcribe all utterances in the dataset." 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 5, 123 | "metadata": { 124 | "colab": { 125 | "base_uri": "https://localhost:8080/" 126 | }, 127 | "id": "_PokfNJtOYNu", 128 | "outputId": "2c53ec44-bc93-4107-b4fa-214e3f71fe8e" 129 | }, 130 | "outputs": [ 131 | { 132 | "name": "stdout", 133 | "output_type": "stream", 134 | "text": [ 135 | "Model is English-only and has 71,825,408 parameters.\n" 136 | ] 137 | } 138 | ], 139 | "source": [ 140 | "model = whisper.load_model(\"base.en\")\n", 141 | "print(\n", 142 | " f\"Model is {'multilingual' if model.is_multilingual else 'English-only'} \"\n", 143 | " f\"and has {sum(np.prod(p.shape) for p in model.parameters()):,} parameters.\"\n", 144 | ")" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 6, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "# predict without timestamps for short-form transcription\n", 154 | "options = whisper.DecodingOptions(language=\"en\", without_timestamps=True)" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 7, 160 | "metadata": { 161 | "colab": { 162 | "base_uri": "https://localhost:8080/", 163 | "height": 49, 164 | "referenced_widgets": [ 165 | "09a29a91f58d4462942505a3cc415801", 166 | "83391f98a240490987c397048fc1a0d4", 167 | "06b9aa5f49fa44ba8c93b647dc7db224", 168 | "da9c231ee67047fb89073c95326b72a5", 169 | "48da931ebe7f4fd299f8c98c7d2460ff", 170 | "7a901f447c1d477bb49f954e0feacedd", 171 | "39f5a6ae8ba74c8598f9c6d5b8ad2d65", 172 | "a0d10a42c753453283e5219c22239337", 173 | "09f4cb79ff86465aaf48b0de24869af9", 174 | "1b9cecf5b3584fba8258a81d4279a25b", 175 | "039b53f2702c4179af7e0548018d0588" 176 | ] 177 | }, 178 | "id": "7OWTn_KvNk59", 179 | "outputId": "a813a792-3c91-4144-f11f-054fd6778023" 180 | }, 181 | "outputs": [ 182 | { 183 | "data": { 184 | "application/vnd.jupyter.widget-view+json": { 185 | "model_id": "9df048b46f764cf68cbe0045b8ff73a8", 186 | "version_major": 2, 187 | "version_minor": 0 188 | }, 189 | "text/plain": [ 190 | " 0%| | 0/164 [00:00\n", 223 | "\n", 236 | "\n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | "
hypothesisreference
0He hoped there would be stew for dinner, turni...HE HOPED THERE WOULD BE STEW FOR DINNER TURNIP...
1Stuffered into you, his belly counseled him.STUFF IT INTO YOU HIS BELLY COUNSELLED HIM
2After early nightfall the yellow lamps would l...AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD L...
3Hello Bertie, any good in your mind?HELLO BERTIE ANY GOOD IN YOUR MIND
4Number 10. Fresh Nelly is waiting on you. Good...NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD ...
.........
2615Oh, to shoot my soul's full meaning into futur...OH TO SHOOT MY SOUL'S FULL MEANING INTO FUTURE...
2616Then I, long tried by natural ills, received t...THEN I LONG TRIED BY NATURAL ILLS RECEIVED THE...
2617I love thee freely as men strive for right. I ...I LOVE THEE FREELY AS MEN STRIVE FOR RIGHT I L...
2618I love thee with the passion put to use, in my...I LOVE THEE WITH THE PASSION PUT TO USE IN MY ...
2619I love thee with the love I seemed to lose wit...I LOVE THEE WITH A LOVE I SEEMED TO LOSE WITH ...
\n", 302 | "

2620 rows × 2 columns

\n", 303 | "" 304 | ], 305 | "text/plain": [ 306 | " hypothesis \\\n", 307 | "0 He hoped there would be stew for dinner, turni... \n", 308 | "1 Stuffered into you, his belly counseled him. \n", 309 | "2 After early nightfall the yellow lamps would l... \n", 310 | "3 Hello Bertie, any good in your mind? \n", 311 | "4 Number 10. Fresh Nelly is waiting on you. Good... \n", 312 | "... ... \n", 313 | "2615 Oh, to shoot my soul's full meaning into futur... \n", 314 | "2616 Then I, long tried by natural ills, received t... \n", 315 | "2617 I love thee freely as men strive for right. I ... \n", 316 | "2618 I love thee with the passion put to use, in my... \n", 317 | "2619 I love thee with the love I seemed to lose wit... \n", 318 | "\n", 319 | " reference \n", 320 | "0 HE HOPED THERE WOULD BE STEW FOR DINNER TURNIP... \n", 321 | "1 STUFF IT INTO YOU HIS BELLY COUNSELLED HIM \n", 322 | "2 AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD L... \n", 323 | "3 HELLO BERTIE ANY GOOD IN YOUR MIND \n", 324 | "4 NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD ... \n", 325 | "... ... \n", 326 | "2615 OH TO SHOOT MY SOUL'S FULL MEANING INTO FUTURE... \n", 327 | "2616 THEN I LONG TRIED BY NATURAL ILLS RECEIVED THE... \n", 328 | "2617 I LOVE THEE FREELY AS MEN STRIVE FOR RIGHT I L... \n", 329 | "2618 I LOVE THEE WITH THE PASSION PUT TO USE IN MY ... \n", 330 | "2619 I LOVE THEE WITH A LOVE I SEEMED TO LOSE WITH ... \n", 331 | "\n", 332 | "[2620 rows x 2 columns]" 333 | ] 334 | }, 335 | "execution_count": 8, 336 | "metadata": {}, 337 | "output_type": "execute_result" 338 | } 339 | ], 340 | "source": [ 341 | "data = pd.DataFrame(dict(hypothesis=hypotheses, reference=references))\n", 342 | "data" 343 | ] 344 | }, 345 | { 346 | "cell_type": "markdown", 347 | "metadata": { 348 | "id": "HPppEJRXX4ox" 349 | }, 350 | "source": [ 351 | "# Calculating the word error rate\n", 352 | "\n", 353 | "Now, we use our English normalizer implementation to standardize the transcription and calculate the WER." 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": 9, 359 | "metadata": { 360 | "id": "dl-KBDflMhrg" 361 | }, 362 | "outputs": [], 363 | "source": [ 364 | "import jiwer\n", 365 | "from whisper.normalizers import EnglishTextNormalizer\n", 366 | "\n", 367 | "normalizer = EnglishTextNormalizer()" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": 10, 373 | "metadata": { 374 | "colab": { 375 | "base_uri": "https://localhost:8080/", 376 | "height": 641 377 | }, 378 | "id": "6-O048q4WI4o", 379 | "outputId": "f2089bc9-f535-441e-f192-26e52ae82b5e" 380 | }, 381 | "outputs": [ 382 | { 383 | "data": { 384 | "text/html": [ 385 | "
\n", 386 | "\n", 399 | "\n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | "
hypothesisreferencehypothesis_cleanreference_clean
0He hoped there would be stew for dinner, turni...HE HOPED THERE WOULD BE STEW FOR DINNER TURNIP...he hoped there would be stew for dinner turnip...he hoped there would be stew for dinner turnip...
1Stuffered into you, his belly counseled him.STUFF IT INTO YOU HIS BELLY COUNSELLED HIMstuffered into you his belly counseled himstuff it into you his belly counseled him
2After early nightfall the yellow lamps would l...AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD L...after early nightfall the yellow lamps would l...after early nightfall the yellow lamps would l...
3Hello Bertie, any good in your mind?HELLO BERTIE ANY GOOD IN YOUR MINDhello bertie any good in your mindhello bertie any good in your mind
4Number 10. Fresh Nelly is waiting on you. Good...NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD ...number 10 fresh nelly is waiting on you good n...number 10 fresh nelly is waiting on you good n...
...............
2615Oh, to shoot my soul's full meaning into futur...OH TO SHOOT MY SOUL'S FULL MEANING INTO FUTURE...0 to shoot my soul is full meaning into future...0 to shoot my soul is full meaning into future...
2616Then I, long tried by natural ills, received t...THEN I LONG TRIED BY NATURAL ILLS RECEIVED THE...then i long tried by natural ills received the...then i long tried by natural ills received the...
2617I love thee freely as men strive for right. I ...I LOVE THEE FREELY AS MEN STRIVE FOR RIGHT I L...i love thee freely as men strive for right i l...i love thee freely as men strive for right i l...
2618I love thee with the passion put to use, in my...I LOVE THEE WITH THE PASSION PUT TO USE IN MY ...i love thee with the passion put to use in my ...i love thee with the passion put to use in my ...
2619I love thee with the love I seemed to lose wit...I LOVE THEE WITH A LOVE I SEEMED TO LOSE WITH ...i love thee with the love i seemed to lose wit...i love thee with a love i seemed to lose with ...
\n", 489 | "

2620 rows × 4 columns

\n", 490 | "
" 491 | ], 492 | "text/plain": [ 493 | " hypothesis \\\n", 494 | "0 He hoped there would be stew for dinner, turni... \n", 495 | "1 Stuffered into you, his belly counseled him. \n", 496 | "2 After early nightfall the yellow lamps would l... \n", 497 | "3 Hello Bertie, any good in your mind? \n", 498 | "4 Number 10. Fresh Nelly is waiting on you. Good... \n", 499 | "... ... \n", 500 | "2615 Oh, to shoot my soul's full meaning into futur... \n", 501 | "2616 Then I, long tried by natural ills, received t... \n", 502 | "2617 I love thee freely as men strive for right. I ... \n", 503 | "2618 I love thee with the passion put to use, in my... \n", 504 | "2619 I love thee with the love I seemed to lose wit... \n", 505 | "\n", 506 | " reference \\\n", 507 | "0 HE HOPED THERE WOULD BE STEW FOR DINNER TURNIP... \n", 508 | "1 STUFF IT INTO YOU HIS BELLY COUNSELLED HIM \n", 509 | "2 AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD L... \n", 510 | "3 HELLO BERTIE ANY GOOD IN YOUR MIND \n", 511 | "4 NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD ... \n", 512 | "... ... \n", 513 | "2615 OH TO SHOOT MY SOUL'S FULL MEANING INTO FUTURE... \n", 514 | "2616 THEN I LONG TRIED BY NATURAL ILLS RECEIVED THE... \n", 515 | "2617 I LOVE THEE FREELY AS MEN STRIVE FOR RIGHT I L... \n", 516 | "2618 I LOVE THEE WITH THE PASSION PUT TO USE IN MY ... \n", 517 | "2619 I LOVE THEE WITH A LOVE I SEEMED TO LOSE WITH ... \n", 518 | "\n", 519 | " hypothesis_clean \\\n", 520 | "0 he hoped there would be stew for dinner turnip... \n", 521 | "1 stuffered into you his belly counseled him \n", 522 | "2 after early nightfall the yellow lamps would l... \n", 523 | "3 hello bertie any good in your mind \n", 524 | "4 number 10 fresh nelly is waiting on you good n... \n", 525 | "... ... \n", 526 | "2615 0 to shoot my soul is full meaning into future... \n", 527 | "2616 then i long tried by natural ills received the... \n", 528 | "2617 i love thee freely as men strive for right i l... \n", 529 | "2618 i love thee with the passion put to use in my ... \n", 530 | "2619 i love thee with the love i seemed to lose wit... \n", 531 | "\n", 532 | " reference_clean \n", 533 | "0 he hoped there would be stew for dinner turnip... \n", 534 | "1 stuff it into you his belly counseled him \n", 535 | "2 after early nightfall the yellow lamps would l... \n", 536 | "3 hello bertie any good in your mind \n", 537 | "4 number 10 fresh nelly is waiting on you good n... \n", 538 | "... ... \n", 539 | "2615 0 to shoot my soul is full meaning into future... \n", 540 | "2616 then i long tried by natural ills received the... \n", 541 | "2617 i love thee freely as men strive for right i l... \n", 542 | "2618 i love thee with the passion put to use in my ... \n", 543 | "2619 i love thee with a love i seemed to lose with ... \n", 544 | "\n", 545 | "[2620 rows x 4 columns]" 546 | ] 547 | }, 548 | "execution_count": 10, 549 | "metadata": {}, 550 | "output_type": "execute_result" 551 | } 552 | ], 553 | "source": [ 554 | "data[\"hypothesis_clean\"] = [normalizer(text) for text in data[\"hypothesis\"]]\n", 555 | "data[\"reference_clean\"] = [normalizer(text) for text in data[\"reference\"]]\n", 556 | "data" 557 | ] 558 | }, 559 | { 560 | "cell_type": "code", 561 | "execution_count": 11, 562 | "metadata": { 563 | "colab": { 564 | "base_uri": "https://localhost:8080/" 565 | }, 566 | "id": "EBGSITeBYPTT", 567 | "outputId": "7b3dbe7c-a37e-4a07-a50a-b27d5f88b68f" 568 | }, 569 | "outputs": [ 570 | { 571 | "name": "stdout", 572 | "output_type": "stream", 573 | "text": [ 574 | "WER: 4.26 %\n" 575 | ] 576 | } 577 | ], 578 | "source": [ 579 | "wer = jiwer.wer(list(data[\"reference_clean\"]), list(data[\"hypothesis_clean\"]))\n", 580 | "\n", 581 | "print(f\"WER: {wer * 100:.2f} %\")" 582 | ] 583 | } 584 | ], 585 | "metadata": { 586 | "accelerator": "GPU", 587 | "colab": { 588 | "collapsed_sections": [], 589 | "provenance": [] 590 | }, 591 | "gpuClass": "standard", 592 | "kernelspec": { 593 | "display_name": "Python 3 (ipykernel)", 594 | "language": "python", 595 | "name": "python3" 596 | }, 597 | "language_info": { 598 | "codemirror_mode": { 599 | "name": "ipython", 600 | "version": 3 601 | }, 602 | "file_extension": ".py", 603 | "mimetype": "text/x-python", 604 | "name": "python", 605 | "nbconvert_exporter": "python", 606 | "pygments_lexer": "ipython3", 607 | "version": "3.9.9" 608 | }, 609 | "widgets": { 610 | "application/vnd.jupyter.widget-state+json": { 611 | "039b53f2702c4179af7e0548018d0588": { 612 | "model_module": "@jupyter-widgets/controls", 613 | "model_module_version": "1.5.0", 614 | "model_name": "DescriptionStyleModel", 615 | "state": { 616 | "_model_module": "@jupyter-widgets/controls", 617 | "_model_module_version": "1.5.0", 618 | "_model_name": "DescriptionStyleModel", 619 | "_view_count": null, 620 | "_view_module": "@jupyter-widgets/base", 621 | "_view_module_version": "1.2.0", 622 | "_view_name": "StyleView", 623 | "description_width": "" 624 | } 625 | }, 626 | "06b9aa5f49fa44ba8c93b647dc7db224": { 627 | "model_module": "@jupyter-widgets/controls", 628 | "model_module_version": "1.5.0", 629 | "model_name": "FloatProgressModel", 630 | "state": { 631 | "_dom_classes": [], 632 | "_model_module": "@jupyter-widgets/controls", 633 | "_model_module_version": "1.5.0", 634 | "_model_name": "FloatProgressModel", 635 | "_view_count": null, 636 | "_view_module": "@jupyter-widgets/controls", 637 | "_view_module_version": "1.5.0", 638 | "_view_name": "ProgressView", 639 | "bar_style": "success", 640 | "description": "", 641 | "description_tooltip": null, 642 | "layout": "IPY_MODEL_a0d10a42c753453283e5219c22239337", 643 | "max": 164, 644 | "min": 0, 645 | "orientation": "horizontal", 646 | "style": "IPY_MODEL_09f4cb79ff86465aaf48b0de24869af9", 647 | "value": 164 648 | } 649 | }, 650 | "09a29a91f58d4462942505a3cc415801": { 651 | "model_module": "@jupyter-widgets/controls", 652 | "model_module_version": "1.5.0", 653 | "model_name": "HBoxModel", 654 | "state": { 655 | "_dom_classes": [], 656 | "_model_module": "@jupyter-widgets/controls", 657 | "_model_module_version": "1.5.0", 658 | "_model_name": "HBoxModel", 659 | "_view_count": null, 660 | "_view_module": "@jupyter-widgets/controls", 661 | "_view_module_version": "1.5.0", 662 | "_view_name": "HBoxView", 663 | "box_style": "", 664 | "children": [ 665 | "IPY_MODEL_83391f98a240490987c397048fc1a0d4", 666 | "IPY_MODEL_06b9aa5f49fa44ba8c93b647dc7db224", 667 | "IPY_MODEL_da9c231ee67047fb89073c95326b72a5" 668 | ], 669 | "layout": "IPY_MODEL_48da931ebe7f4fd299f8c98c7d2460ff" 670 | } 671 | }, 672 | "09f4cb79ff86465aaf48b0de24869af9": { 673 | "model_module": "@jupyter-widgets/controls", 674 | "model_module_version": "1.5.0", 675 | "model_name": "ProgressStyleModel", 676 | "state": { 677 | "_model_module": "@jupyter-widgets/controls", 678 | "_model_module_version": "1.5.0", 679 | "_model_name": "ProgressStyleModel", 680 | "_view_count": null, 681 | "_view_module": "@jupyter-widgets/base", 682 | "_view_module_version": "1.2.0", 683 | "_view_name": "StyleView", 684 | "bar_color": null, 685 | "description_width": "" 686 | } 687 | }, 688 | "1b9cecf5b3584fba8258a81d4279a25b": { 689 | "model_module": "@jupyter-widgets/base", 690 | "model_module_version": "1.2.0", 691 | "model_name": "LayoutModel", 692 | "state": { 693 | "_model_module": "@jupyter-widgets/base", 694 | "_model_module_version": "1.2.0", 695 | "_model_name": "LayoutModel", 696 | "_view_count": null, 697 | "_view_module": "@jupyter-widgets/base", 698 | "_view_module_version": "1.2.0", 699 | "_view_name": "LayoutView", 700 | "align_content": null, 701 | "align_items": null, 702 | "align_self": null, 703 | "border": null, 704 | "bottom": null, 705 | "display": null, 706 | "flex": null, 707 | "flex_flow": null, 708 | "grid_area": null, 709 | "grid_auto_columns": null, 710 | "grid_auto_flow": null, 711 | "grid_auto_rows": null, 712 | "grid_column": null, 713 | "grid_gap": null, 714 | "grid_row": null, 715 | "grid_template_areas": null, 716 | "grid_template_columns": null, 717 | "grid_template_rows": null, 718 | "height": null, 719 | "justify_content": null, 720 | "justify_items": null, 721 | "left": null, 722 | "margin": null, 723 | "max_height": null, 724 | "max_width": null, 725 | "min_height": null, 726 | "min_width": null, 727 | "object_fit": null, 728 | "object_position": null, 729 | "order": null, 730 | "overflow": null, 731 | "overflow_x": null, 732 | "overflow_y": null, 733 | "padding": null, 734 | "right": null, 735 | "top": null, 736 | "visibility": null, 737 | "width": null 738 | } 739 | }, 740 | "39f5a6ae8ba74c8598f9c6d5b8ad2d65": { 741 | "model_module": "@jupyter-widgets/controls", 742 | "model_module_version": "1.5.0", 743 | "model_name": "DescriptionStyleModel", 744 | "state": { 745 | "_model_module": "@jupyter-widgets/controls", 746 | "_model_module_version": "1.5.0", 747 | "_model_name": "DescriptionStyleModel", 748 | "_view_count": null, 749 | "_view_module": "@jupyter-widgets/base", 750 | "_view_module_version": "1.2.0", 751 | "_view_name": "StyleView", 752 | "description_width": "" 753 | } 754 | }, 755 | "48da931ebe7f4fd299f8c98c7d2460ff": { 756 | "model_module": "@jupyter-widgets/base", 757 | "model_module_version": "1.2.0", 758 | "model_name": "LayoutModel", 759 | "state": { 760 | "_model_module": "@jupyter-widgets/base", 761 | "_model_module_version": "1.2.0", 762 | "_model_name": "LayoutModel", 763 | "_view_count": null, 764 | "_view_module": "@jupyter-widgets/base", 765 | "_view_module_version": "1.2.0", 766 | "_view_name": "LayoutView", 767 | "align_content": null, 768 | "align_items": null, 769 | "align_self": null, 770 | "border": null, 771 | "bottom": null, 772 | "display": null, 773 | "flex": null, 774 | "flex_flow": null, 775 | "grid_area": null, 776 | "grid_auto_columns": null, 777 | "grid_auto_flow": null, 778 | "grid_auto_rows": null, 779 | "grid_column": null, 780 | "grid_gap": null, 781 | "grid_row": null, 782 | "grid_template_areas": null, 783 | "grid_template_columns": null, 784 | "grid_template_rows": null, 785 | "height": null, 786 | "justify_content": null, 787 | "justify_items": null, 788 | "left": null, 789 | "margin": null, 790 | "max_height": null, 791 | "max_width": null, 792 | "min_height": null, 793 | "min_width": null, 794 | "object_fit": null, 795 | "object_position": null, 796 | "order": null, 797 | "overflow": null, 798 | "overflow_x": null, 799 | "overflow_y": null, 800 | "padding": null, 801 | "right": null, 802 | "top": null, 803 | "visibility": null, 804 | "width": null 805 | } 806 | }, 807 | "7a901f447c1d477bb49f954e0feacedd": { 808 | "model_module": "@jupyter-widgets/base", 809 | "model_module_version": "1.2.0", 810 | "model_name": "LayoutModel", 811 | "state": { 812 | "_model_module": "@jupyter-widgets/base", 813 | "_model_module_version": "1.2.0", 814 | "_model_name": "LayoutModel", 815 | "_view_count": null, 816 | "_view_module": "@jupyter-widgets/base", 817 | "_view_module_version": "1.2.0", 818 | "_view_name": "LayoutView", 819 | "align_content": null, 820 | "align_items": null, 821 | "align_self": null, 822 | "border": null, 823 | "bottom": null, 824 | "display": null, 825 | "flex": null, 826 | "flex_flow": null, 827 | "grid_area": null, 828 | "grid_auto_columns": null, 829 | "grid_auto_flow": null, 830 | "grid_auto_rows": null, 831 | "grid_column": null, 832 | "grid_gap": null, 833 | "grid_row": null, 834 | "grid_template_areas": null, 835 | "grid_template_columns": null, 836 | "grid_template_rows": null, 837 | "height": null, 838 | "justify_content": null, 839 | "justify_items": null, 840 | "left": null, 841 | "margin": null, 842 | "max_height": null, 843 | "max_width": null, 844 | "min_height": null, 845 | "min_width": null, 846 | "object_fit": null, 847 | "object_position": null, 848 | "order": null, 849 | "overflow": null, 850 | "overflow_x": null, 851 | "overflow_y": null, 852 | "padding": null, 853 | "right": null, 854 | "top": null, 855 | "visibility": null, 856 | "width": null 857 | } 858 | }, 859 | "83391f98a240490987c397048fc1a0d4": { 860 | "model_module": "@jupyter-widgets/controls", 861 | "model_module_version": "1.5.0", 862 | "model_name": "HTMLModel", 863 | "state": { 864 | "_dom_classes": [], 865 | "_model_module": "@jupyter-widgets/controls", 866 | "_model_module_version": "1.5.0", 867 | "_model_name": "HTMLModel", 868 | "_view_count": null, 869 | "_view_module": "@jupyter-widgets/controls", 870 | "_view_module_version": "1.5.0", 871 | "_view_name": "HTMLView", 872 | "description": "", 873 | "description_tooltip": null, 874 | "layout": "IPY_MODEL_7a901f447c1d477bb49f954e0feacedd", 875 | "placeholder": "​", 876 | "style": "IPY_MODEL_39f5a6ae8ba74c8598f9c6d5b8ad2d65", 877 | "value": "100%" 878 | } 879 | }, 880 | "a0d10a42c753453283e5219c22239337": { 881 | "model_module": "@jupyter-widgets/base", 882 | "model_module_version": "1.2.0", 883 | "model_name": "LayoutModel", 884 | "state": { 885 | "_model_module": "@jupyter-widgets/base", 886 | "_model_module_version": "1.2.0", 887 | "_model_name": "LayoutModel", 888 | "_view_count": null, 889 | "_view_module": "@jupyter-widgets/base", 890 | "_view_module_version": "1.2.0", 891 | "_view_name": "LayoutView", 892 | "align_content": null, 893 | "align_items": null, 894 | "align_self": null, 895 | "border": null, 896 | "bottom": null, 897 | "display": null, 898 | "flex": null, 899 | "flex_flow": null, 900 | "grid_area": null, 901 | "grid_auto_columns": null, 902 | "grid_auto_flow": null, 903 | "grid_auto_rows": null, 904 | "grid_column": null, 905 | "grid_gap": null, 906 | "grid_row": null, 907 | "grid_template_areas": null, 908 | "grid_template_columns": null, 909 | "grid_template_rows": null, 910 | "height": null, 911 | "justify_content": null, 912 | "justify_items": null, 913 | "left": null, 914 | "margin": null, 915 | "max_height": null, 916 | "max_width": null, 917 | "min_height": null, 918 | "min_width": null, 919 | "object_fit": null, 920 | "object_position": null, 921 | "order": null, 922 | "overflow": null, 923 | "overflow_x": null, 924 | "overflow_y": null, 925 | "padding": null, 926 | "right": null, 927 | "top": null, 928 | "visibility": null, 929 | "width": null 930 | } 931 | }, 932 | "da9c231ee67047fb89073c95326b72a5": { 933 | "model_module": "@jupyter-widgets/controls", 934 | "model_module_version": "1.5.0", 935 | "model_name": "HTMLModel", 936 | "state": { 937 | "_dom_classes": [], 938 | "_model_module": "@jupyter-widgets/controls", 939 | "_model_module_version": "1.5.0", 940 | "_model_name": "HTMLModel", 941 | "_view_count": null, 942 | "_view_module": "@jupyter-widgets/controls", 943 | "_view_module_version": "1.5.0", 944 | "_view_name": "HTMLView", 945 | "description": "", 946 | "description_tooltip": null, 947 | "layout": "IPY_MODEL_1b9cecf5b3584fba8258a81d4279a25b", 948 | "placeholder": "​", 949 | "style": "IPY_MODEL_039b53f2702c4179af7e0548018d0588", 950 | "value": " 164/164 [05:08<00:00, 1.86s/it]" 951 | } 952 | } 953 | } 954 | } 955 | }, 956 | "nbformat": 4, 957 | "nbformat_minor": 1 958 | } 959 | -------------------------------------------------------------------------------- /notebooks/efficient_whisper.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "327184b6-d10f-4d26-ada1-f9d3c6f1ccba", 6 | "metadata": {}, 7 | "source": [ 8 | "# Create test data" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "3be3b26c-c4f6-489f-bc32-f5c4f9fa9664", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import librosa\n", 19 | "from datasets import load_dataset\n", 20 | "\n", 21 | "\n", 22 | "common_voice = load_dataset(\"common_voice\", \"ja\")\n", 23 | "audio_data_list = [\n", 24 | " librosa.resample(\n", 25 | " common_voice['train'][i]['audio']['array'], orig_sr=48000, target_sr=16000\n", 26 | " ) for i in range(10)]" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "id": "e8c00c8b-934f-4c75-92cb-f68d0d5ec638", 32 | "metadata": {}, 33 | "source": [ 34 | "# Official Whisper" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "id": "2ef5d03e-d1c2-4125-9c63-07138556d610", 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "import whisper\n", 45 | "\n", 46 | "model = whisper.load_model(\"large\")" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 4, 52 | "id": "b3b84c56-f46e-4e25-82e7-41a0f2b9dd69", 53 | "metadata": { 54 | "scrolled": true, 55 | "tags": [] 56 | }, 57 | "outputs": [ 58 | { 59 | "name": "stdout", 60 | "output_type": "stream", 61 | "text": [ 62 | "[00:00.000 --> 00:30.000] 予想外の事態に電力会社がちょっぴり困惑切りだ\n", 63 | "[00:00.000 --> 00:30.000] 町域にあった峰山藩は長岡藩に米100票を送ったことで有名。\n", 64 | "[00:00.000 --> 00:30.000] 週末 友達と山に登ります\n", 65 | "[00:00.000 --> 00:30.000] 後で図書館へ本を返しに行きます。\n", 66 | "[00:00.000 --> 00:30.000] 55歳だって嬉しい時が嬉しいのだ\n", 67 | "[00:00.000 --> 00:30.000] 私はパンもご飯も好きです。\n", 68 | "[00:00.000 --> 00:30.000] デパートやスーパーで買い物をします\n", 69 | "[00:00.000 --> 00:30.000] 用紙に書いてある番号を覚えます。\n", 70 | "[00:00.000 --> 00:30.000] 明日 友達と 映画を 見に行きます。\n", 71 | "[00:00.000 --> 00:30.000] あの男の人は背が高くて足が長いです。\n", 72 | "CPU times: user 26.9 s, sys: 216 ms, total: 27.1 s\n", 73 | "Wall time: 13.6 s\n" 74 | ] 75 | } 76 | ], 77 | "source": [ 78 | "%%time\n", 79 | "\n", 80 | "for audio_data in audio_data_list:\n", 81 | " result = model.transcribe(\n", 82 | " audio_data,\n", 83 | " verbose=True,\n", 84 | " language='japanese',\n", 85 | " beam_size=5,\n", 86 | " fp16=True,\n", 87 | " without_timestamps=True\n", 88 | " )" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "id": "ff0dbac2-c456-43cf-a58c-2d130c123f5e", 94 | "metadata": {}, 95 | "source": [ 96 | "# Official Whisper with model.half()" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "id": "04045816-ca87-444d-84e6-19bf3ef7bbab", 102 | "metadata": {}, 103 | "source": [ 104 | "### We will get a little faster and a large memory improvement (12G -> 6G)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 2, 110 | "id": "6e902c8c-255f-46ce-9524-c1939812edb7", 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "import whisper\n", 115 | "\n", 116 | "\n", 117 | "model = whisper.load_model(\"large\", device=\"cpu\")\n", 118 | "_ = model.half()\n", 119 | "_ = model.cuda()\n", 120 | "\n", 121 | "# exception without following code\n", 122 | "# reason : model.py -> line 31 -> super().forward(x.float()).type(x.dtype)\n", 123 | "for m in model.modules():\n", 124 | " if isinstance(m, whisper.model.LayerNorm):\n", 125 | " m.float()" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 4, 131 | "id": "2516d80b-bd44-40b7-81db-9c00e94d570d", 132 | "metadata": {}, 133 | "outputs": [ 134 | { 135 | "name": "stdout", 136 | "output_type": "stream", 137 | "text": [ 138 | "[00:00.000 --> 00:30.000] 予想外の事態に電力会社がちょっぴり困惑切りだ\n", 139 | "[00:00.000 --> 00:30.000] 町域にあった峰山藩は長岡藩に米100票を送ったことで有名。\n", 140 | "[00:00.000 --> 00:30.000] 週末 友達と山に登ります\n", 141 | "[00:00.000 --> 00:30.000] 後で図書館へ本を返しに行きます。\n", 142 | "[00:00.000 --> 00:30.000] 55歳だって嬉しい時が嬉しいのだ\n", 143 | "[00:00.000 --> 00:30.000] 私はパンもご飯も好きです。\n", 144 | "[00:00.000 --> 00:30.000] デパートやスーパーで買い物をします\n", 145 | "[00:00.000 --> 00:30.000] 用紙に書いてある番号を覚えます。\n", 146 | "[00:00.000 --> 00:30.000] 明日 友達と 映画を 見に行きます。\n", 147 | "[00:00.000 --> 00:30.000] あの男の人は背が高くて足が長いです。\n", 148 | "CPU times: user 25.1 s, sys: 130 ms, total: 25.3 s\n", 149 | "Wall time: 12.5 s\n" 150 | ] 151 | } 152 | ], 153 | "source": [ 154 | "%%time\n", 155 | "\n", 156 | "for audio_data in audio_data_list:\n", 157 | " result = model.transcribe(\n", 158 | " audio_data,\n", 159 | " verbose=True,\n", 160 | " language='japanese',\n", 161 | " beam_size=5,\n", 162 | " fp16=True,\n", 163 | " without_timestamps=True\n", 164 | " )" 165 | ] 166 | }, 167 | { 168 | "cell_type": "markdown", 169 | "id": "8148420d-8ee7-4762-9558-9364259b08b7", 170 | "metadata": {}, 171 | "source": [ 172 | "# Whisper with TorchScript" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 2, 178 | "id": "e7991148-03fd-42a0-ad89-50cf6bfbb163", 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "import torch\n", 183 | "import efficient_whisper as whisper" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 3, 189 | "id": "d1d14332-53b8-442e-a14c-f51b2f48032b", 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "model = whisper.load_model(\"large\", device=\"cpu\")\n", 194 | "model.encoder = torch.jit.script(model.encoder)\n", 195 | "model.decoder = torch.jit.script(model.decoder)\n", 196 | "_ = model.half()\n", 197 | "_ = model.cuda()" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 5, 203 | "id": "68fcc2f7-ffe7-47d9-9f3b-cadd0e7b50f0", 204 | "metadata": {}, 205 | "outputs": [ 206 | { 207 | "name": "stdout", 208 | "output_type": "stream", 209 | "text": [ 210 | "[00:00.000 --> 00:30.000] 予想外の事態に電力会社がちょっぴり困惑切りだ\n", 211 | "[00:00.000 --> 00:30.000] 町域にあった峰山藩は長岡藩に米100票を送ったことで有名。\n", 212 | "[00:00.000 --> 00:30.000] 週末 友達と山に登ります\n", 213 | "[00:00.000 --> 00:30.000] 後で図書館へ本を返しに行きます。\n", 214 | "[00:00.000 --> 00:30.000] 55歳だって嬉しい時が嬉しいのだ\n", 215 | "[00:00.000 --> 00:30.000] 私はパンもご飯も好きです。\n", 216 | "[00:00.000 --> 00:30.000] デパートやスーパーで買い物をします\n", 217 | "[00:00.000 --> 00:30.000] 用紙に書いてある番号を覚えます。\n", 218 | "[00:00.000 --> 00:30.000] 明日 友達と 映画を 見に行きます。\n", 219 | "[00:00.000 --> 00:30.000] あの男の人は背が高くて足が長いです。\n", 220 | "CPU times: user 20.8 s, sys: 425 ms, total: 21.2 s\n", 221 | "Wall time: 8.59 s\n" 222 | ] 223 | } 224 | ], 225 | "source": [ 226 | "%%time\n", 227 | "\n", 228 | "for audio_data in audio_data_list:\n", 229 | " result = model.transcribe(\n", 230 | " audio_data,\n", 231 | " verbose=True,\n", 232 | " language='japanese',\n", 233 | " beam_size=5,\n", 234 | " fp16=True,\n", 235 | " without_timestamps=True\n", 236 | " )" 237 | ] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "id": "517e19a2-834c-4f63-bfa4-67e4b00d9a48", 242 | "metadata": {}, 243 | "source": [ 244 | "# Whisper with TorchScript & pad_or_trim (30s -> 10s)" 245 | ] 246 | }, 247 | { 248 | "cell_type": "markdown", 249 | "id": "5e217caf-bbd5-47a0-b9fe-02c182b28cb3", 250 | "metadata": {}, 251 | "source": [ 252 | "Fix `CHUNK_LENGTH` in audio.py: `30 -> 10`" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": 2, 258 | "id": "86a66958-7e68-4d61-ae7f-7defb7cb3e7e", 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "import torch\n", 263 | "import efficient_whisper as whisper" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": 3, 269 | "id": "f0829855-2fa1-45fd-a6f0-a86cb3053bcc", 270 | "metadata": {}, 271 | "outputs": [], 272 | "source": [ 273 | "checkpoint = torch.load('/home/ubuntu/.cache/whisper/large.pt', map_location='cpu')\n", 274 | "dims = whisper.model.ModelDimensions(**checkpoint[\"dims\"])\n", 275 | "dims.n_audio_ctx = 500 # 10s\n", 276 | "\n", 277 | "model = whisper.model.Whisper(dims)\n", 278 | "for k, p in model.state_dict().items():\n", 279 | " p.copy_(checkpoint[\"model_state_dict\"][k])\n", 280 | "\n", 281 | "model.encoder = torch.jit.script(model.encoder)\n", 282 | "model.decoder = torch.jit.script(model.decoder)\n", 283 | "_ = model.half()\n", 284 | "_ = model.cuda()" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 5, 290 | "id": "8766750c-01fe-4515-8ffa-ae408ecd15b5", 291 | "metadata": {}, 292 | "outputs": [ 293 | { 294 | "name": "stdout", 295 | "output_type": "stream", 296 | "text": [ 297 | "[00:00.000 --> 00:10.000] 予想外の事態に電力会社がちょっぴり手を巻く気味が\n", 298 | "[00:00.000 --> 00:10.000] 町域にあった峰山藩は、長岡藩に米100票を送ったことで有名。\n", 299 | "[00:00.000 --> 00:10.000] 週末友達と山に登ります。\n", 300 | "[00:00.000 --> 00:10.000] 後で図書館へ本を返しに行きます。\n", 301 | "[00:00.000 --> 00:10.000] 55歳だって嬉しい時が嬉しいのだ\n", 302 | "[00:00.000 --> 00:10.000] 私はパンもご飯も好きです。\n", 303 | "[00:00.000 --> 00:10.000] デパートやスーパーで買い物をします。\n", 304 | "[00:00.000 --> 00:10.000] 用紙に書いてある番号を覚えます。\n", 305 | "[00:00.000 --> 00:10.000] 明日、友達と映画を見に行きます。\n", 306 | "[00:00.000 --> 00:10.000] あの男の人は背が高くて足が長いです。\n", 307 | "CPU times: user 20.3 s, sys: 3.61 ms, total: 20.3 s\n", 308 | "Wall time: 6.77 s\n" 309 | ] 310 | } 311 | ], 312 | "source": [ 313 | "%%time\n", 314 | "\n", 315 | "for audio_data in audio_data_list:\n", 316 | " result = model.transcribe(\n", 317 | " audio_data,\n", 318 | " verbose=True,\n", 319 | " language='japanese',\n", 320 | " beam_size=5,\n", 321 | " fp16=True,\n", 322 | " without_timestamps=True\n", 323 | " )" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": null, 329 | "id": "b83a4a26-5d28-4403-88c7-0415d9775ca0", 330 | "metadata": {}, 331 | "outputs": [], 332 | "source": [] 333 | } 334 | ], 335 | "metadata": { 336 | "kernelspec": { 337 | "display_name": "Python 3 (ipykernel)", 338 | "language": "python", 339 | "name": "python3" 340 | }, 341 | "language_info": { 342 | "codemirror_mode": { 343 | "name": "ipython", 344 | "version": 3 345 | }, 346 | "file_extension": ".py", 347 | "mimetype": "text/x-python", 348 | "name": "python", 349 | "nbconvert_exporter": "python", 350 | "pygments_lexer": "ipython3", 351 | "version": "3.8.6" 352 | } 353 | }, 354 | "nbformat": 4, 355 | "nbformat_minor": 5 356 | } 357 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch 3 | tqdm 4 | more-itertools 5 | transformers>=4.19.0 6 | ffmpeg-python==0.2.0 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pkg_resources 4 | from setuptools import setup, find_packages 5 | 6 | setup( 7 | name="whisper", 8 | py_modules=["whisper"], 9 | version="1.0", 10 | description="Robust Speech Recognition via Large-Scale Weak Supervision", 11 | readme="README.md", 12 | python_requires=">=3.7", 13 | author="OpenAI", 14 | url="https://github.com/openai/whisper", 15 | license="MIT", 16 | packages=find_packages(exclude=["tests*"]), 17 | install_requires=[ 18 | str(r) 19 | for r in pkg_resources.parse_requirements( 20 | open(os.path.join(os.path.dirname(__file__), "requirements.txt")) 21 | ) 22 | ], 23 | entry_points = { 24 | 'console_scripts': ['whisper=whisper.transcribe:cli'], 25 | }, 26 | include_package_data=True, 27 | extras_require={'dev': ['pytest']}, 28 | ) 29 | -------------------------------------------------------------------------------- /tests/jfk.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/projectlucas/efficient_whisper/ed0c6cefbab7cad208ca33c91facd5674e1101a7/tests/jfk.flac -------------------------------------------------------------------------------- /tests/test_audio.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | import numpy as np 4 | 5 | from whisper.audio import load_audio, log_mel_spectrogram, SAMPLE_RATE 6 | 7 | 8 | def test_audio(): 9 | audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac") 10 | audio = load_audio(audio_path) 11 | assert audio.ndim == 1 12 | assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12 13 | assert 0 < audio.std() < 1 14 | 15 | mel_from_audio = log_mel_spectrogram(audio) 16 | mel_from_file = log_mel_spectrogram(audio_path) 17 | 18 | assert np.allclose(mel_from_audio, mel_from_file) 19 | assert mel_from_audio.max() - mel_from_audio.min() <= 2.0 20 | -------------------------------------------------------------------------------- /tests/test_normalizer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from whisper.normalizers import EnglishTextNormalizer 4 | from whisper.normalizers.english import EnglishNumberNormalizer, EnglishSpellingNormalizer 5 | 6 | 7 | @pytest.mark.parametrize("std", [EnglishNumberNormalizer(), EnglishTextNormalizer()]) 8 | def test_number_normalizer(std): 9 | assert std("two") == "2" 10 | assert std("thirty one") == "31" 11 | assert std("five twenty four") == "524" 12 | assert std("nineteen ninety nine") == "1999" 13 | assert std("twenty nineteen") == "2019" 14 | 15 | assert std("two point five million") == "2500000" 16 | assert std("four point two billions") == "4200000000s" 17 | assert std("200 thousand") == "200000" 18 | assert std("200 thousand dollars") == "$200000" 19 | assert std("$20 million") == "$20000000" 20 | assert std("€52.4 million") == "€52400000" 21 | assert std("£77 thousands") == "£77000s" 22 | 23 | assert std("two double o eight") == "2008" 24 | 25 | assert std("three thousand twenty nine") == "3029" 26 | assert std("forty three thousand two hundred sixty") == "43260" 27 | assert std("forty three thousand two hundred and sixty") == "43260" 28 | 29 | assert std("nineteen fifties") == "1950s" 30 | assert std("thirty first") == "31st" 31 | assert std("thirty three thousand and three hundred and thirty third") == "33333rd" 32 | 33 | assert std("three billion") == "3000000000" 34 | assert std("millions") == "1000000s" 35 | 36 | assert std("july third twenty twenty") == "july 3rd 2020" 37 | assert std("august twenty sixth twenty twenty one") == "august 26th 2021" 38 | assert std("3 14") == "3 14" 39 | assert std("3.14") == "3.14" 40 | assert std("3 point 2") == "3.2" 41 | assert std("3 point 14") == "3.14" 42 | assert std("fourteen point 4") == "14.4" 43 | assert std("two point two five dollars") == "$2.25" 44 | assert std("two hundred million dollars") == "$200000000" 45 | assert std("$20.1 million") == "$20100000" 46 | 47 | assert std("ninety percent") == "90%" 48 | assert std("seventy six per cent") == "76%" 49 | 50 | assert std("double oh seven") == "007" 51 | assert std("double zero seven") == "007" 52 | assert std("nine one one") == "911" 53 | assert std("nine double one") == "911" 54 | assert std("one triple oh one") == "10001" 55 | 56 | assert std("two thousandth") == "2000th" 57 | assert std("thirty two thousandth") == "32000th" 58 | 59 | assert std("minus 500") == "-500" 60 | assert std("positive twenty thousand") == "+20000" 61 | 62 | assert std("two dollars and seventy cents") == "$2.70" 63 | assert std("3 cents") == "¢3" 64 | assert std("$0.36") == "¢36" 65 | assert std("three euros and sixty five cents") == "€3.65" 66 | 67 | assert std("three and a half million") == "3500000" 68 | assert std("forty eight and a half dollars") == "$48.5" 69 | assert std("b747") == "b 747" 70 | assert std("10 th") == "10th" 71 | assert std("10th") == "10th" 72 | 73 | 74 | def test_spelling_normalizer(): 75 | std = EnglishSpellingNormalizer() 76 | 77 | assert std("mobilisation") == "mobilization" 78 | assert std("cancelation") == "cancellation" 79 | 80 | 81 | def test_text_normalizer(): 82 | std = EnglishTextNormalizer() 83 | assert std("Let's") == "let us" 84 | assert std("he's like") == "he is like" 85 | assert std("she's been like") == "she has been like" 86 | assert std("10km") == "10 km" 87 | assert std("RC232") == "rc 232" 88 | 89 | assert ( 90 | std("Mr. Park visited Assoc. Prof. Kim Jr.") 91 | == "mister park visited associate professor kim junior" 92 | ) 93 | -------------------------------------------------------------------------------- /tests/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | from whisper.tokenizer import get_tokenizer 2 | 3 | 4 | def test_tokenizer(): 5 | gpt2_tokenizer = get_tokenizer(multilingual=False) 6 | multilingual_tokenizer = get_tokenizer(multilingual=True) 7 | 8 | text = "다람쥐 헌 쳇바퀴에 타고파" 9 | gpt2_tokens = gpt2_tokenizer.encode(text) 10 | multilingual_tokens = multilingual_tokenizer.encode(text) 11 | 12 | assert gpt2_tokenizer.decode(gpt2_tokens) == text 13 | assert multilingual_tokenizer.decode(multilingual_tokens) == text 14 | assert len(gpt2_tokens) > len(multilingual_tokens) 15 | -------------------------------------------------------------------------------- /tests/test_transcribe.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | import whisper 6 | 7 | 8 | @pytest.mark.parametrize('model_name', whisper.available_models()) 9 | def test_transcribe(model_name: str): 10 | model = whisper.load_model(model_name).cuda() 11 | audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac") 12 | 13 | language = "en" if model_name.endswith(".en") else None 14 | result = model.transcribe(audio_path, language=language, temperature=0.0) 15 | assert result["language"] == "en" 16 | 17 | transcription = result["text"].lower() 18 | assert "my fellow americans" in transcription 19 | assert "your country" in transcription 20 | assert "do for you" in transcription 21 | -------------------------------------------------------------------------------- /whisper/__init__.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import io 3 | import os 4 | import urllib 5 | import warnings 6 | from typing import List, Optional, Union 7 | 8 | import torch 9 | from tqdm import tqdm 10 | 11 | from .audio import load_audio, log_mel_spectrogram, pad_or_trim 12 | from .decoding import DecodingOptions, DecodingResult, decode, detect_language 13 | from .model import Whisper, ModelDimensions 14 | from .transcribe import transcribe 15 | 16 | 17 | _MODELS = { 18 | "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", 19 | "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", 20 | "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", 21 | "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", 22 | "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", 23 | "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", 24 | "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", 25 | "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", 26 | "large": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large.pt", 27 | } 28 | 29 | 30 | def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: 31 | os.makedirs(root, exist_ok=True) 32 | 33 | expected_sha256 = url.split("/")[-2] 34 | download_target = os.path.join(root, os.path.basename(url)) 35 | 36 | if os.path.exists(download_target) and not os.path.isfile(download_target): 37 | raise RuntimeError(f"{download_target} exists and is not a regular file") 38 | 39 | if os.path.isfile(download_target): 40 | model_bytes = open(download_target, "rb").read() 41 | if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: 42 | return model_bytes if in_memory else download_target 43 | else: 44 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 45 | 46 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 47 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 48 | while True: 49 | buffer = source.read(8192) 50 | if not buffer: 51 | break 52 | 53 | output.write(buffer) 54 | loop.update(len(buffer)) 55 | 56 | model_bytes = open(download_target, "rb").read() 57 | if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: 58 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.") 59 | 60 | return model_bytes if in_memory else download_target 61 | 62 | 63 | def available_models() -> List[str]: 64 | """Returns the names of available models""" 65 | return list(_MODELS.keys()) 66 | 67 | 68 | def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper: 69 | """ 70 | Load a Whisper ASR model 71 | 72 | Parameters 73 | ---------- 74 | name : str 75 | one of the official model names listed by `whisper.available_models()`, or 76 | path to a model checkpoint containing the model dimensions and the model state_dict. 77 | device : Union[str, torch.device] 78 | the PyTorch device to put the model into 79 | download_root: str 80 | path to download the model files; by default, it uses "~/.cache/whisper" 81 | in_memory: bool 82 | whether to preload the model weights into host memory 83 | 84 | Returns 85 | ------- 86 | model : Whisper 87 | The Whisper ASR model instance 88 | """ 89 | 90 | if device is None: 91 | device = "cuda" if torch.cuda.is_available() else "cpu" 92 | if download_root is None: 93 | download_root = os.getenv( 94 | "XDG_CACHE_HOME", 95 | os.path.join(os.path.expanduser("~"), ".cache", "whisper") 96 | ) 97 | 98 | if name in _MODELS: 99 | checkpoint_file = _download(_MODELS[name], download_root, in_memory) 100 | elif os.path.isfile(name): 101 | checkpoint_file = open(name, "rb").read() if in_memory else name 102 | else: 103 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 104 | 105 | with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp: 106 | checkpoint = torch.load(fp, map_location=device) 107 | del checkpoint_file 108 | 109 | dims = ModelDimensions(**checkpoint["dims"]) 110 | model = Whisper(dims) 111 | model.load_state_dict(checkpoint["model_state_dict"]) 112 | 113 | return model.to(device) 114 | -------------------------------------------------------------------------------- /whisper/__main__.py: -------------------------------------------------------------------------------- 1 | from .transcribe import cli 2 | 3 | 4 | cli() 5 | -------------------------------------------------------------------------------- /whisper/assets/gpt2/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"} -------------------------------------------------------------------------------- /whisper/assets/gpt2/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"} -------------------------------------------------------------------------------- /whisper/assets/mel_filters.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/projectlucas/efficient_whisper/ed0c6cefbab7cad208ca33c91facd5674e1101a7/whisper/assets/mel_filters.npz -------------------------------------------------------------------------------- /whisper/assets/multilingual/added_tokens.json: -------------------------------------------------------------------------------- 1 | {"<|endoftext|>": 50257} 2 | -------------------------------------------------------------------------------- /whisper/assets/multilingual/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"} -------------------------------------------------------------------------------- /whisper/assets/multilingual/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"} -------------------------------------------------------------------------------- /whisper/audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import lru_cache 3 | from typing import Union 4 | 5 | import ffmpeg 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from .utils import exact_div 11 | 12 | # hard-coded audio hyperparameters 13 | SAMPLE_RATE = 16000 14 | N_FFT = 400 15 | N_MELS = 80 16 | HOP_LENGTH = 160 17 | CHUNK_LENGTH = 30 18 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk 19 | N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input 20 | 21 | 22 | def load_audio(file: str, sr: int = SAMPLE_RATE): 23 | """ 24 | Open an audio file and read as mono waveform, resampling as necessary 25 | 26 | Parameters 27 | ---------- 28 | file: str 29 | The audio file to open 30 | 31 | sr: int 32 | The sample rate to resample the audio if necessary 33 | 34 | Returns 35 | ------- 36 | A NumPy array containing the audio waveform, in float32 dtype. 37 | """ 38 | try: 39 | # This launches a subprocess to decode audio while down-mixing and resampling as necessary. 40 | # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. 41 | out, _ = ( 42 | ffmpeg.input(file, threads=0) 43 | .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr) 44 | .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) 45 | ) 46 | except ffmpeg.Error as e: 47 | raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e 48 | 49 | return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 50 | 51 | 52 | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): 53 | """ 54 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder. 55 | """ 56 | if torch.is_tensor(array): 57 | if array.shape[axis] > length: 58 | array = array.index_select(dim=axis, index=torch.arange(length, device=array.device)) 59 | 60 | if array.shape[axis] < length: 61 | pad_widths = [(0, 0)] * array.ndim 62 | pad_widths[axis] = (0, length - array.shape[axis]) 63 | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) 64 | else: 65 | if array.shape[axis] > length: 66 | array = array.take(indices=range(length), axis=axis) 67 | 68 | if array.shape[axis] < length: 69 | pad_widths = [(0, 0)] * array.ndim 70 | pad_widths[axis] = (0, length - array.shape[axis]) 71 | array = np.pad(array, pad_widths) 72 | 73 | return array 74 | 75 | 76 | @lru_cache(maxsize=None) 77 | def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: 78 | """ 79 | load the mel filterbank matrix for projecting STFT into a Mel spectrogram. 80 | Allows decoupling librosa dependency; saved using: 81 | 82 | np.savez_compressed( 83 | "mel_filters.npz", 84 | mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), 85 | ) 86 | """ 87 | assert n_mels == 80, f"Unsupported n_mels: {n_mels}" 88 | with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f: 89 | return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) 90 | 91 | 92 | def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS): 93 | """ 94 | Compute the log-Mel spectrogram of 95 | 96 | Parameters 97 | ---------- 98 | audio: Union[str, np.ndarray, torch.Tensor], shape = (*) 99 | The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz 100 | 101 | n_mels: int 102 | The number of Mel-frequency filters, only 80 is supported 103 | 104 | Returns 105 | ------- 106 | torch.Tensor, shape = (80, n_frames) 107 | A Tensor that contains the Mel spectrogram 108 | """ 109 | if not torch.is_tensor(audio): 110 | if isinstance(audio, str): 111 | audio = load_audio(audio) 112 | audio = torch.from_numpy(audio) 113 | 114 | window = torch.hann_window(N_FFT).to(audio.device) 115 | stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) 116 | magnitudes = stft[:, :-1].abs() ** 2 117 | 118 | filters = mel_filters(audio.device, n_mels) 119 | mel_spec = filters @ magnitudes 120 | 121 | log_spec = torch.clamp(mel_spec, min=1e-10).log10() 122 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) 123 | log_spec = (log_spec + 4.0) / 4.0 124 | return log_spec 125 | -------------------------------------------------------------------------------- /whisper/decoding.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import Tensor 8 | from torch.distributions import Categorical 9 | 10 | from .audio import CHUNK_LENGTH 11 | from .tokenizer import Tokenizer, get_tokenizer 12 | from .utils import compression_ratio 13 | 14 | if TYPE_CHECKING: 15 | from .model import Whisper 16 | 17 | 18 | @torch.no_grad() 19 | def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]: 20 | """ 21 | Detect the spoken language in the audio, and return them as list of strings, along with the ids 22 | of the most probable language tokens and the probability distribution over all language tokens. 23 | This is performed outside the main decode loop in order to not interfere with kv-caching. 24 | 25 | Returns 26 | ------- 27 | language_tokens : Tensor, shape = (n_audio,) 28 | ids of the most probable language tokens, which appears after the startoftranscript token. 29 | language_probs : List[Dict[str, float]], length = n_audio 30 | list of dictionaries containing the probability distribution over all languages. 31 | """ 32 | if tokenizer is None: 33 | tokenizer = get_tokenizer(model.is_multilingual) 34 | if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence: 35 | raise ValueError(f"This model doesn't have language tokens so it can't perform lang id") 36 | 37 | single = mel.ndim == 2 38 | if single: 39 | mel = mel.unsqueeze(0) 40 | 41 | # skip encoder forward pass if already-encoded audio features were given 42 | if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state): 43 | mel = model.encoder(mel) 44 | 45 | # forward pass using a single token, startoftranscript 46 | n_audio = mel.shape[0] 47 | x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1] 48 | logits = model.logits(x, mel)[:, 0] 49 | 50 | # collect detected languages; suppress all non-language tokens 51 | mask = torch.ones(logits.shape[-1], dtype=torch.bool) 52 | mask[list(tokenizer.all_language_tokens)] = False 53 | logits[:, mask] = -np.inf 54 | language_tokens = logits.argmax(dim=-1) 55 | language_token_probs = logits.softmax(dim=-1).cpu() 56 | language_probs = [ 57 | { 58 | c: language_token_probs[i, j].item() 59 | for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes) 60 | } 61 | for i in range(n_audio) 62 | ] 63 | 64 | if single: 65 | language_tokens = language_tokens[0] 66 | language_probs = language_probs[0] 67 | 68 | return language_tokens, language_probs 69 | 70 | 71 | @dataclass(frozen=True) 72 | class DecodingOptions: 73 | task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate" 74 | language: Optional[str] = None # language that the audio is in; uses detected language if None 75 | 76 | # sampling-related options 77 | temperature: float = 0.0 78 | sample_len: Optional[int] = None # maximum number of tokens to sample 79 | best_of: Optional[int] = None # number of independent samples to collect, when t > 0 80 | beam_size: Optional[int] = None # number of beams in beam search, when t == 0 81 | patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424) 82 | 83 | # options for ranking generations (either beams or best-of-N samples) 84 | length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm 85 | 86 | # prompt, prefix, and token suppression 87 | prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context 88 | prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context 89 | suppress_blank: bool = True # this will suppress blank outputs 90 | 91 | # list of tokens ids (or comma-separated token ids) to suppress 92 | # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()` 93 | suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1" 94 | 95 | # timestamp sampling options 96 | without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only 97 | max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this 98 | 99 | # implementation details 100 | fp16: bool = True # use fp16 for most of the calculation 101 | 102 | 103 | @dataclass(frozen=True) 104 | class DecodingResult: 105 | audio_features: Tensor 106 | language: str 107 | language_probs: Optional[Dict[str, float]] = None 108 | tokens: List[int] = field(default_factory=list) 109 | text: str = "" 110 | avg_logprob: float = np.nan 111 | no_speech_prob: float = np.nan 112 | temperature: float = np.nan 113 | compression_ratio: float = np.nan 114 | 115 | 116 | class Inference: 117 | def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: 118 | """Perform a forward pass on the decoder and return per-token logits""" 119 | raise NotImplementedError 120 | 121 | def rearrange_kv_cache(self, source_indices) -> None: 122 | """Update the key-value cache according to the updated beams""" 123 | raise NotImplementedError 124 | 125 | def cleanup_caching(self) -> None: 126 | """Clean up any resources or hooks after decoding is finished""" 127 | pass 128 | 129 | 130 | class PyTorchInference(Inference): 131 | def __init__(self, model: "Whisper", initial_token_length: int): 132 | self.model: "Whisper" = model 133 | self.initial_token_length = initial_token_length 134 | self.kv_cache = {} 135 | 136 | def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: 137 | if tokens.shape[-1] > self.initial_token_length: 138 | # only need to use the last token except in the first forward pass 139 | tokens = tokens[:, -1:] 140 | 141 | if len(self.kv_cache) == 0: 142 | dummy_cache = torch.zeros([ 143 | tokens.size(0), 144 | self.model.dims.n_text_layer, 145 | 0, 146 | self.model.dims.n_text_state 147 | ], dtype=audio_features.dtype, device=tokens.device) 148 | self.kv_cache['k_cache'] = dummy_cache 149 | self.kv_cache['v_cache'] = dummy_cache 150 | self.kv_cache['xa_k_cache'] = dummy_cache 151 | self.kv_cache['xa_v_cache'] = dummy_cache 152 | 153 | outputs, k_cache, v_cache, xa_k_cache, xa_v_cache = self.model.decoder( 154 | tokens, 155 | audio_features, 156 | self.kv_cache['k_cache'], 157 | self.kv_cache['v_cache'], 158 | self.kv_cache['xa_k_cache'], 159 | self.kv_cache['xa_v_cache'] 160 | ) 161 | self.kv_cache['k_cache'] = k_cache 162 | self.kv_cache['v_cache'] = v_cache 163 | self.kv_cache['xa_k_cache'] = xa_k_cache 164 | self.kv_cache['xa_v_cache'] = xa_v_cache 165 | 166 | return outputs 167 | 168 | def cleanup_caching(self): 169 | self.kv_cache = {} 170 | 171 | def rearrange_kv_cache(self, source_indices): 172 | for module, tensor in self.kv_cache.items(): 173 | # update the key/value cache to contain the selected sequences 174 | self.kv_cache[module] = tensor[source_indices].detach() 175 | 176 | 177 | class SequenceRanker: 178 | def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]: 179 | """ 180 | Given a list of groups of samples and their cumulative log probabilities, 181 | return the indices of the samples in each group to select as the final result 182 | """ 183 | raise NotImplementedError 184 | 185 | 186 | class MaximumLikelihoodRanker(SequenceRanker): 187 | """ 188 | Select the sample with the highest log probabilities, penalized using either 189 | a simple length normalization or Google NMT paper's length penalty 190 | """ 191 | 192 | def __init__(self, length_penalty: Optional[float]): 193 | self.length_penalty = length_penalty 194 | 195 | def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]): 196 | def scores(logprobs, lengths): 197 | result = [] 198 | for logprob, length in zip(logprobs, lengths): 199 | if self.length_penalty is None: 200 | penalty = length 201 | else: 202 | # from the Google NMT paper 203 | penalty = ((5 + length) / 6) ** self.length_penalty 204 | result.append(logprob / penalty) 205 | return result 206 | 207 | # get the sequence with the highest score 208 | lengths = [[len(t) for t in s] for s in tokens] 209 | return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)] 210 | 211 | 212 | class TokenDecoder: 213 | def reset(self): 214 | """Initialize any stateful variables for decoding a new sequence""" 215 | 216 | def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: 217 | """Specify how to select the next token, based on the current trace and logits 218 | 219 | Parameters 220 | ---------- 221 | tokens : Tensor, shape = (n_batch, current_sequence_length) 222 | all tokens in the context so far, including the prefix and sot_sequence tokens 223 | 224 | logits : Tensor, shape = (n_batch, vocab_size) 225 | per-token logits of the probability distribution at the current step 226 | 227 | sum_logprobs : Tensor, shape = (n_batch) 228 | cumulative log probabilities for each sequence 229 | 230 | Returns 231 | ------- 232 | tokens : Tensor, shape = (n_batch, current_sequence_length + 1) 233 | the tokens, appended with the selected next token 234 | 235 | completed : bool 236 | True if all sequences has reached the end of text 237 | 238 | """ 239 | raise NotImplementedError 240 | 241 | def finalize( 242 | self, tokens: Tensor, sum_logprobs: Tensor 243 | ) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]: 244 | """Finalize search and return the final candidate sequences 245 | 246 | Parameters 247 | ---------- 248 | tokens : Tensor, shape = (n_audio, n_group, current_sequence_length) 249 | all tokens in the context so far, including the prefix and sot_sequence 250 | 251 | sum_logprobs : Tensor, shape = (n_audio, n_group) 252 | cumulative log probabilities for each sequence 253 | 254 | Returns 255 | ------- 256 | tokens : Sequence[Sequence[Tensor]], length = n_audio 257 | sequence of Tensors containing candidate token sequences, for each audio input 258 | 259 | sum_logprobs : List[List[float]], length = n_audio 260 | sequence of cumulative log probabilities corresponding to the above 261 | 262 | """ 263 | raise NotImplementedError 264 | 265 | 266 | class GreedyDecoder(TokenDecoder): 267 | def __init__(self, temperature: float, eot: int): 268 | self.temperature = temperature 269 | self.eot = eot 270 | 271 | def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: 272 | temperature = self.temperature 273 | if temperature == 0: 274 | next_tokens = logits.argmax(dim=-1) 275 | else: 276 | next_tokens = Categorical(logits=logits / temperature).sample() 277 | 278 | logprobs = F.log_softmax(logits.float(), dim=-1) 279 | current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens] 280 | sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot) 281 | 282 | next_tokens[tokens[:, -1] == self.eot] = self.eot 283 | tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1) 284 | 285 | completed = (tokens[:, -1] == self.eot).all() 286 | return tokens, completed 287 | 288 | def finalize(self, tokens: Tensor, sum_logprobs: Tensor): 289 | # make sure each sequence has at least one EOT token at the end 290 | tokens = F.pad(tokens, (0, 1), value=self.eot) 291 | return tokens, sum_logprobs.tolist() 292 | 293 | 294 | class BeamSearchDecoder(TokenDecoder): 295 | def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None): 296 | self.beam_size = beam_size 297 | self.eot = eot 298 | self.inference = inference 299 | self.patience = patience or 1.0 300 | self.max_candidates: int = round(beam_size * self.patience) 301 | self.finished_sequences = None 302 | 303 | assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})" 304 | 305 | def reset(self): 306 | self.finished_sequences = None 307 | 308 | def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: 309 | if tokens.shape[0] % self.beam_size != 0: 310 | raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0") 311 | 312 | n_audio = tokens.shape[0] // self.beam_size 313 | if self.finished_sequences is None: # for the first update 314 | self.finished_sequences = [{} for _ in range(n_audio)] 315 | 316 | logprobs = F.log_softmax(logits.float(), dim=-1) 317 | next_tokens, source_indices, finished_sequences = [], [], [] 318 | for i in range(n_audio): 319 | scores, sources, finished = {}, {}, {} 320 | 321 | # STEP 1: calculate the cumulative log probabilities for possible candidates 322 | for j in range(self.beam_size): 323 | idx = i * self.beam_size + j 324 | prefix = tokens[idx].tolist() 325 | for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)): 326 | new_logprob = (sum_logprobs[idx] + logprob).item() 327 | sequence = tuple(prefix + [token.item()]) 328 | scores[sequence] = new_logprob 329 | sources[sequence] = idx 330 | 331 | # STEP 2: rank the candidates and keep the top beam_size sequences for each audio 332 | saved = 0 333 | for sequence in sorted(scores, key=scores.get, reverse=True): 334 | if sequence[-1] == self.eot: 335 | finished[sequence] = scores[sequence] 336 | else: 337 | sum_logprobs[len(next_tokens)] = scores[sequence] 338 | next_tokens.append(sequence) 339 | source_indices.append(sources[sequence]) 340 | 341 | saved += 1 342 | if saved == self.beam_size: 343 | break 344 | 345 | finished_sequences.append(finished) 346 | 347 | tokens = torch.tensor(next_tokens, device=tokens.device) 348 | self.inference.rearrange_kv_cache(source_indices) 349 | 350 | # add newly finished sequences to self.finished_sequences 351 | assert len(self.finished_sequences) == len(finished_sequences) 352 | for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences): 353 | for seq in sorted(newly_finished, key=newly_finished.get, reverse=True): 354 | if len(previously_finished) >= self.max_candidates: 355 | break # the candidate list is full 356 | previously_finished[seq] = newly_finished[seq] 357 | 358 | # mark as completed if all audio has enough number of samples 359 | completed = all( 360 | len(sequences) >= self.max_candidates for sequences in self.finished_sequences 361 | ) 362 | return tokens, completed 363 | 364 | def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor): 365 | # collect all finished sequences, including patience, and add unfinished ones if not enough 366 | sum_logprobs = sum_logprobs.cpu() 367 | for i, sequences in enumerate(self.finished_sequences): 368 | if len(sequences) < self.beam_size: # when not enough sequences are finished 369 | for j in list(np.argsort(sum_logprobs[i]))[::-1]: 370 | sequence = preceding_tokens[i, j].tolist() + [self.eot] 371 | sequences[tuple(sequence)] = sum_logprobs[i][j].item() 372 | if len(sequences) >= self.beam_size: 373 | break 374 | 375 | tokens: List[List[Tensor]] = [ 376 | [torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences 377 | ] 378 | sum_logprobs: List[List[float]] = [ 379 | list(sequences.values()) for sequences in self.finished_sequences 380 | ] 381 | return tokens, sum_logprobs 382 | 383 | 384 | class LogitFilter: 385 | def apply(self, logits: Tensor, tokens: Tensor) -> None: 386 | """Apply any filtering or masking to logits in-place 387 | 388 | Parameters 389 | ---------- 390 | logits : Tensor, shape = (n_batch, vocab_size) 391 | per-token logits of the probability distribution at the current step 392 | 393 | tokens : Tensor, shape = (n_batch, current_sequence_length) 394 | all tokens in the context so far, including the prefix and sot_sequence tokens 395 | 396 | """ 397 | raise NotImplementedError 398 | 399 | 400 | class SuppressBlank(LogitFilter): 401 | def __init__(self, tokenizer: Tokenizer, sample_begin: int): 402 | self.tokenizer = tokenizer 403 | self.sample_begin = sample_begin 404 | 405 | def apply(self, logits: Tensor, tokens: Tensor): 406 | if tokens.shape[1] == self.sample_begin: 407 | logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf 408 | 409 | 410 | class SuppressTokens(LogitFilter): 411 | def __init__(self, suppress_tokens: Sequence[int]): 412 | self.suppress_tokens = list(suppress_tokens) 413 | 414 | def apply(self, logits: Tensor, tokens: Tensor): 415 | logits[:, self.suppress_tokens] = -np.inf 416 | 417 | 418 | class ApplyTimestampRules(LogitFilter): 419 | def __init__( 420 | self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int] 421 | ): 422 | self.tokenizer = tokenizer 423 | self.sample_begin = sample_begin 424 | self.max_initial_timestamp_index = max_initial_timestamp_index 425 | 426 | def apply(self, logits: Tensor, tokens: Tensor): 427 | # suppress <|notimestamps|> which is handled by without_timestamps 428 | if self.tokenizer.no_timestamps is not None: 429 | logits[:, self.tokenizer.no_timestamps] = -np.inf 430 | 431 | # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly 432 | for k in range(tokens.shape[0]): 433 | seq = [t for t in tokens[k, self.sample_begin :].tolist()] 434 | last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin 435 | penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin 436 | 437 | if last_was_timestamp: 438 | if penultimate_was_timestamp: # has to be non-timestamp 439 | logits[k, self.tokenizer.timestamp_begin :] = -np.inf 440 | else: # cannot be normal text tokens 441 | logits[k, : self.tokenizer.eot] = -np.inf 442 | 443 | if tokens.shape[1] == self.sample_begin: 444 | # suppress generating non-timestamp tokens at the beginning 445 | logits[:, : self.tokenizer.timestamp_begin] = -np.inf 446 | 447 | # apply the `max_initial_timestamp` option 448 | if self.max_initial_timestamp_index is not None: 449 | last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index 450 | logits[:, last_allowed + 1 :] = -np.inf 451 | 452 | # if sum of probability over timestamps is above any other token, sample timestamp 453 | logprobs = F.log_softmax(logits.float(), dim=-1) 454 | for k in range(tokens.shape[0]): 455 | timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1) 456 | max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max() 457 | if timestamp_logprob > max_text_token_logprob: 458 | logits[k, : self.tokenizer.timestamp_begin] = -np.inf 459 | 460 | 461 | class DecodingTask: 462 | inference: Inference 463 | sequence_ranker: SequenceRanker 464 | decoder: TokenDecoder 465 | logit_filters: List[LogitFilter] 466 | 467 | def __init__(self, model: "Whisper", options: DecodingOptions): 468 | self.model = model 469 | 470 | language = options.language or "en" 471 | tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task) 472 | self.tokenizer: Tokenizer = tokenizer 473 | self.options: DecodingOptions = self._verify_options(options) 474 | 475 | self.n_group: int = options.beam_size or options.best_of or 1 476 | self.n_ctx: int = model.dims.n_text_ctx 477 | self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2 478 | 479 | self.sot_sequence: Tuple[int] = tokenizer.sot_sequence 480 | if self.options.without_timestamps: 481 | self.sot_sequence = tokenizer.sot_sequence_including_notimestamps 482 | 483 | self.initial_tokens: Tuple[int] = self._get_initial_tokens() 484 | self.sample_begin: int = len(self.initial_tokens) 485 | self.sot_index: int = self.initial_tokens.index(tokenizer.sot) 486 | 487 | # inference: implements the forward pass through the decoder, including kv caching 488 | self.inference = PyTorchInference(model, len(self.initial_tokens)) 489 | 490 | # sequence ranker: implements how to rank a group of sampled sequences 491 | self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty) 492 | 493 | # decoder: implements how to select the next tokens, given the autoregressive distribution 494 | if options.beam_size is not None: 495 | self.decoder = BeamSearchDecoder( 496 | options.beam_size, tokenizer.eot, self.inference, options.patience 497 | ) 498 | else: 499 | self.decoder = GreedyDecoder(options.temperature, tokenizer.eot) 500 | 501 | # logit filters: applies various rules to suppress or penalize certain tokens 502 | self.logit_filters = [] 503 | if self.options.suppress_blank: 504 | self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin)) 505 | if self.options.suppress_tokens: 506 | self.logit_filters.append(SuppressTokens(self._get_suppress_tokens())) 507 | if not options.without_timestamps: 508 | precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds 509 | max_initial_timestamp_index = None 510 | if options.max_initial_timestamp: 511 | max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision) 512 | self.logit_filters.append( 513 | ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index) 514 | ) 515 | 516 | def _verify_options(self, options: DecodingOptions) -> DecodingOptions: 517 | if options.beam_size is not None and options.best_of is not None: 518 | raise ValueError("beam_size and best_of can't be given together") 519 | if options.temperature == 0: 520 | if options.best_of is not None: 521 | raise ValueError("best_of with greedy sampling (T=0) is not compatible") 522 | if options.patience is not None and options.beam_size is None: 523 | raise ValueError("patience requires beam_size to be given") 524 | if options.length_penalty is not None and not (0 <= options.length_penalty <= 1): 525 | raise ValueError("length_penalty (alpha) should be a value between 0 and 1") 526 | 527 | return options 528 | 529 | def _get_initial_tokens(self) -> Tuple[int]: 530 | tokens = list(self.sot_sequence) 531 | prefix = self.options.prefix 532 | prompt = self.options.prompt 533 | 534 | if prefix: 535 | prefix_tokens = ( 536 | self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix 537 | ) 538 | if self.sample_len is not None: 539 | max_prefix_len = self.n_ctx // 2 - self.sample_len 540 | prefix_tokens = prefix_tokens[-max_prefix_len:] 541 | tokens = tokens + prefix_tokens 542 | 543 | if prompt: 544 | prompt_tokens = ( 545 | self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt 546 | ) 547 | tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens 548 | 549 | return tuple(tokens) 550 | 551 | def _get_suppress_tokens(self) -> Tuple[int]: 552 | suppress_tokens = self.options.suppress_tokens 553 | 554 | if isinstance(suppress_tokens, str): 555 | suppress_tokens = [int(t) for t in suppress_tokens.split(",")] 556 | 557 | if -1 in suppress_tokens: 558 | suppress_tokens = [t for t in suppress_tokens if t >= 0] 559 | suppress_tokens.extend(self.tokenizer.non_speech_tokens) 560 | elif suppress_tokens is None or len(suppress_tokens) == 0: 561 | suppress_tokens = [] # interpret empty string as an empty list 562 | else: 563 | assert isinstance(suppress_tokens, list), "suppress_tokens must be a list" 564 | 565 | suppress_tokens.extend( 566 | [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm] 567 | ) 568 | if self.tokenizer.no_speech is not None: 569 | # no-speech probability is collected separately 570 | suppress_tokens.append(self.tokenizer.no_speech) 571 | 572 | return tuple(sorted(set(suppress_tokens))) 573 | 574 | def _get_audio_features(self, mel: Tensor): 575 | if self.options.fp16: 576 | mel = mel.half() 577 | 578 | if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state): 579 | # encoded audio features are given; skip audio encoding 580 | audio_features = mel 581 | else: 582 | audio_features = self.model.encoder(mel) 583 | 584 | if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32): 585 | return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}") 586 | 587 | return audio_features 588 | 589 | def _detect_language(self, audio_features: Tensor, tokens: Tensor): 590 | languages = [self.options.language] * audio_features.shape[0] 591 | lang_probs = None 592 | 593 | if self.options.language is None or self.options.task == "lang_id": 594 | lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer) 595 | languages = [max(probs, key=probs.get) for probs in lang_probs] 596 | if self.options.language is None: 597 | tokens[:, self.sot_index + 1] = lang_tokens # write language tokens 598 | 599 | return languages, lang_probs 600 | 601 | def _main_loop(self, audio_features: Tensor, tokens: Tensor): 602 | assert audio_features.shape[0] == tokens.shape[0] 603 | n_batch = tokens.shape[0] 604 | sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device) 605 | no_speech_probs = [np.nan] * n_batch 606 | 607 | try: 608 | for i in range(self.sample_len): 609 | logits = self.inference.logits(tokens, audio_features) 610 | 611 | if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs 612 | probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1) 613 | no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist() 614 | 615 | # now we need to consider the logits at the last token only 616 | logits = logits[:, -1] 617 | 618 | # apply the logit filters, e.g. for suppressing or applying penalty to 619 | for logit_filter in self.logit_filters: 620 | logit_filter.apply(logits, tokens) 621 | 622 | # expand the tokens tensor with the selected next tokens 623 | tokens, completed = self.decoder.update(tokens, logits, sum_logprobs) 624 | 625 | if completed or tokens.shape[-1] > self.n_ctx: 626 | break 627 | finally: 628 | self.inference.cleanup_caching() 629 | 630 | return tokens, sum_logprobs, no_speech_probs 631 | 632 | @torch.no_grad() 633 | def run(self, mel: Tensor) -> List[DecodingResult]: 634 | self.decoder.reset() 635 | tokenizer: Tokenizer = self.tokenizer 636 | n_audio: int = mel.shape[0] 637 | 638 | audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass 639 | tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1) 640 | 641 | # detect language if requested, overwriting the language token 642 | languages, language_probs = self._detect_language(audio_features, tokens) 643 | if self.options.task == "lang_id": 644 | return [ 645 | DecodingResult(audio_features=features, language=language, language_probs=probs) 646 | for features, language, probs in zip(audio_features, languages, language_probs) 647 | ] 648 | 649 | # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling 650 | audio_features = audio_features.repeat_interleave(self.n_group, dim=0) 651 | tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device) 652 | 653 | # call the main sampling loop 654 | tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens) 655 | 656 | # reshape the tensors to have (n_audio, n_group) as the first two dimensions 657 | audio_features = audio_features[:: self.n_group] 658 | no_speech_probs = no_speech_probs[:: self.n_group] 659 | assert audio_features.shape[0] == len(no_speech_probs) == n_audio 660 | 661 | tokens = tokens.reshape(n_audio, self.n_group, -1) 662 | sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group) 663 | 664 | # get the final candidates for each group, and slice between the first sampled token and EOT 665 | tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs) 666 | tokens: List[List[Tensor]] = [ 667 | [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens 668 | ] 669 | 670 | # select the top-ranked sample in each group 671 | selected = self.sequence_ranker.rank(tokens, sum_logprobs) 672 | tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)] 673 | texts: List[str] = [tokenizer.decode(t).strip() for t in tokens] 674 | 675 | sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)] 676 | avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)] 677 | 678 | fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs) 679 | if len(set(map(len, fields))) != 1: 680 | raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}") 681 | 682 | return [ 683 | DecodingResult( 684 | audio_features=features, 685 | language=language, 686 | tokens=tokens, 687 | text=text, 688 | avg_logprob=avg_logprob, 689 | no_speech_prob=no_speech_prob, 690 | temperature=self.options.temperature, 691 | compression_ratio=compression_ratio(text), 692 | ) 693 | for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields) 694 | ] 695 | 696 | 697 | @torch.no_grad() 698 | def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]: 699 | """ 700 | Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s). 701 | 702 | Parameters 703 | ---------- 704 | model: Whisper 705 | the Whisper model instance 706 | 707 | mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000) 708 | A tensor containing the Mel spectrogram(s) 709 | 710 | options: DecodingOptions 711 | A dataclass that contains all necessary options for decoding 30-second segments 712 | 713 | Returns 714 | ------- 715 | result: Union[DecodingResult, List[DecodingResult]] 716 | The result(s) of decoding contained in `DecodingResult` dataclass instance(s) 717 | """ 718 | single = mel.ndim == 2 719 | if single: 720 | mel = mel.unsqueeze(0) 721 | 722 | result = DecodingTask(model, options).run(mel) 723 | 724 | if single: 725 | result = result[0] 726 | 727 | return result 728 | -------------------------------------------------------------------------------- /whisper/model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict 3 | from typing import Iterable, Optional 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import Tensor 9 | from torch import nn 10 | 11 | from .transcribe import transcribe as transcribe_function 12 | from .decoding import detect_language as detect_language_function, decode as decode_function 13 | 14 | 15 | @dataclass 16 | class ModelDimensions: 17 | n_mels: int 18 | n_audio_ctx: int 19 | n_audio_state: int 20 | n_audio_head: int 21 | n_audio_layer: int 22 | n_vocab: int 23 | n_text_ctx: int 24 | n_text_state: int 25 | n_text_head: int 26 | n_text_layer: int 27 | 28 | 29 | def sinusoids(length, channels, max_timescale=10000): 30 | """Returns sinusoids for positional embedding""" 31 | assert channels % 2 == 0 32 | log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) 33 | inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) 34 | scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] 35 | return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) 36 | 37 | 38 | class MultiHeadAttention(nn.Module): 39 | def __init__(self, n_state: int, n_head: int): 40 | super().__init__() 41 | self.n_head = n_head 42 | self.query = nn.Linear(n_state, n_state) 43 | self.key = nn.Linear(n_state, n_state, bias=False) 44 | self.value = nn.Linear(n_state, n_state) 45 | self.out = nn.Linear(n_state, n_state) 46 | 47 | def forward( 48 | self, 49 | x: Tensor, 50 | mask: Tensor, 51 | k_cache: Tensor, 52 | v_cache: Tensor 53 | ): 54 | q = self.query(x) 55 | k = self.key(x) 56 | v = self.value(x) 57 | 58 | k = torch.cat([k_cache, k], dim=1) 59 | v = torch.cat([v_cache, v], dim=1) 60 | 61 | wv = self.qkv_attention(q, k, v, mask) 62 | return self.out(wv), k, v 63 | 64 | def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Tensor = None): 65 | n_batch, n_ctx, n_state = q.shape 66 | scale = (n_state // self.n_head) ** -0.25 67 | q = q.view(q.shape[0], q.shape[1], self.n_head, -1).permute(0, 2, 1, 3) * scale 68 | k = k.view(k.shape[0], k.shape[1], self.n_head, -1).permute(0, 2, 3, 1) * scale 69 | v = v.view(v.shape[0], v.shape[1], self.n_head, -1).permute(0, 2, 1, 3) 70 | 71 | qk = q @ k 72 | qk = qk + mask[:n_ctx, :n_ctx] 73 | 74 | w = F.softmax(qk.float(), dim=-1).to(q.dtype) 75 | return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2) 76 | 77 | 78 | class MultiHeadCrossAttention(nn.Module): 79 | def __init__(self, n_state: int, n_head: int): 80 | super().__init__() 81 | self.n_head = n_head 82 | self.query = nn.Linear(n_state, n_state) 83 | self.key = nn.Linear(n_state, n_state, bias=False) 84 | self.value = nn.Linear(n_state, n_state) 85 | self.out = nn.Linear(n_state, n_state) 86 | 87 | def forward( 88 | self, 89 | x: Tensor, 90 | xa: Tensor, 91 | k_cache: Tensor, 92 | v_cache: Tensor 93 | ): 94 | q = self.query(x) 95 | 96 | if k_cache.size(1) == 0: 97 | k = self.key(xa) 98 | v = self.value(xa) 99 | else: 100 | k = k_cache 101 | v = v_cache 102 | 103 | wv = self.qkv_attention(q, k, v) 104 | return self.out(wv), k, v 105 | 106 | def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor): 107 | n_batch, n_ctx, n_state = q.shape 108 | scale = (n_state // self.n_head) ** -0.25 109 | q = q.view(q.shape[0], q.shape[1], self.n_head, -1).permute(0, 2, 1, 3) * scale 110 | k = k.view(k.shape[0], k.shape[1], self.n_head, -1).permute(0, 2, 3, 1) * scale 111 | v = v.view(v.shape[0], v.shape[1], self.n_head, -1).permute(0, 2, 1, 3) 112 | 113 | qk = q @ k 114 | w = F.softmax(qk.float(), dim=-1).to(q.dtype) 115 | return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2) 116 | 117 | 118 | class ResidualAttentionBlock(nn.Module): 119 | def __init__(self, n_state: int, n_head: int): 120 | super().__init__() 121 | 122 | self.attn = MultiHeadAttention(n_state, n_head) 123 | self.attn_ln = nn.LayerNorm(n_state) 124 | 125 | n_mlp = n_state * 4 126 | self.mlp = nn.Sequential( 127 | nn.Linear(n_state, n_mlp), 128 | nn.GELU(), 129 | nn.Linear(n_mlp, n_state) 130 | ) 131 | self.mlp_ln = nn.LayerNorm(n_state) 132 | 133 | def forward( 134 | self, 135 | x: Tensor, 136 | mask: Tensor, 137 | ): 138 | dummy_cache = torch.zeros( 139 | [x.size(0), 0, x.size(-1)], dtype=x.dtype, device=x.device) 140 | y, _, _ = self.attn( 141 | self.attn_ln(x), 142 | mask, 143 | dummy_cache, 144 | dummy_cache 145 | ) 146 | x = x + y 147 | x = x + self.mlp(self.mlp_ln(x)) 148 | return x 149 | 150 | 151 | class ResidualCrossAttentionBlock(nn.Module): 152 | def __init__(self, n_state: int, n_head: int): 153 | super().__init__() 154 | 155 | self.attn = MultiHeadAttention(n_state, n_head) 156 | self.attn_ln = nn.LayerNorm(n_state) 157 | 158 | self.cross_attn = MultiHeadCrossAttention(n_state, n_head) 159 | self.cross_attn_ln = nn.LayerNorm(n_state) 160 | 161 | n_mlp = n_state * 4 162 | self.mlp = nn.Sequential( 163 | nn.Linear(n_state, n_mlp), 164 | nn.GELU(), 165 | nn.Linear(n_mlp, n_state) 166 | ) 167 | self.mlp_ln = nn.LayerNorm(n_state) 168 | 169 | def forward( 170 | self, 171 | x: Tensor, 172 | xa: Tensor, 173 | mask: Tensor, 174 | k_cache: Tensor, 175 | v_cache: Tensor, 176 | xa_k_cache: Tensor, 177 | xa_v_cache: Tensor 178 | ): 179 | y, k_cache, v_cache = self.attn(self.attn_ln(x), mask, k_cache, v_cache) 180 | x = x + y 181 | y, xa_k_cache, xa_v_cache = self.cross_attn(self.cross_attn_ln(x), xa, xa_k_cache, xa_v_cache) 182 | x = x + y 183 | x = x + self.mlp(self.mlp_ln(x)) 184 | return x, k_cache, v_cache, xa_k_cache, xa_v_cache 185 | 186 | 187 | class AudioEncoder(nn.Module): 188 | def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): 189 | super().__init__() 190 | self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1) 191 | self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) 192 | self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) 193 | 194 | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( 195 | [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] 196 | ) 197 | self.ln_post = nn.LayerNorm(n_state) 198 | 199 | mask = torch.zeros(n_ctx, n_ctx) 200 | self.register_buffer("mask", mask, persistent=False) 201 | 202 | def forward(self, x: Tensor): 203 | x = F.gelu(self.conv1(x)) 204 | x = F.gelu(self.conv2(x)) 205 | x = x.permute(0, 2, 1) 206 | 207 | assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" 208 | x = (x + self.positional_embedding).to(x.dtype) 209 | 210 | for block in self.blocks: 211 | x = block(x, self.mask) 212 | 213 | x = self.ln_post(x) 214 | return x 215 | 216 | 217 | class TextDecoder(nn.Module): 218 | def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): 219 | super().__init__() 220 | 221 | self.token_embedding = nn.Embedding(n_vocab, n_state) 222 | self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) 223 | 224 | self.blocks: Iterable[ResidualCrossAttentionBlock] = nn.ModuleList( 225 | [ResidualCrossAttentionBlock(n_state, n_head) for _ in range(n_layer)] 226 | ) 227 | self.ln = nn.LayerNorm(n_state) 228 | 229 | mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) 230 | self.register_buffer("mask", mask, persistent=False) 231 | 232 | def forward( 233 | self, 234 | x: Tensor, 235 | xa: Tensor, 236 | k_cache: Tensor, 237 | v_cache: Tensor, 238 | xa_k_cache: Tensor, 239 | xa_v_cache: Tensor 240 | ): 241 | offset = k_cache.shape[2] 242 | x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]] 243 | x = x.to(xa.dtype) 244 | 245 | k_cache_list, v_cache_list = [], [] 246 | xa_k_cache_list, xa_v_cache_list = [], [] 247 | for i, block in enumerate(self.blocks): 248 | x, new_k_cache, new_v_cache, new_xa_k_cache, new_xa_v_cache = block( 249 | x, 250 | xa, 251 | self.mask, 252 | k_cache[:, i, :, :], 253 | v_cache[:, i, :, :], 254 | xa_k_cache[:, i, :, :], 255 | xa_v_cache[:, i, :, :] 256 | ) 257 | k_cache_list.append(new_k_cache) 258 | v_cache_list.append(new_v_cache) 259 | xa_k_cache_list.append(new_xa_k_cache) 260 | xa_v_cache_list.append(new_xa_v_cache) 261 | 262 | x = self.ln(x) 263 | logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float() 264 | 265 | return ( 266 | logits, 267 | torch.stack(k_cache_list, dim=1), 268 | torch.stack(v_cache_list, dim=1), 269 | torch.stack(xa_k_cache_list, dim=1), 270 | torch.stack(xa_v_cache_list, dim=1) 271 | ) 272 | 273 | 274 | class Whisper(nn.Module): 275 | def __init__(self, dims: ModelDimensions): 276 | super().__init__() 277 | self.dims = dims 278 | self.encoder = AudioEncoder( 279 | self.dims.n_mels, 280 | self.dims.n_audio_ctx, 281 | self.dims.n_audio_state, 282 | self.dims.n_audio_head, 283 | self.dims.n_audio_layer, 284 | ) 285 | self.decoder = TextDecoder( 286 | self.dims.n_vocab, 287 | self.dims.n_text_ctx, 288 | self.dims.n_text_state, 289 | self.dims.n_text_head, 290 | self.dims.n_text_layer, 291 | ) 292 | 293 | def embed_audio(self, mel: torch.Tensor): 294 | return self.encoder(mel) 295 | 296 | def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): 297 | dummy_cache = torch.zeros([ 298 | tokens.size(0), self.dims.n_text_layer, 0, self.dims.n_text_state 299 | ], dtype=audio_features.dtype, device=tokens.device) 300 | outputs, _, _, _, _ = self.decoder( 301 | tokens, 302 | audio_features, 303 | dummy_cache, 304 | dummy_cache, 305 | dummy_cache, 306 | dummy_cache 307 | ) 308 | return outputs 309 | 310 | def forward(self, mel: torch.Tensor, tokens: torch.Tensor): 311 | dummy_cache = torch.zeros([ 312 | tokens.size(0), self.dims.n_text_layer, 0, self.dims.n_text_state 313 | ], dtype=mel.dtype, device=tokens.device) 314 | outputs, _, _, _, _ = self.decoder( 315 | tokens, 316 | self.encoder(mel), 317 | dummy_cache, 318 | dummy_cache, 319 | dummy_cache, 320 | dummy_cache 321 | ) 322 | return outputs 323 | 324 | @property 325 | def device(self): 326 | return next(self.parameters()).device 327 | 328 | @property 329 | def is_multilingual(self): 330 | return self.dims.n_vocab == 51865 331 | 332 | detect_language = detect_language_function 333 | transcribe = transcribe_function 334 | decode = decode_function 335 | -------------------------------------------------------------------------------- /whisper/normalizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic import BasicTextNormalizer 2 | from .english import EnglishTextNormalizer 3 | -------------------------------------------------------------------------------- /whisper/normalizers/basic.py: -------------------------------------------------------------------------------- 1 | import re 2 | import unicodedata 3 | 4 | import regex 5 | 6 | # non-ASCII letters that are not separated by "NFKD" normalization 7 | ADDITIONAL_DIACRITICS = { 8 | "œ": "oe", 9 | "Œ": "OE", 10 | "ø": "o", 11 | "Ø": "O", 12 | "æ": "ae", 13 | "Æ": "AE", 14 | "ß": "ss", 15 | "ẞ": "SS", 16 | "đ": "d", 17 | "Đ": "D", 18 | "ð": "d", 19 | "Ð": "D", 20 | "þ": "th", 21 | "Þ": "th", 22 | "ł": "l", 23 | "Ł": "L", 24 | } 25 | 26 | 27 | def remove_symbols_and_diacritics(s: str, keep=""): 28 | """ 29 | Replace any other markers, symbols, and punctuations with a space, 30 | and drop any diacritics (category 'Mn' and some manual mappings) 31 | """ 32 | return "".join( 33 | c 34 | if c in keep 35 | else ADDITIONAL_DIACRITICS[c] 36 | if c in ADDITIONAL_DIACRITICS 37 | else "" 38 | if unicodedata.category(c) == "Mn" 39 | else " " 40 | if unicodedata.category(c)[0] in "MSP" 41 | else c 42 | for c in unicodedata.normalize("NFKD", s) 43 | ) 44 | 45 | 46 | def remove_symbols(s: str): 47 | """ 48 | Replace any other markers, symbols, punctuations with a space, keeping diacritics 49 | """ 50 | return "".join( 51 | " " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s) 52 | ) 53 | 54 | 55 | class BasicTextNormalizer: 56 | def __init__(self, remove_diacritics: bool = False, split_letters: bool = False): 57 | self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols 58 | self.split_letters = split_letters 59 | 60 | def __call__(self, s: str): 61 | s = s.lower() 62 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets 63 | s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis 64 | s = self.clean(s).lower() 65 | 66 | if self.split_letters: 67 | s = " ".join(regex.findall(r"\X", s, regex.U)) 68 | 69 | s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space 70 | 71 | return s 72 | -------------------------------------------------------------------------------- /whisper/normalizers/english.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | from fractions import Fraction 5 | from typing import Iterator, List, Match, Optional, Union 6 | 7 | from more_itertools import windowed 8 | 9 | from .basic import remove_symbols_and_diacritics 10 | 11 | 12 | class EnglishNumberNormalizer: 13 | """ 14 | Convert any spelled-out numbers into arabic numbers, while handling: 15 | 16 | - remove any commas 17 | - keep the suffixes such as: `1960s`, `274th`, `32nd`, etc. 18 | - spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars` 19 | - spell out `one` and `ones` 20 | - interpret successive single-digit numbers as nominal: `one oh one` -> `101` 21 | """ 22 | 23 | def __init__(self): 24 | super().__init__() 25 | 26 | self.zeros = {"o", "oh", "zero"} 27 | self.ones = { 28 | name: i 29 | for i, name in enumerate( 30 | [ 31 | "one", 32 | "two", 33 | "three", 34 | "four", 35 | "five", 36 | "six", 37 | "seven", 38 | "eight", 39 | "nine", 40 | "ten", 41 | "eleven", 42 | "twelve", 43 | "thirteen", 44 | "fourteen", 45 | "fifteen", 46 | "sixteen", 47 | "seventeen", 48 | "eighteen", 49 | "nineteen", 50 | ], 51 | start=1, 52 | ) 53 | } 54 | self.ones_plural = { 55 | "sixes" if name == "six" else name + "s": (value, "s") 56 | for name, value in self.ones.items() 57 | } 58 | self.ones_ordinal = { 59 | "zeroth": (0, "th"), 60 | "first": (1, "st"), 61 | "second": (2, "nd"), 62 | "third": (3, "rd"), 63 | "fifth": (5, "th"), 64 | "twelfth": (12, "th"), 65 | **{ 66 | name + ("h" if name.endswith("t") else "th"): (value, "th") 67 | for name, value in self.ones.items() 68 | if value > 3 and value != 5 and value != 12 69 | }, 70 | } 71 | self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal} 72 | 73 | self.tens = { 74 | "twenty": 20, 75 | "thirty": 30, 76 | "forty": 40, 77 | "fifty": 50, 78 | "sixty": 60, 79 | "seventy": 70, 80 | "eighty": 80, 81 | "ninety": 90, 82 | } 83 | self.tens_plural = { 84 | name.replace("y", "ies"): (value, "s") for name, value in self.tens.items() 85 | } 86 | self.tens_ordinal = { 87 | name.replace("y", "ieth"): (value, "th") for name, value in self.tens.items() 88 | } 89 | self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal} 90 | 91 | self.multipliers = { 92 | "hundred": 100, 93 | "thousand": 1_000, 94 | "million": 1_000_000, 95 | "billion": 1_000_000_000, 96 | "trillion": 1_000_000_000_000, 97 | "quadrillion": 1_000_000_000_000_000, 98 | "quintillion": 1_000_000_000_000_000_000, 99 | "sextillion": 1_000_000_000_000_000_000_000, 100 | "septillion": 1_000_000_000_000_000_000_000_000, 101 | "octillion": 1_000_000_000_000_000_000_000_000_000, 102 | "nonillion": 1_000_000_000_000_000_000_000_000_000_000, 103 | "decillion": 1_000_000_000_000_000_000_000_000_000_000_000, 104 | } 105 | self.multipliers_plural = { 106 | name + "s": (value, "s") for name, value in self.multipliers.items() 107 | } 108 | self.multipliers_ordinal = { 109 | name + "th": (value, "th") for name, value in self.multipliers.items() 110 | } 111 | self.multipliers_suffixed = {**self.multipliers_plural, **self.multipliers_ordinal} 112 | self.decimals = {*self.ones, *self.tens, *self.zeros} 113 | 114 | self.preceding_prefixers = { 115 | "minus": "-", 116 | "negative": "-", 117 | "plus": "+", 118 | "positive": "+", 119 | } 120 | self.following_prefixers = { 121 | "pound": "£", 122 | "pounds": "£", 123 | "euro": "€", 124 | "euros": "€", 125 | "dollar": "$", 126 | "dollars": "$", 127 | "cent": "¢", 128 | "cents": "¢", 129 | } 130 | self.prefixes = set( 131 | list(self.preceding_prefixers.values()) + list(self.following_prefixers.values()) 132 | ) 133 | self.suffixers = { 134 | "per": {"cent": "%"}, 135 | "percent": "%", 136 | } 137 | self.specials = {"and", "double", "triple", "point"} 138 | 139 | self.words = set( 140 | [ 141 | key 142 | for mapping in [ 143 | self.zeros, 144 | self.ones, 145 | self.ones_suffixed, 146 | self.tens, 147 | self.tens_suffixed, 148 | self.multipliers, 149 | self.multipliers_suffixed, 150 | self.preceding_prefixers, 151 | self.following_prefixers, 152 | self.suffixers, 153 | self.specials, 154 | ] 155 | for key in mapping 156 | ] 157 | ) 158 | self.literal_words = {"one", "ones"} 159 | 160 | def process_words(self, words: List[str]) -> Iterator[str]: 161 | prefix: Optional[str] = None 162 | value: Optional[Union[str, int]] = None 163 | skip = False 164 | 165 | def to_fraction(s: str): 166 | try: 167 | return Fraction(s) 168 | except ValueError: 169 | return None 170 | 171 | def output(result: Union[str, int]): 172 | nonlocal prefix, value 173 | result = str(result) 174 | if prefix is not None: 175 | result = prefix + result 176 | value = None 177 | prefix = None 178 | return result 179 | 180 | if len(words) == 0: 181 | return 182 | 183 | for prev, current, next in windowed([None] + words + [None], 3): 184 | if skip: 185 | skip = False 186 | continue 187 | 188 | next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next) 189 | has_prefix = current[0] in self.prefixes 190 | current_without_prefix = current[1:] if has_prefix else current 191 | if re.match(r"^\d+(\.\d+)?$", current_without_prefix): 192 | # arabic numbers (potentially with signs and fractions) 193 | f = to_fraction(current_without_prefix) 194 | assert f is not None 195 | if value is not None: 196 | if isinstance(value, str) and value.endswith("."): 197 | # concatenate decimals / ip address components 198 | value = str(value) + str(current) 199 | continue 200 | else: 201 | yield output(value) 202 | 203 | prefix = current[0] if has_prefix else prefix 204 | if f.denominator == 1: 205 | value = f.numerator # store integers as int 206 | else: 207 | value = current_without_prefix 208 | elif current not in self.words: 209 | # non-numeric words 210 | if value is not None: 211 | yield output(value) 212 | yield output(current) 213 | elif current in self.zeros: 214 | value = str(value or "") + "0" 215 | elif current in self.ones: 216 | ones = self.ones[current] 217 | 218 | if value is None: 219 | value = ones 220 | elif isinstance(value, str) or prev in self.ones: 221 | if prev in self.tens and ones < 10: # replace the last zero with the digit 222 | assert value[-1] == "0" 223 | value = value[:-1] + str(ones) 224 | else: 225 | value = str(value) + str(ones) 226 | elif ones < 10: 227 | if value % 10 == 0: 228 | value += ones 229 | else: 230 | value = str(value) + str(ones) 231 | else: # eleven to nineteen 232 | if value % 100 == 0: 233 | value += ones 234 | else: 235 | value = str(value) + str(ones) 236 | elif current in self.ones_suffixed: 237 | # ordinal or cardinal; yield the number right away 238 | ones, suffix = self.ones_suffixed[current] 239 | if value is None: 240 | yield output(str(ones) + suffix) 241 | elif isinstance(value, str) or prev in self.ones: 242 | if prev in self.tens and ones < 10: 243 | assert value[-1] == "0" 244 | yield output(value[:-1] + str(ones) + suffix) 245 | else: 246 | yield output(str(value) + str(ones) + suffix) 247 | elif ones < 10: 248 | if value % 10 == 0: 249 | yield output(str(value + ones) + suffix) 250 | else: 251 | yield output(str(value) + str(ones) + suffix) 252 | else: # eleven to nineteen 253 | if value % 100 == 0: 254 | yield output(str(value + ones) + suffix) 255 | else: 256 | yield output(str(value) + str(ones) + suffix) 257 | value = None 258 | elif current in self.tens: 259 | tens = self.tens[current] 260 | if value is None: 261 | value = tens 262 | elif isinstance(value, str): 263 | value = str(value) + str(tens) 264 | else: 265 | if value % 100 == 0: 266 | value += tens 267 | else: 268 | value = str(value) + str(tens) 269 | elif current in self.tens_suffixed: 270 | # ordinal or cardinal; yield the number right away 271 | tens, suffix = self.tens_suffixed[current] 272 | if value is None: 273 | yield output(str(tens) + suffix) 274 | elif isinstance(value, str): 275 | yield output(str(value) + str(tens) + suffix) 276 | else: 277 | if value % 100 == 0: 278 | yield output(str(value + tens) + suffix) 279 | else: 280 | yield output(str(value) + str(tens) + suffix) 281 | elif current in self.multipliers: 282 | multiplier = self.multipliers[current] 283 | if value is None: 284 | value = multiplier 285 | elif isinstance(value, str) or value == 0: 286 | f = to_fraction(value) 287 | p = f * multiplier if f is not None else None 288 | if f is not None and p.denominator == 1: 289 | value = p.numerator 290 | else: 291 | yield output(value) 292 | value = multiplier 293 | else: 294 | before = value // 1000 * 1000 295 | residual = value % 1000 296 | value = before + residual * multiplier 297 | elif current in self.multipliers_suffixed: 298 | multiplier, suffix = self.multipliers_suffixed[current] 299 | if value is None: 300 | yield output(str(multiplier) + suffix) 301 | elif isinstance(value, str): 302 | f = to_fraction(value) 303 | p = f * multiplier if f is not None else None 304 | if f is not None and p.denominator == 1: 305 | yield output(str(p.numerator) + suffix) 306 | else: 307 | yield output(value) 308 | yield output(str(multiplier) + suffix) 309 | else: # int 310 | before = value // 1000 * 1000 311 | residual = value % 1000 312 | value = before + residual * multiplier 313 | yield output(str(value) + suffix) 314 | value = None 315 | elif current in self.preceding_prefixers: 316 | # apply prefix (positive, minus, etc.) if it precedes a number 317 | if value is not None: 318 | yield output(value) 319 | 320 | if next in self.words or next_is_numeric: 321 | prefix = self.preceding_prefixers[current] 322 | else: 323 | yield output(current) 324 | elif current in self.following_prefixers: 325 | # apply prefix (dollars, cents, etc.) only after a number 326 | if value is not None: 327 | prefix = self.following_prefixers[current] 328 | yield output(value) 329 | else: 330 | yield output(current) 331 | elif current in self.suffixers: 332 | # apply suffix symbols (percent -> '%') 333 | if value is not None: 334 | suffix = self.suffixers[current] 335 | if isinstance(suffix, dict): 336 | if next in suffix: 337 | yield output(str(value) + suffix[next]) 338 | skip = True 339 | else: 340 | yield output(value) 341 | yield output(current) 342 | else: 343 | yield output(str(value) + suffix) 344 | else: 345 | yield output(current) 346 | elif current in self.specials: 347 | if next not in self.words and not next_is_numeric: 348 | # apply special handling only if the next word can be numeric 349 | if value is not None: 350 | yield output(value) 351 | yield output(current) 352 | elif current == "and": 353 | # ignore "and" after hundreds, thousands, etc. 354 | if prev not in self.multipliers: 355 | if value is not None: 356 | yield output(value) 357 | yield output(current) 358 | elif current == "double" or current == "triple": 359 | if next in self.ones or next in self.zeros: 360 | repeats = 2 if current == "double" else 3 361 | ones = self.ones.get(next, 0) 362 | value = str(value or "") + str(ones) * repeats 363 | skip = True 364 | else: 365 | if value is not None: 366 | yield output(value) 367 | yield output(current) 368 | elif current == "point": 369 | if next in self.decimals or next_is_numeric: 370 | value = str(value or "") + "." 371 | else: 372 | # should all have been covered at this point 373 | raise ValueError(f"Unexpected token: {current}") 374 | else: 375 | # all should have been covered at this point 376 | raise ValueError(f"Unexpected token: {current}") 377 | 378 | if value is not None: 379 | yield output(value) 380 | 381 | def preprocess(self, s: str): 382 | # replace " and a half" with " point five" 383 | results = [] 384 | 385 | segments = re.split(r"\band\s+a\s+half\b", s) 386 | for i, segment in enumerate(segments): 387 | if len(segment.strip()) == 0: 388 | continue 389 | if i == len(segments) - 1: 390 | results.append(segment) 391 | else: 392 | results.append(segment) 393 | last_word = segment.rsplit(maxsplit=2)[-1] 394 | if last_word in self.decimals or last_word in self.multipliers: 395 | results.append("point five") 396 | else: 397 | results.append("and a half") 398 | 399 | s = " ".join(results) 400 | 401 | # put a space at number/letter boundary 402 | s = re.sub(r"([a-z])([0-9])", r"\1 \2", s) 403 | s = re.sub(r"([0-9])([a-z])", r"\1 \2", s) 404 | 405 | # but remove spaces which could be a suffix 406 | s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s) 407 | 408 | return s 409 | 410 | def postprocess(self, s: str): 411 | def combine_cents(m: Match): 412 | try: 413 | currency = m.group(1) 414 | integer = m.group(2) 415 | cents = int(m.group(3)) 416 | return f"{currency}{integer}.{cents:02d}" 417 | except ValueError: 418 | return m.string 419 | 420 | def extract_cents(m: Match): 421 | try: 422 | return f"¢{int(m.group(1))}" 423 | except ValueError: 424 | return m.string 425 | 426 | # apply currency postprocessing; "$2 and ¢7" -> "$2.07" 427 | s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s) 428 | s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s) 429 | 430 | # write "one(s)" instead of "1(s)", just for the readability 431 | s = re.sub(r"\b1(s?)\b", r"one\1", s) 432 | 433 | return s 434 | 435 | def __call__(self, s: str): 436 | s = self.preprocess(s) 437 | s = " ".join(word for word in self.process_words(s.split()) if word is not None) 438 | s = self.postprocess(s) 439 | 440 | return s 441 | 442 | 443 | class EnglishSpellingNormalizer: 444 | """ 445 | Applies British-American spelling mappings as listed in [1]. 446 | 447 | [1] https://www.tysto.com/uk-us-spelling-list.html 448 | """ 449 | 450 | def __init__(self): 451 | mapping_path = os.path.join(os.path.dirname(__file__), "english.json") 452 | self.mapping = json.load(open(mapping_path)) 453 | 454 | def __call__(self, s: str): 455 | return " ".join(self.mapping.get(word, word) for word in s.split()) 456 | 457 | 458 | class EnglishTextNormalizer: 459 | def __init__(self): 460 | self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b" 461 | self.replacers = { 462 | # common contractions 463 | r"\bwon't\b": "will not", 464 | r"\bcan't\b": "can not", 465 | r"\blet's\b": "let us", 466 | r"\bain't\b": "aint", 467 | r"\by'all\b": "you all", 468 | r"\bwanna\b": "want to", 469 | r"\bgotta\b": "got to", 470 | r"\bgonna\b": "going to", 471 | r"\bi'ma\b": "i am going to", 472 | r"\bimma\b": "i am going to", 473 | r"\bwoulda\b": "would have", 474 | r"\bcoulda\b": "could have", 475 | r"\bshoulda\b": "should have", 476 | r"\bma'am\b": "madam", 477 | # contractions in titles/prefixes 478 | r"\bmr\b": "mister ", 479 | r"\bmrs\b": "missus ", 480 | r"\bst\b": "saint ", 481 | r"\bdr\b": "doctor ", 482 | r"\bprof\b": "professor ", 483 | r"\bcapt\b": "captain ", 484 | r"\bgov\b": "governor ", 485 | r"\bald\b": "alderman ", 486 | r"\bgen\b": "general ", 487 | r"\bsen\b": "senator ", 488 | r"\brep\b": "representative ", 489 | r"\bpres\b": "president ", 490 | r"\brev\b": "reverend ", 491 | r"\bhon\b": "honorable ", 492 | r"\basst\b": "assistant ", 493 | r"\bassoc\b": "associate ", 494 | r"\blt\b": "lieutenant ", 495 | r"\bcol\b": "colonel ", 496 | r"\bjr\b": "junior ", 497 | r"\bsr\b": "senior ", 498 | r"\besq\b": "esquire ", 499 | # prefect tenses, ideally it should be any past participles, but it's harder.. 500 | r"'d been\b": " had been", 501 | r"'s been\b": " has been", 502 | r"'d gone\b": " had gone", 503 | r"'s gone\b": " has gone", 504 | r"'d done\b": " had done", # "'s done" is ambiguous 505 | r"'s got\b": " has got", 506 | # general contractions 507 | r"n't\b": " not", 508 | r"'re\b": " are", 509 | r"'s\b": " is", 510 | r"'d\b": " would", 511 | r"'ll\b": " will", 512 | r"'t\b": " not", 513 | r"'ve\b": " have", 514 | r"'m\b": " am", 515 | } 516 | self.standardize_numbers = EnglishNumberNormalizer() 517 | self.standardize_spellings = EnglishSpellingNormalizer() 518 | 519 | def __call__(self, s: str): 520 | s = s.lower() 521 | 522 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets 523 | s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis 524 | s = re.sub(self.ignore_patterns, "", s) 525 | s = re.sub(r"\s+'", "'", s) # standardize when there's a space before an apostrophe 526 | 527 | for pattern, replacement in self.replacers.items(): 528 | s = re.sub(pattern, replacement, s) 529 | 530 | s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits 531 | s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers 532 | s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep some symbols for numerics 533 | 534 | s = self.standardize_numbers(s) 535 | s = self.standardize_spellings(s) 536 | 537 | # now remove prefix/suffix symbols that are not preceded/followed by numbers 538 | s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s) 539 | s = re.sub(r"([^0-9])%", r"\1 ", s) 540 | 541 | s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space 542 | 543 | return s 544 | -------------------------------------------------------------------------------- /whisper/tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from functools import lru_cache 4 | from typing import List, Optional, Tuple, Union 5 | 6 | import numpy as np 7 | import torch 8 | from transformers import GPT2TokenizerFast 9 | 10 | LANGUAGES = { 11 | "en": "english", 12 | "zh": "chinese", 13 | "de": "german", 14 | "es": "spanish", 15 | "ru": "russian", 16 | "ko": "korean", 17 | "fr": "french", 18 | "ja": "japanese", 19 | "pt": "portuguese", 20 | "tr": "turkish", 21 | "pl": "polish", 22 | "ca": "catalan", 23 | "nl": "dutch", 24 | "ar": "arabic", 25 | "sv": "swedish", 26 | "it": "italian", 27 | "id": "indonesian", 28 | "hi": "hindi", 29 | "fi": "finnish", 30 | "vi": "vietnamese", 31 | "iw": "hebrew", 32 | "uk": "ukrainian", 33 | "el": "greek", 34 | "ms": "malay", 35 | "cs": "czech", 36 | "ro": "romanian", 37 | "da": "danish", 38 | "hu": "hungarian", 39 | "ta": "tamil", 40 | "no": "norwegian", 41 | "th": "thai", 42 | "ur": "urdu", 43 | "hr": "croatian", 44 | "bg": "bulgarian", 45 | "lt": "lithuanian", 46 | "la": "latin", 47 | "mi": "maori", 48 | "ml": "malayalam", 49 | "cy": "welsh", 50 | "sk": "slovak", 51 | "te": "telugu", 52 | "fa": "persian", 53 | "lv": "latvian", 54 | "bn": "bengali", 55 | "sr": "serbian", 56 | "az": "azerbaijani", 57 | "sl": "slovenian", 58 | "kn": "kannada", 59 | "et": "estonian", 60 | "mk": "macedonian", 61 | "br": "breton", 62 | "eu": "basque", 63 | "is": "icelandic", 64 | "hy": "armenian", 65 | "ne": "nepali", 66 | "mn": "mongolian", 67 | "bs": "bosnian", 68 | "kk": "kazakh", 69 | "sq": "albanian", 70 | "sw": "swahili", 71 | "gl": "galician", 72 | "mr": "marathi", 73 | "pa": "punjabi", 74 | "si": "sinhala", 75 | "km": "khmer", 76 | "sn": "shona", 77 | "yo": "yoruba", 78 | "so": "somali", 79 | "af": "afrikaans", 80 | "oc": "occitan", 81 | "ka": "georgian", 82 | "be": "belarusian", 83 | "tg": "tajik", 84 | "sd": "sindhi", 85 | "gu": "gujarati", 86 | "am": "amharic", 87 | "yi": "yiddish", 88 | "lo": "lao", 89 | "uz": "uzbek", 90 | "fo": "faroese", 91 | "ht": "haitian creole", 92 | "ps": "pashto", 93 | "tk": "turkmen", 94 | "nn": "nynorsk", 95 | "mt": "maltese", 96 | "sa": "sanskrit", 97 | "lb": "luxembourgish", 98 | "my": "myanmar", 99 | "bo": "tibetan", 100 | "tl": "tagalog", 101 | "mg": "malagasy", 102 | "as": "assamese", 103 | "tt": "tatar", 104 | "haw": "hawaiian", 105 | "ln": "lingala", 106 | "ha": "hausa", 107 | "ba": "bashkir", 108 | "jw": "javanese", 109 | "su": "sundanese", 110 | } 111 | 112 | # language code lookup by name, with a few language aliases 113 | TO_LANGUAGE_CODE = { 114 | **{language: code for code, language in LANGUAGES.items()}, 115 | "burmese": "my", 116 | "valencian": "ca", 117 | "flemish": "nl", 118 | "haitian": "ht", 119 | "letzeburgesch": "lb", 120 | "pushto": "ps", 121 | "panjabi": "pa", 122 | "moldavian": "ro", 123 | "moldovan": "ro", 124 | "sinhalese": "si", 125 | "castilian": "es", 126 | } 127 | 128 | 129 | @dataclass(frozen=True) 130 | class Tokenizer: 131 | """A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens""" 132 | 133 | tokenizer: "GPT2TokenizerFast" 134 | language: Optional[str] 135 | sot_sequence: Tuple[int] 136 | 137 | def encode(self, text, **kwargs): 138 | return self.tokenizer.encode(text, **kwargs) 139 | 140 | def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs): 141 | return self.tokenizer.decode(token_ids, **kwargs) 142 | 143 | def decode_with_timestamps(self, tokens) -> str: 144 | """ 145 | Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. 146 | This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". 147 | """ 148 | outputs = [[]] 149 | for token in tokens: 150 | if token >= self.timestamp_begin: 151 | timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>" 152 | outputs.append(timestamp) 153 | outputs.append([]) 154 | else: 155 | outputs[-1].append(token) 156 | outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs] 157 | return "".join(outputs) 158 | 159 | @property 160 | @lru_cache() 161 | def eot(self) -> int: 162 | return self.tokenizer.eos_token_id 163 | 164 | @property 165 | @lru_cache() 166 | def sot(self) -> int: 167 | return self._get_single_token_id("<|startoftranscript|>") 168 | 169 | @property 170 | @lru_cache() 171 | def sot_lm(self) -> int: 172 | return self._get_single_token_id("<|startoflm|>") 173 | 174 | @property 175 | @lru_cache() 176 | def sot_prev(self) -> int: 177 | return self._get_single_token_id("<|startofprev|>") 178 | 179 | @property 180 | @lru_cache() 181 | def no_speech(self) -> int: 182 | return self._get_single_token_id("<|nospeech|>") 183 | 184 | @property 185 | @lru_cache() 186 | def no_timestamps(self) -> int: 187 | return self._get_single_token_id("<|notimestamps|>") 188 | 189 | @property 190 | @lru_cache() 191 | def timestamp_begin(self) -> int: 192 | return self.tokenizer.all_special_ids[-1] + 1 193 | 194 | @property 195 | @lru_cache() 196 | def language_token(self) -> int: 197 | """Returns the token id corresponding to the value of the `language` field""" 198 | if self.language is None: 199 | raise ValueError(f"This tokenizer does not have language token configured") 200 | 201 | additional_tokens = dict( 202 | zip( 203 | self.tokenizer.additional_special_tokens, 204 | self.tokenizer.additional_special_tokens_ids, 205 | ) 206 | ) 207 | candidate = f"<|{self.language}|>" 208 | if candidate in additional_tokens: 209 | return additional_tokens[candidate] 210 | 211 | raise KeyError(f"Language {self.language} not found in tokenizer.") 212 | 213 | @property 214 | @lru_cache() 215 | def all_language_tokens(self) -> Tuple[int]: 216 | result = [] 217 | for token, token_id in zip( 218 | self.tokenizer.additional_special_tokens, 219 | self.tokenizer.additional_special_tokens_ids, 220 | ): 221 | if token.strip("<|>") in LANGUAGES: 222 | result.append(token_id) 223 | return tuple(result) 224 | 225 | @property 226 | @lru_cache() 227 | def all_language_codes(self) -> Tuple[str]: 228 | return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens) 229 | 230 | @property 231 | @lru_cache() 232 | def sot_sequence_including_notimestamps(self) -> Tuple[int]: 233 | return tuple(list(self.sot_sequence) + [self.no_timestamps]) 234 | 235 | @property 236 | @lru_cache() 237 | def non_speech_tokens(self) -> Tuple[int]: 238 | """ 239 | Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech 240 | annotations, to prevent sampling texts that are not actually spoken in the audio, e.g. 241 | 242 | - ♪♪♪ 243 | - ( SPEAKING FOREIGN LANGUAGE ) 244 | - [DAVID] Hey there, 245 | 246 | keeping basic punctuations like commas, periods, question marks, exclamation points, etc. 247 | """ 248 | symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』") 249 | symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split() 250 | 251 | # symbols that may be a single token or multiple tokens depending on the tokenizer. 252 | # In case they're multiple tokens, suppress the first token, which is safe because: 253 | # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress 254 | # in generations, and in the 3-byte UTF-8 representation they share the first two bytes. 255 | miscellaneous = set("♩♪♫♬♭♮♯") 256 | assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous) 257 | 258 | # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word 259 | result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]} 260 | for symbol in symbols + list(miscellaneous): 261 | for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]: 262 | if len(tokens) == 1 or symbol in miscellaneous: 263 | result.add(tokens[0]) 264 | 265 | return tuple(sorted(result)) 266 | 267 | def _get_single_token_id(self, text) -> int: 268 | tokens = self.tokenizer.encode(text) 269 | assert len(tokens) == 1, f"{text} is not encoded as a single token" 270 | return tokens[0] 271 | 272 | 273 | @lru_cache(maxsize=None) 274 | def build_tokenizer(name: str = "gpt2"): 275 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 276 | path = os.path.join(os.path.dirname(__file__), "assets", name) 277 | tokenizer = GPT2TokenizerFast.from_pretrained(path) 278 | 279 | specials = [ 280 | "<|startoftranscript|>", 281 | *[f"<|{lang}|>" for lang in LANGUAGES.keys()], 282 | "<|translate|>", 283 | "<|transcribe|>", 284 | "<|startoflm|>", 285 | "<|startofprev|>", 286 | "<|nospeech|>", 287 | "<|notimestamps|>", 288 | ] 289 | 290 | tokenizer.add_special_tokens(dict(additional_special_tokens=specials)) 291 | return tokenizer 292 | 293 | 294 | @lru_cache(maxsize=None) 295 | def get_tokenizer( 296 | multilingual: bool, 297 | *, 298 | task: Optional[str] = None, # Literal["transcribe", "translate", None] 299 | language: Optional[str] = None, 300 | ) -> Tokenizer: 301 | if language is not None: 302 | language = language.lower() 303 | if language not in LANGUAGES: 304 | if language in TO_LANGUAGE_CODE: 305 | language = TO_LANGUAGE_CODE[language] 306 | else: 307 | raise ValueError(f"Unsupported language: {language}") 308 | 309 | if multilingual: 310 | tokenizer_name = "multilingual" 311 | task = task or "transcribe" 312 | language = language or "en" 313 | else: 314 | tokenizer_name = "gpt2" 315 | task = None 316 | language = None 317 | 318 | tokenizer = build_tokenizer(name=tokenizer_name) 319 | all_special_ids: List[int] = tokenizer.all_special_ids 320 | sot: int = all_special_ids[1] 321 | translate: int = all_special_ids[-6] 322 | transcribe: int = all_special_ids[-5] 323 | 324 | langs = tuple(LANGUAGES.keys()) 325 | sot_sequence = [sot] 326 | if language is not None: 327 | sot_sequence.append(sot + 1 + langs.index(language)) 328 | if task is not None: 329 | sot_sequence.append(transcribe if task == "transcribe" else translate) 330 | 331 | return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)) 332 | -------------------------------------------------------------------------------- /whisper/transcribe.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | from typing import List, Optional, Tuple, Union, TYPE_CHECKING 5 | 6 | import numpy as np 7 | import torch 8 | import tqdm 9 | 10 | from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram 11 | from .decoding import DecodingOptions, DecodingResult 12 | from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer 13 | from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt 14 | 15 | if TYPE_CHECKING: 16 | from .model import Whisper 17 | 18 | 19 | def transcribe( 20 | model: "Whisper", 21 | audio: Union[str, np.ndarray, torch.Tensor], 22 | *, 23 | verbose: Optional[bool] = None, 24 | temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), 25 | compression_ratio_threshold: Optional[float] = 2.4, 26 | logprob_threshold: Optional[float] = -1.0, 27 | no_speech_threshold: Optional[float] = 0.6, 28 | condition_on_previous_text: bool = True, 29 | **decode_options, 30 | ): 31 | """ 32 | Transcribe an audio file using Whisper 33 | 34 | Parameters 35 | ---------- 36 | model: Whisper 37 | The Whisper model instance 38 | 39 | audio: Union[str, np.ndarray, torch.Tensor] 40 | The path to the audio file to open, or the audio waveform 41 | 42 | verbose: bool 43 | Whether to display the text being decoded to the console. If True, displays all the details, 44 | If False, displays minimal details. If None, does not display anything 45 | 46 | temperature: Union[float, Tuple[float, ...]] 47 | Temperature for sampling. It can be a tuple of temperatures, which will be successfully used 48 | upon failures according to either `compression_ratio_threshold` or `logprob_threshold`. 49 | 50 | compression_ratio_threshold: float 51 | If the gzip compression ratio is above this value, treat as failed 52 | 53 | logprob_threshold: float 54 | If the average log probability over sampled tokens is below this value, treat as failed 55 | 56 | no_speech_threshold: float 57 | If the no_speech probability is higher than this value AND the average log probability 58 | over sampled tokens is below `logprob_threshold`, consider the segment as silent 59 | 60 | condition_on_previous_text: bool 61 | if True, the previous output of the model is provided as a prompt for the next window; 62 | disabling may make the text inconsistent across windows, but the model becomes less prone to 63 | getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. 64 | 65 | decode_options: dict 66 | Keyword arguments to construct `DecodingOptions` instances 67 | 68 | Returns 69 | ------- 70 | A dictionary containing the resulting text ("text") and segment-level details ("segments"), and 71 | the spoken language ("language"), which is detected when `decode_options["language"]` is None. 72 | """ 73 | dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32 74 | if model.device == torch.device("cpu"): 75 | if torch.cuda.is_available(): 76 | warnings.warn("Performing inference on CPU when CUDA is available") 77 | if dtype == torch.float16: 78 | warnings.warn("FP16 is not supported on CPU; using FP32 instead") 79 | dtype = torch.float32 80 | 81 | if dtype == torch.float32: 82 | decode_options["fp16"] = False 83 | 84 | mel = log_mel_spectrogram(audio) 85 | 86 | if decode_options.get("language", None) is None: 87 | if not model.is_multilingual: 88 | decode_options["language"] = "en" 89 | else: 90 | if verbose: 91 | print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language") 92 | segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) 93 | _, probs = model.detect_language(segment) 94 | decode_options["language"] = max(probs, key=probs.get) 95 | if verbose is not None: 96 | print(f"Detected language: {LANGUAGES[decode_options['language']].title()}") 97 | 98 | language = decode_options["language"] 99 | task = decode_options.get("task", "transcribe") 100 | tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task) 101 | 102 | def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: 103 | temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature 104 | decode_result = None 105 | 106 | for t in temperatures: 107 | kwargs = {**decode_options} 108 | if t > 0: 109 | # disable beam_size and patience when t > 0 110 | kwargs.pop("beam_size", None) 111 | kwargs.pop("patience", None) 112 | else: 113 | # disable best_of when t == 0 114 | kwargs.pop("best_of", None) 115 | 116 | options = DecodingOptions(**kwargs, temperature=t) 117 | decode_result = model.decode(segment, options) 118 | 119 | needs_fallback = False 120 | if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold: 121 | needs_fallback = True # too repetitive 122 | if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold: 123 | needs_fallback = True # average log probability is too low 124 | 125 | if not needs_fallback: 126 | break 127 | 128 | return decode_result 129 | 130 | seek = 0 131 | input_stride = exact_div( 132 | N_FRAMES, model.dims.n_audio_ctx 133 | ) # mel frames per output token: 2 134 | time_precision = ( 135 | input_stride * HOP_LENGTH / SAMPLE_RATE 136 | ) # time per output token: 0.02 (seconds) 137 | all_tokens = [] 138 | all_segments = [] 139 | prompt_reset_since = 0 140 | 141 | initial_prompt = decode_options.pop("initial_prompt", None) or [] 142 | if initial_prompt: 143 | initial_prompt = tokenizer.encode(" " + initial_prompt.strip()) 144 | all_tokens.extend(initial_prompt) 145 | 146 | def add_segment( 147 | *, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult 148 | ): 149 | text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot]) 150 | if len(text.strip()) == 0: # skip empty text output 151 | return 152 | 153 | all_segments.append( 154 | { 155 | "id": len(all_segments), 156 | "seek": seek, 157 | "start": start, 158 | "end": end, 159 | "text": text, 160 | "tokens": text_tokens.tolist(), 161 | "temperature": result.temperature, 162 | "avg_logprob": result.avg_logprob, 163 | "compression_ratio": result.compression_ratio, 164 | "no_speech_prob": result.no_speech_prob, 165 | } 166 | ) 167 | if verbose: 168 | print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}") 169 | 170 | # show the progress bar when verbose is False (otherwise the transcribed text will be printed) 171 | num_frames = mel.shape[-1] 172 | previous_seek_value = seek 173 | 174 | with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar: 175 | while seek < num_frames: 176 | timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) 177 | segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype) 178 | segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE 179 | 180 | decode_options["prompt"] = all_tokens[prompt_reset_since:] 181 | result: DecodingResult = decode_with_fallback(segment) 182 | tokens = torch.tensor(result.tokens) 183 | 184 | if no_speech_threshold is not None: 185 | # no voice activity check 186 | should_skip = result.no_speech_prob > no_speech_threshold 187 | if logprob_threshold is not None and result.avg_logprob > logprob_threshold: 188 | # don't skip if the logprob is high enough, despite the no_speech_prob 189 | should_skip = False 190 | 191 | if should_skip: 192 | seek += segment.shape[-1] # fast-forward to the next segment boundary 193 | continue 194 | 195 | timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) 196 | consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1) 197 | if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens 198 | last_slice = 0 199 | for current_slice in consecutive: 200 | sliced_tokens = tokens[last_slice:current_slice] 201 | start_timestamp_position = ( 202 | sliced_tokens[0].item() - tokenizer.timestamp_begin 203 | ) 204 | end_timestamp_position = ( 205 | sliced_tokens[-1].item() - tokenizer.timestamp_begin 206 | ) 207 | add_segment( 208 | start=timestamp_offset + start_timestamp_position * time_precision, 209 | end=timestamp_offset + end_timestamp_position * time_precision, 210 | text_tokens=sliced_tokens[1:-1], 211 | result=result, 212 | ) 213 | last_slice = current_slice 214 | last_timestamp_position = ( 215 | tokens[last_slice - 1].item() - tokenizer.timestamp_begin 216 | ) 217 | seek += last_timestamp_position * input_stride 218 | all_tokens.extend(tokens[: last_slice + 1].tolist()) 219 | else: 220 | duration = segment_duration 221 | timestamps = tokens[timestamp_tokens.nonzero().flatten()] 222 | if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin: 223 | # no consecutive timestamps but it has a timestamp; use the last one. 224 | # single timestamp at the end means no speech after the last timestamp. 225 | last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin 226 | duration = last_timestamp_position * time_precision 227 | 228 | add_segment( 229 | start=timestamp_offset, 230 | end=timestamp_offset + duration, 231 | text_tokens=tokens, 232 | result=result, 233 | ) 234 | 235 | seek += segment.shape[-1] 236 | all_tokens.extend(tokens.tolist()) 237 | 238 | if not condition_on_previous_text or result.temperature > 0.5: 239 | # do not feed the prompt tokens if a high temperature was used 240 | prompt_reset_since = len(all_tokens) 241 | 242 | # update progress bar 243 | pbar.update(min(num_frames, seek) - previous_seek_value) 244 | previous_seek_value = seek 245 | 246 | return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language) 247 | 248 | 249 | def cli(): 250 | from . import available_models 251 | 252 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 253 | parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") 254 | parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use") 255 | parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") 256 | parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") 257 | parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") 258 | parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") 259 | 260 | parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") 261 | parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection") 262 | 263 | parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling") 264 | parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature") 265 | parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero") 266 | parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search") 267 | parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default") 268 | 269 | parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations") 270 | parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") 271 | parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") 272 | parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default") 273 | 274 | parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below") 275 | parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed") 276 | parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed") 277 | parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") 278 | parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") 279 | 280 | args = parser.parse_args().__dict__ 281 | model_name: str = args.pop("model") 282 | model_dir: str = args.pop("model_dir") 283 | output_dir: str = args.pop("output_dir") 284 | device: str = args.pop("device") 285 | os.makedirs(output_dir, exist_ok=True) 286 | 287 | if model_name.endswith(".en") and args["language"] not in {"en", "English"}: 288 | if args["language"] is not None: 289 | warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.") 290 | args["language"] = "en" 291 | 292 | temperature = args.pop("temperature") 293 | temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback") 294 | if temperature_increment_on_fallback is not None: 295 | temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback)) 296 | else: 297 | temperature = [temperature] 298 | 299 | threads = args.pop("threads") 300 | if threads > 0: 301 | torch.set_num_threads(threads) 302 | 303 | from . import load_model 304 | model = load_model(model_name, device=device, download_root=model_dir) 305 | 306 | for audio_path in args.pop("audio"): 307 | result = transcribe(model, audio_path, temperature=temperature, **args) 308 | 309 | audio_basename = os.path.basename(audio_path) 310 | 311 | # save TXT 312 | with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt: 313 | write_txt(result["segments"], file=txt) 314 | 315 | # save VTT 316 | with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt: 317 | write_vtt(result["segments"], file=vtt) 318 | 319 | # save SRT 320 | with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt: 321 | write_srt(result["segments"], file=srt) 322 | 323 | 324 | if __name__ == '__main__': 325 | cli() 326 | -------------------------------------------------------------------------------- /whisper/utils.py: -------------------------------------------------------------------------------- 1 | import zlib 2 | from typing import Iterator, TextIO 3 | 4 | 5 | def exact_div(x, y): 6 | assert x % y == 0 7 | return x // y 8 | 9 | 10 | def str2bool(string): 11 | str2val = {"True": True, "False": False} 12 | if string in str2val: 13 | return str2val[string] 14 | else: 15 | raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") 16 | 17 | 18 | def optional_int(string): 19 | return None if string == "None" else int(string) 20 | 21 | 22 | def optional_float(string): 23 | return None if string == "None" else float(string) 24 | 25 | 26 | def compression_ratio(text) -> float: 27 | return len(text) / len(zlib.compress(text.encode("utf-8"))) 28 | 29 | 30 | def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'): 31 | assert seconds >= 0, "non-negative timestamp expected" 32 | milliseconds = round(seconds * 1000.0) 33 | 34 | hours = milliseconds // 3_600_000 35 | milliseconds -= hours * 3_600_000 36 | 37 | minutes = milliseconds // 60_000 38 | milliseconds -= minutes * 60_000 39 | 40 | seconds = milliseconds // 1_000 41 | milliseconds -= seconds * 1_000 42 | 43 | hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" 44 | return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" 45 | 46 | 47 | def write_txt(transcript: Iterator[dict], file: TextIO): 48 | for segment in transcript: 49 | print(segment['text'].strip(), file=file, flush=True) 50 | 51 | 52 | def write_vtt(transcript: Iterator[dict], file: TextIO): 53 | print("WEBVTT\n", file=file) 54 | for segment in transcript: 55 | print( 56 | f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" 57 | f"{segment['text'].strip().replace('-->', '->')}\n", 58 | file=file, 59 | flush=True, 60 | ) 61 | 62 | 63 | def write_srt(transcript: Iterator[dict], file: TextIO): 64 | """ 65 | Write a transcript to a file in SRT format. 66 | 67 | Example usage: 68 | from pathlib import Path 69 | from whisper.utils import write_srt 70 | 71 | result = transcribe(model, audio_path, temperature=temperature, **args) 72 | 73 | # save SRT 74 | audio_basename = Path(audio_path).stem 75 | with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt: 76 | write_srt(result["segments"], file=srt) 77 | """ 78 | for i, segment in enumerate(transcript, start=1): 79 | # write srt lines 80 | print( 81 | f"{i}\n" 82 | f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> " 83 | f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n" 84 | f"{segment['text'].strip().replace('-->', '->')}\n", 85 | file=file, 86 | flush=True, 87 | ) 88 | --------------------------------------------------------------------------------