├── .flake8 ├── .gitattributes ├── .github └── workflows │ ├── python-publish.yml │ └── test.yml ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── approach.png ├── data ├── README.md └── meanwhile.json ├── landing-page-metrics.png ├── language-breakdown.svg ├── model-card.md ├── notebooks ├── Demo_YouTube.ipynb ├── LibriSpeech.ipynb └── Multilingual_ASR.ipynb ├── pyproject.toml ├── requirements.txt ├── setup.py ├── tdrz_dev ├── README.md ├── barplots.png ├── extra_requirements.txt ├── notebooks │ ├── analysis.ipynb │ ├── analysis_utils.py │ └── duplicate-serialization.png ├── score.py ├── score_fstalign.sh └── scripts │ ├── diarize_post_sr.py │ ├── diarize_pre_sr.py │ ├── fetch_earnings21_calls.sh │ └── run_pipelines.py ├── tests ├── conftest.py ├── jfk.flac ├── test_audio.py ├── test_normalizer.py ├── test_timing.py ├── test_tokenizer.py └── test_transcribe.py ├── trim-tinydiarize.gif └── whisper ├── __init__.py ├── __main__.py ├── assets ├── gpt2 │ ├── merges.txt │ ├── special_tokens_map.json │ ├── tokenizer_config.json │ └── vocab.json ├── mel_filters.npz └── multilingual │ ├── added_tokens.json │ ├── merges.txt │ ├── special_tokens_map.json │ ├── tokenizer_config.json │ └── vocab.json ├── audio.py ├── decoding.py ├── model.py ├── normalizers ├── __init__.py ├── basic.py ├── english.json └── english.py ├── timing.py ├── tokenizer.py ├── transcribe.py ├── triton_ops.py ├── utils.py └── version.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | per-file-ignores = 3 | */__init__.py: F401 4 | 5 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Override jupyter in Github language stats for more accurate estimate of repo code languages 2 | # reference: https://github.com/github/linguist/blob/master/docs/overrides.md#generated-code 3 | *.ipynb linguist-generated 4 | 5 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - uses: actions-ecosystem/action-regex-match@v2 13 | id: regex-match 14 | with: 15 | text: ${{ github.event.head_commit.message }} 16 | regex: '^Release ([^ ]+)' 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.8' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Release 26 | if: ${{ steps.regex-match.outputs.match != '' }} 27 | uses: softprops/action-gh-release@v1 28 | with: 29 | tag_name: v${{ steps.regex-match.outputs.group1 }} 30 | - name: Build and publish 31 | if: ${{ steps.regex-match.outputs.match != '' }} 32 | env: 33 | TWINE_USERNAME: __token__ 34 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 35 | run: | 36 | python setup.py sdist 37 | twine upload dist/* 38 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | branches: 8 | - main 9 | jobs: 10 | whisper-test: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ['3.8', '3.9', '3.10'] 15 | pytorch-version: [1.10.2, 1.13.1] 16 | exclude: 17 | - python-version: '3.10' 18 | pytorch-version: 1.10.2 19 | steps: 20 | - uses: conda-incubator/setup-miniconda@v2 21 | - run: conda install -n test ffmpeg python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch 22 | - uses: actions/checkout@v2 23 | - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH 24 | - run: pip install .["dev"] 25 | - run: black --check --diff -t py38 --include '(\.pyi?)$' . 26 | - run: isort --check --diff . 27 | - run: flake8 --ignore E203,W503,W504,E501,E731,E741 . 28 | - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda' 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | *.egg-info 5 | .pytest_cache 6 | .ipynb_checkpoints 7 | 8 | thumbs.db 9 | .DS_Store 10 | .idea 11 | 12 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # CHANGELOG 2 | 3 | ## [v20230308](https://github.com/openai/whisper/releases/tag/v20230308) 4 | 5 | * kwargs in decode() for convenience ([#1061](https://github.com/openai/whisper/pull/1061)) 6 | * fix all_tokens handling that caused more repetitions and discrepancy in JSON ([#1060](https://github.com/openai/whisper/pull/1060)) 7 | * fix typo in CHANGELOG.md 8 | 9 | ## [v20230307](https://github.com/openai/whisper/releases/tag/v20230307) 10 | 11 | * Fix the repetition/hallucination issue identified in #1046 ([#1052](https://github.com/openai/whisper/pull/1052)) 12 | * Use triton==2.0.0 ([#1053](https://github.com/openai/whisper/pull/1053)) 13 | * Install triton in x86_64 linux only ([#1051](https://github.com/openai/whisper/pull/1051)) 14 | * update setup.py to specify python >= 3.8 requirement 15 | 16 | ## [v20230306](https://github.com/openai/whisper/releases/tag/v20230306) 17 | 18 | * remove auxiliary audio extension ([#1021](https://github.com/openai/whisper/pull/1021)) 19 | * apply formatting with `black`, `isort`, and `flake8` ([#1038](https://github.com/openai/whisper/pull/1038)) 20 | * word-level timestamps in `transcribe()` ([#869](https://github.com/openai/whisper/pull/869)) 21 | * Decoding improvements ([#1033](https://github.com/openai/whisper/pull/1033)) 22 | * Update README.md ([#894](https://github.com/openai/whisper/pull/894)) 23 | * Fix infinite loop caused by incorrect timestamp tokens prediction ([#914](https://github.com/openai/whisper/pull/914)) 24 | * drop python 3.7 support ([#889](https://github.com/openai/whisper/pull/889)) 25 | 26 | ## [v20230124](https://github.com/openai/whisper/releases/tag/v20230124) 27 | 28 | * handle printing even if sys.stdout.buffer is not available ([#887](https://github.com/openai/whisper/pull/887)) 29 | * Add TSV formatted output in transcript, using integer start/end time in milliseconds ([#228](https://github.com/openai/whisper/pull/228)) 30 | * Added `--output_format` option ([#333](https://github.com/openai/whisper/pull/333)) 31 | * Handle `XDG_CACHE_HOME` properly for `download_root` ([#864](https://github.com/openai/whisper/pull/864)) 32 | * use stdout for printing transcription progress ([#867](https://github.com/openai/whisper/pull/867)) 33 | * Fix bug where mm is mistakenly replaced with hmm in e.g. 20mm ([#659](https://github.com/openai/whisper/pull/659)) 34 | * print '?' if a letter can't be encoded using the system default encoding ([#859](https://github.com/openai/whisper/pull/859)) 35 | 36 | ## [v20230117](https://github.com/openai/whisper/releases/tag/v20230117) 37 | 38 | The first versioned release available on [PyPI](https://pypi.org/project/openai-whisper/) 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | include README.md 3 | include LICENSE 4 | include whisper/assets/* 5 | include whisper/assets/gpt2/* 6 | include whisper/assets/multilingual/* 7 | include whisper/normalizers/english.json 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tinydiarize 🐥🗣️ 2 | 3 | - *Speaker diarization* labels who said what in a transcript (e.g. Speaker A, Speaker B …). It is essential for conversation transcripts like meetings or podcasts. 4 | - *tinydiarize* aims to be a minimal, interpretable extension of OpenAI's [Whisper](https://github.com/openai/whisper) models that adds speaker diarization with few extra dependencies (inspired by [minGPT](https://github.com/karpathy/minGPT)). 5 | - This uses a finetuned model that adds special tokens to mark speaker changes [[1,2,3,4]](#references). It can use *both voice and semantic context to tell speakers apart*, which is a unique benefit of this approach. 6 | - You can refer to [tdrz_dev](https://github.com/akashmjn/tinydiarize/tree/main/tdrz_dev) for a detailed analysis of performance. Note that this is intended to be a prototype/proof-of-concept. 7 | - Experimental support is also added to [whisper.cpp](https://github.com/ggerganov/whisper.cpp#speaker-segmentation-via-tinydiarize-experimental) so this can run on consumer hardware like MacBooks and iPhones. A tiny change is needed to original inference code (<50 lines), enabling simple and cheap speaker segmentation, compared with conventional approaches. 8 | 9 | 10 | ## Demo 11 | 12 | https://user-images.githubusercontent.com/13268767/229617067-eca0f614-d334-480d-9801-7c30d88acdc6.mp4 13 | 14 | You can try it out on other such gems from YouTube using this notebook. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/akashmjn/tinydiarize/blob/main/notebooks/Demo_YouTube.ipynb) 15 | 16 | 17 | ## Quickstart 18 | 19 | Install `ffmpeg` following the [original repo](https://github.com/openai/whisper#Setup), then run: 20 | 21 | ``` 22 | pip install -e . 23 | whisper --model small.en-tdrz AUDIO 24 | ``` 25 | 26 | The only change is the `small.en-tdrz` model instead of `small.en`. That's it! 🎉 27 | 28 | 29 | ## What's included? 30 | 31 | - Finetuned checkpoint for the `small.en-tdrz` model (located [here](whisper/__init__.py)) and example inference code (relevant edits in [[#4]](https://github.com/akashmjn/tinydiarize/pull/4) [[#11]](https://github.com/akashmjn/tinydiarize/pull/11)). This has the same dependencies as the original whisper repo. 32 | - Tools for comparison and analysis (under [/tdrz_dev](tdrz_dev)): 33 | - A scoring tool to measure and compare accuracy on your own data in an easy to interpret way. 34 | - A reference script to run and compare various diarization pipelines. 35 | - A Jupyter notebook to compare and understand performance in detail. 36 | - See [Roadmap](#roadmap) for more info. 37 | 38 | We aim to demonstrate a starting point enabling anyone (or even OpenAI themselves!) to improve performance and extend support (multilingual, speech translation etc.). 39 | 40 | ## Performance 41 | 42 | |metric|small.en|small.en-tdrz| 43 | |:----|:----|:----| 44 | |spk_turn_precision|-|97.7| 45 | |spk_turn_recall|-|70.8| 46 | |wer_overall|11.0|10.3| 47 | |wer_speaker_switch|15.0|15.5| 48 | 49 | On a (tiny) benchmark set of 3 [earnings calls](https://github.com/revdotcom/speech-datasets/tree/main/earnings21), `tdrz` gets near-perfect speaker turn precision at fairly decent recall. A similar WER is retained as the original model. Not too shabby for a tiny finetuning setup, and <10% extra inference cost! 50 | 51 | Refer to [tdrz_dev](tdrz_dev/) for details on performance analysis and comparisons. 52 | 53 | ## More info 54 | - Whisper `small.en` checkpoints were finetuned on ~100hrs of [AMI meetings](https://groups.inf.ed.ac.uk/ami/corpus/) using HuggingFace [Transformers](https://github.com/huggingface/transformers) and [Datasets](https://github.com/huggingface/datasets). 55 | - With some tricks, this could be done relatively cheaply with just 30mins of 1 GPU training starting to produce decent results. Tiny indeed 😊. 56 | - We used helpful tools from [pyannote](https://github.com/pyannote/pyannote-core) (the OG open-source diarization toolkit) for finetuning data preparation and also analyze its performance. 57 | - We make use of the excellent open-source [revdotcom/fstalign](https://github.com/revdotcom/fstalign) tool for scoring and analysis. 58 | 59 | ## Gotchas 60 | 61 | Note that this still an early proof-of-concept and there are a few things to be aware of: 62 | - Only the `small.en` English model has been finetuned. 63 | - Word-error-rate (WER) is close to original models, although not yet extensively tested. Ad-hoc inspection does show some differences in timestamp behavior (longer segments) or deletion errors. See the notebook under [tdrz_dev](tdrz_dev/) for details. 64 | - Given a pretty tiny finetuning setup, there's likely a lot of room for further accuracy improvements. 65 | - Only local diarization (segmentation into speaker turns) is handled so far. Extension with global diarization (speaker clustering) is planned for later. 66 | - Stuff is still hacky and subject to change, so hold your horses just yet! 🐎 67 | 68 | ## Roadmap 69 | - [x] inference code & demo 70 | - [x] scoring and analysis tools 71 | - [x] [whisper.cpp integration](https://github.com/ggerganov/whisper.cpp/pull/1058) 72 | - [ ] *reproducible dataprep + finetuning\** 73 | - [ ] *blog post explainer\** 74 | - [ ] HuggingFace integration 75 | - [ ] better LoRa-based `small.en` checkpoint 76 | - [ ] possibly clustering with [NME-SC](https://github.com/tango4j/Auto-Tuning-Spectral-Clustering)? 77 | - [ ] possibly `large-v2` checkpoint? 78 | 79 | *\* is a pointer to the current state of the repo. Please see https://github.com/akashmjn/tinydiarize/issues/14 for an update on plans. TLDR; things have had to be put on pause :/* 80 | 81 | ## References 82 | 83 | [[1]](https://arxiv.org/abs/1907.05337) Joint Speech Recognition and Speaker Diarization via Sequence Transduction 84 | [[2]](https://arxiv.org/abs/2003.12687) Serialized Output Training for End-to-End Overlapped Speech Recognition 85 | [[3]](https://arxiv.org/abs/2109.11641) Turn-to-Diarize: Online Speaker Diarization Constrained by Transformer Transducer Speaker Turn Detection 86 | [[4]](https://arxiv.org/abs/2305.18747) Adapting Multi-Lingual ASR Models for Handling Multiple Talkers 87 | 88 | For information on the underlying Whisper model, please refer to the [original documentation (release: `20230308`)](https://github.com/openai/whisper/tree/v20230308) 89 | 90 | ## License 91 | 92 | Code and model weights are released under the MIT License. See [LICENSE](https://github.com/openai/whisper/blob/main/LICENSE) for further details. 93 | 94 | ## Citation 95 | 96 | If you please to use this in your research, you can cite this work as 97 | ``` 98 | @software{mahajan2023tinydiarize, 99 | author = {Mahajan, Akash}, 100 | month = {08}, 101 | title = {tinydiarize: Minimal extension of Whisper for speaker segmentation with special tokens}, 102 | url = {https://github.com/akashmjn/tinyDiarize}, 103 | year = {2023} 104 | } 105 | ``` 106 | 107 | -------------------------------------------------------------------------------- /approach.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akashmjn/tinydiarize/7cba47def707514fe68bbd6663f97663ea482158/approach.png -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | This directory supplements the paper with more details on how we prepared the data for evaluation, to help replicate our experiments. 2 | 3 | ## Short-form English-only datasets 4 | 5 | ### LibriSpeech 6 | 7 | We used the test-clean and test-other splits from the [LibriSpeech ASR corpus](https://www.openslr.org/12). 8 | 9 | ### TED-LIUM 3 10 | 11 | We used the test split of [TED-LIUM Release 3](https://www.openslr.org/51/), using the segmented manual transcripts included in the release. 12 | 13 | ### Common Voice 5.1 14 | 15 | We downloaded the English subset of Common Voice Corpus 5.1 from [the official website](https://commonvoice.mozilla.org/en/datasets) 16 | 17 | ### Artie 18 | 19 | We used the [Artie bias corpus](https://github.com/artie-inc/artie-bias-corpus). This is a subset of the Common Voice dataset. 20 | 21 | ### CallHome & Switchboard 22 | 23 | We used the two corpora from [LDC2002S09](https://catalog.ldc.upenn.edu/LDC2002S09) and [LDC2002T43](https://catalog.ldc.upenn.edu/LDC2002T43) and followed the [eval2000_data_prep.sh](https://github.com/kaldi-asr/kaldi/blob/master/egs/fisher_swbd/s5/local/eval2000_data_prep.sh) script for preprocessing. The `wav.scp` files can be converted to WAV files with the following bash commands: 24 | 25 | ```bash 26 | mkdir -p wav 27 | while read name cmd; do 28 | echo $name 29 | echo ${cmd/\|/} wav/$name.wav | bash 30 | done < wav.scp 31 | ``` 32 | 33 | 34 | ### WSJ 35 | 36 | We used [LDC93S6B](https://catalog.ldc.upenn.edu/LDC93S6B) and [LDC94S13B](https://catalog.ldc.upenn.edu/LDC94S13B) and followed the [s5 recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/wsj/s5) to preprocess the dataset. 37 | 38 | ### CORAAL 39 | 40 | We used the 231 interviews from [CORAAL (v. 2021.07)](https://oraal.uoregon.edu/coraal) and used the segmentations from [the FairSpeech project](https://github.com/stanford-policylab/asr-disparities/blob/master/input/CORAAL_transcripts.csv). 41 | 42 | ### CHiME-6 43 | 44 | We downloaded the [CHiME-5 dataset](https://spandh.dcs.shef.ac.uk//chime_challenge/CHiME5/download.html) and followed the stage 0 of the [s5_track1 recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/chime6/s5_track1) to create the CHiME-6 dataset which fixes synchronization. We then used the binaural recordings (`*_P??.wav`) and the corresponding transcripts. 45 | 46 | ### AMI-IHM, AMI-SDM1 47 | 48 | We preprocessed the [AMI Corpus](https://groups.inf.ed.ac.uk/ami/corpus/overview.shtml) by following the stage 0 ad 2 of the [s5b recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/ami/s5b). 49 | 50 | 51 | ## Long-form English-only datasets 52 | 53 | ### TED-LIUM 3 54 | 55 | To create a long-form transcription dataset from the [TED-LIUM3](https://www.openslr.org/51/) dataset, we sliced the audio between the beginning of the first labeled segment and the end of the last labeled segment of each talk, and we used the concatenated text as the label. Below are the timestamps used for slicing each of the 11 TED talks in the test split. 56 | 57 | | Filename | Begin time (s) | End time (s) | 58 | |---------------------|----------------|--------------| 59 | | DanBarber_2010 | 16.09 | 1116.24 | 60 | | JaneMcGonigal_2010 | 15.476 | 1187.61 | 61 | | BillGates_2010 | 15.861 | 1656.94 | 62 | | TomWujec_2010U | 16.26 | 402.17 | 63 | | GaryFlake_2010 | 16.06 | 367.14 | 64 | | EricMead_2009P | 18.434 | 536.44 | 65 | | MichaelSpecter_2010 | 16.11 | 979.312 | 66 | | DanielKahneman_2010 | 15.8 | 1199.44 | 67 | | AimeeMullins_2009P | 17.82 | 1296.59 | 68 | | JamesCameron_2010 | 16.75 | 1010.65 | 69 | | RobertGupta_2010U | 16.8 | 387.03 | 70 | 71 | ### Meanwhile 72 | 73 | This dataset consists of 64 segments from The Late Show with Stephen Colbert. The YouTube video ID, start and end timestamps, and the labels can be found in [meanwhile.json](meanwhile.json). The labels are collected from the closed-caption data for each video and corrected with manual inspection. 74 | 75 | ### Rev16 76 | 77 | We use a subset of 16 files from the 30 podcast episodes in [Rev.AI's Podcast Transcription Benchmark](https://www.rev.ai/blog/podcast-transcription-benchmark-part-1/), after finding that there are multiple cases where a significant portion of the audio and the labels did not match, mostly on the parts introducing the sponsors. We selected 16 episodes that do not have this error, whose "file number" are: 78 | 79 | 3 4 9 10 11 14 17 18 20 21 23 24 26 27 29 32 80 | 81 | ### Kincaid46 82 | 83 | This dataset consists of 46 audio files and the corresponding transcripts compiled in the blog article [Which automatic transcription service is the most accurate - 2018](https://medium.com/descript/which-automatic-transcription-service-is-the-most-accurate-2018-2e859b23ed19) by Jason Kincaid. We used the 46 audio files and reference transcripts from the Airtable widget in the article. 84 | 85 | For the human transcription benchmark in the paper, we use a subset of 25 examples from this data, whose "Ref ID" are: 86 | 87 | 2 4 5 8 9 10 12 13 14 16 19 21 23 25 26 28 29 30 33 35 36 37 42 43 45 88 | 89 | ### Earnings-21, Earnings-22 90 | 91 | For these datasets, we used the files available in [the speech-datasets repository](https://github.com/revdotcom/speech-datasets), as of their `202206` version. 92 | 93 | ### CORAAL 94 | 95 | We used the 231 interviews from [CORAAL (v. 2021.07)](https://oraal.uoregon.edu/coraal) and used the full-length interview files and transcripts. 96 | 97 | 98 | ## Multilingual datasets 99 | 100 | ### Multilingual LibriSpeech 101 | 102 | We used the test splits from each language in [the Multilingual LibriSpeech (MLS) corpus](https://www.openslr.org/94/). 103 | 104 | ### Fleurs 105 | 106 | We collected audio files and transcripts using the implementation available as [HuggingFace datasets](https://huggingface.co/datasets/google/fleurs/blob/main/fleurs.py). To use as a translation dataset, we matched the numerical utterance IDs to find the corresponding transcript in English. 107 | 108 | ### VoxPopuli 109 | 110 | We used the `get_asr_data.py` script from [the official repository](https://github.com/facebookresearch/voxpopuli) to collect the ASR data in 14 languages. 111 | 112 | ### Common Voice 9 113 | 114 | We downloaded the Common Voice Corpus 9 from [the official website](https://commonvoice.mozilla.org/en/datasets) 115 | 116 | ### CoVOST 2 117 | 118 | We collected the `X into English` data collected using [the official repository](https://github.com/facebookresearch/covost). 119 | -------------------------------------------------------------------------------- /landing-page-metrics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akashmjn/tinydiarize/7cba47def707514fe68bbd6663f97663ea482158/landing-page-metrics.png -------------------------------------------------------------------------------- /model-card.md: -------------------------------------------------------------------------------- 1 | # Model Card: Whisper 2 | 3 | This is the official codebase for running the automatic speech recognition (ASR) models (Whisper models) trained and released by OpenAI. 4 | 5 | Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993), we're providing some information about the automatic speech recognition model. More information on how these models were trained and evaluated can be found [in the paper](https://arxiv.org/abs/2212.04356). 6 | 7 | 8 | ## Model Details 9 | 10 | The Whisper models are trained for speech recognition and translation tasks, capable of transcribing speech audio into the text in the language it is spoken (ASR) as well as translated into English (speech translation). Researchers at OpenAI developed the models to study the robustness of speech processing systems trained under large-scale weak supervision. There are 9 models of different sizes and capabilities, summarized in the following table. 11 | 12 | | Size | Parameters | English-only model | Multilingual model | 13 | |:------:|:----------:|:------------------:|:------------------:| 14 | | tiny | 39 M | ✓ | ✓ | 15 | | base | 74 M | ✓ | ✓ | 16 | | small | 244 M | ✓ | ✓ | 17 | | medium | 769 M | ✓ | ✓ | 18 | | large | 1550 M | | ✓ | 19 | 20 | In December 2022, we [released an improved large model named `large-v2`](https://github.com/openai/whisper/discussions/661). 21 | 22 | 23 | ### Release date 24 | 25 | September 2022 (original series) and December 2022 (`large-v2`) 26 | 27 | ### Model type 28 | 29 | Sequence-to-sequence ASR (automatic speech recognition) and speech translation model 30 | 31 | ### Paper & samples 32 | 33 | [Paper](https://arxiv.org/abs/2212.04356) / [Blog](https://openai.com/blog/whisper) 34 | 35 | 36 | ## Model Use 37 | 38 | ### Evaluated Use 39 | 40 | The primary intended users of these models are AI researchers studying robustness, generalization, capabilities, biases, and constraints of the current model. However, Whisper is also potentially quite useful as an ASR solution for developers, especially for English speech recognition. We recognize that once models are released, it is impossible to restrict access to only “intended” uses or to draw reasonable guidelines around what is or is not research. 41 | 42 | The models are primarily trained and evaluated on ASR and speech translation to English tasks. They show strong ASR results in ~10 languages. They may exhibit additional capabilities, particularly if fine-tuned on certain tasks like voice activity detection, speaker classification, or speaker diarization but have not been robustly evaluated in these areas. We strongly recommend that users perform robust evaluations of the models in a particular context and domain before deploying them. 43 | 44 | In particular, we caution against using Whisper models to transcribe recordings of individuals taken without their consent or purporting to use these models for any kind of subjective classification. We recommend against use in high-risk domains like decision-making contexts, where flaws in accuracy can lead to pronounced flaws in outcomes. The models are intended to transcribe and translate speech, use of the model for classification is not only not evaluated but also not appropriate, particularly to infer human attributes. 45 | 46 | 47 | ## Training Data 48 | 49 | The models are trained on 680,000 hours of audio and the corresponding transcripts collected from the internet. 65% of this data (or 438,000 hours) represents English-language audio and matched English transcripts, roughly 18% (or 126,000 hours) represents non-English audio and English transcripts, while the final 17% (or 117,000 hours) represents non-English audio and the corresponding transcript. This non-English data represents 98 different languages. 50 | 51 | As discussed in [the accompanying paper](https://arxiv.org/abs/2212.04356), we see that performance on transcription in a given language is directly correlated with the amount of training data we employ in that language. 52 | 53 | 54 | ## Performance and Limitations 55 | 56 | Our studies show that, over many existing ASR systems, the models exhibit improved robustness to accents, background noise, technical language, as well as zero shot translation from multiple languages into English; and that accuracy on speech recognition and translation is near the state-of-the-art level. 57 | 58 | However, because the models are trained in a weakly supervised manner using large-scale noisy data, the predictions may include texts that are not actually spoken in the audio input (i.e. hallucination). We hypothesize that this happens because, given their general knowledge of language, the models combine trying to predict the next word in audio with trying to transcribe the audio itself. 59 | 60 | Our models perform unevenly across languages, and we observe lower accuracy on low-resource and/or low-discoverability languages or languages where we have less training data. The models also exhibit disparate performance on different accents and dialects of particular languages, which may include higher word error rate across speakers of different genders, races, ages, or other demographic criteria. Our full evaluation results are presented in [the paper accompanying this release](https://arxiv.org/abs/2212.04356). 61 | 62 | In addition, the sequence-to-sequence architecture of the model makes it prone to generating repetitive texts, which can be mitigated to some degree by beam search and temperature scheduling but not perfectly. Further analysis on these limitations are provided in [the paper](https://arxiv.org/abs/2212.04356). It is likely that this behavior and hallucinations may be worse on lower-resource and/or lower-discoverability languages. 63 | 64 | 65 | ## Broader Implications 66 | 67 | We anticipate that Whisper models’ transcription capabilities may be used for improving accessibility tools. While Whisper models cannot be used for real-time transcription out of the box – their speed and size suggest that others may be able to build applications on top of them that allow for near-real-time speech recognition and translation. The real value of beneficial applications built on top of Whisper models suggests that the disparate performance of these models may have real economic implications. 68 | 69 | There are also potential dual use concerns that come with releasing Whisper. While we hope the technology will be used primarily for beneficial purposes, making ASR technology more accessible could enable more actors to build capable surveillance technologies or scale up existing surveillance efforts, as the speed and accuracy allow for affordable automatic transcription and translation of large volumes of audio communication. Moreover, these models may have some capabilities to recognize specific individuals out of the box, which in turn presents safety concerns related both to dual use and disparate performance. In practice, we expect that the cost of transcription is not the limiting factor of scaling up surveillance projects. 70 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | 3 | [tool.isort] 4 | profile = "black" 5 | include_trailing_comma = true 6 | line_length = 88 7 | multi_line_output = 3 8 | 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numba 2 | numpy 3 | torch 4 | tqdm 5 | more-itertools 6 | transformers>=4.19.0 7 | ffmpeg-python==0.2.0 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | import sys 4 | 5 | import pkg_resources 6 | from setuptools import find_packages, setup 7 | 8 | 9 | def read_version(fname="whisper/version.py"): 10 | exec(compile(open(fname, encoding="utf-8").read(), fname, "exec")) 11 | return locals()["__version__"] 12 | 13 | 14 | requirements = [] 15 | if sys.platform.startswith("linux") and platform.machine() == "x86_64": 16 | requirements.append("triton==2.0.0") 17 | 18 | setup( 19 | name="openai-whisper", 20 | py_modules=["whisper"], 21 | version=read_version(), 22 | description="Robust Speech Recognition via Large-Scale Weak Supervision", 23 | long_description=open("README.md", encoding="utf-8").read(), 24 | long_description_content_type="text/markdown", 25 | readme="README.md", 26 | python_requires=">=3.8", 27 | author="OpenAI", 28 | url="https://github.com/openai/whisper", 29 | license="MIT", 30 | packages=find_packages(exclude=["tests*"]), 31 | install_requires=requirements 32 | + [ 33 | str(r) 34 | for r in pkg_resources.parse_requirements( 35 | open(os.path.join(os.path.dirname(__file__), "requirements.txt")) 36 | ) 37 | ], 38 | entry_points={ 39 | "console_scripts": ["whisper=whisper.transcribe:cli"], 40 | }, 41 | include_package_data=True, 42 | extras_require={"dev": ["pytest", "scipy", "black", "flake8", "isort"]}, 43 | ) 44 | -------------------------------------------------------------------------------- /tdrz_dev/README.md: -------------------------------------------------------------------------------- 1 | # tinydiarize dev 🔨📊 2 | 3 | This directory contains tools to aid development and analysis. This can be used to reproduce results and take a closer look at analysis from the blog post. Contents: 4 | - [score.py](score.py) to measure and compare accuracy on your own data with easy to interpret metrics (WER, speaker turn precision/recall). 5 | - [run_pipelines.py](scripts/run_pipelines.py) shows how to run and compare various diarization pipelines. 6 | - [analysis.ipynb](https://nbviewer.org/github/akashmjn/tinydiarize/blob/main/tdrz_dev/notebooks/analysis.ipynb) walks through a comparison of various pipelines with a deep dive to understand sources of errors. 7 | - Code to reproduce finetuning will also be released shortly. 8 | 9 | It has extra [setup and dependencies](#setup) that are not required for inference. 10 | 11 | ## Analysis 12 | 13 | In the accompanying notebook [analysis.ipynb](https://nbviewer.org/github/akashmjn/tinydiarize/blob/main/tdrz_dev/notebooks/analysis.ipynb) ([![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/akashmjn/tinydiarize/blob/main/tdrz_dev/notebooks/analysis.ipynb)), we show that: 14 | - Whisper models already have a good internal representation of speaker turns via both acoustic and semantic cues. 15 | - Their placement of `punctuation` tokens appears to be very consistent with speaker turns (high recall). 16 | - Whisper's time segments `segment_timestamped` (used for clustering [here](https://huggingface.co/spaces/vumichien/Whisper_speaker_diarization)) are less consistent. 17 | - Acoustic embedding-based diarization methods like [`pyannote_pre_sr`](https://github.com/pyannote/pyannote-audio) perform well overall, but struggle with short segments & quick speaker turns. This leaves a gap with the best individual precision or recall. 18 | - `tdrz_token` shows that we can extract Whisper's speaker representations cheaply, and with small word error rate impact. 19 | - With improvements to finetuning, strong performance can be expected as it can use both voice and semantic context to tell speakers apart, which is a unique benefit of this approach. 20 | 21 | *The following numbers are scored on a set of [3 earnings calls](https://github.com/revdotcom/speech-datasets/tree/main/earnings21) (~23k words, ~300 spk turns). Qualitative analysis/error inspection helps us draw these conclusions.* 22 | 23 | |model|small.en| | |small.en-tdrz| 24 | |:----|:----|:----|:----|:----| 25 | |method|punctuation|pyannote_pre_sr|segment_timestamped|tdrz_token| 26 | |metric| | | | | 27 | |spk_turn_precision|19.5|83.4|14.5|97.7| 28 | |spk_turn_recall|92.0|78.4|86.7|70.8| 29 | |wer_overall|11.0|12.9|11.0|10.3| 30 | |wer_speaker_switch|15.0|23.11|15.0|15.5| 31 | 32 | ![metrics](barplots.png) 33 | 34 | ## Runtime cost estimate 35 | 36 | Local diarization done by `tdrz` comes at a marginal added cost. If we account for an additional clustering step (implemented [here](scripts/diarize_post_sr.py)), this can still be quite cheap overall. 37 | 38 | |Stage|Runtime (s)|Extra cost (%)| 39 | |:----|:----|:----| 40 | |Whisper.transcribe|121.2|-| 41 | |Pyannote diarization|56.6|47%| 42 | |Clustering whisper time segments|3.4|3%| 43 | |Whisper.transcribe (tdrz)|131.5|8%| 44 | 45 | *These numbers were tested using the earnings21-4374910 call (33.8 min) on a Quadro RTX 5000 GPU. Whisper is run with beam_size=4 and condition_on_previous_text=True.* 46 | 47 | ## Setup 48 | 49 | You'll need to setup the following dependencies. 50 | 51 | 1. In a fresh python environment (I used python=3.9), run the following (from the root of the github repo): 52 | ``` 53 | pip install -e . 54 | cd tdrz_dev 55 | pip install -r extra_requirements.txt 56 | ``` 57 | 2. `docker pull revdotcom/fstalign`. This installs [revdotcom/fstalign](https://github.com/revdotcom/fstalign) for scoring 58 | 3. [Pyannote](https://github.com/pyannote/pyannote-audio) consent prerequisites: 59 | - visit hf.co/pyannote/speaker-diarization and accept user conditions 60 | - visit hf.co/pyannote/segmentation and accept user conditions 61 | - visit hf.co/settings/tokens to create an access token and save it to a text file at `tdrz_dev/scripts/HF_TOK.txt`. 62 | -------------------------------------------------------------------------------- /tdrz_dev/barplots.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akashmjn/tinydiarize/7cba47def707514fe68bbd6663f97663ea482158/tdrz_dev/barplots.png -------------------------------------------------------------------------------- /tdrz_dev/extra_requirements.txt: -------------------------------------------------------------------------------- 1 | pandas 2 | matplotlib 3 | notebook 4 | git+https://github.com/pyannote/pyannote-audio 5 | datasets 6 | -------------------------------------------------------------------------------- /tdrz_dev/notebooks/analysis_utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import re 3 | from pathlib import Path 4 | 5 | import IPython.display as ipd 6 | import matplotlib.pyplot as plt 7 | import pandas as pd 8 | from score import parse_analysis_file 9 | 10 | # edit matplotlib theme to a better looking one 11 | plt.style.use("tableau-colorblind10") 12 | 13 | 14 | def parse_result_id(input): 15 | """Parse the result_id into model, call_id, method""" 16 | # rename some methods to more understandable names 17 | method_map = { 18 | "drz_pre_sr__segment": "pyannote_pre_sr", 19 | "segment": "segment_timestamped", 20 | "drz_post_sr__segment": "segment_timestamped_clustered", 21 | "token": "tdrz_token", 22 | } 23 | 24 | def _parse_result_id(result_id): 25 | model, call_id, method = re.search( 26 | r"(.*)__earnings21-([0-9]+)_+(.*)", result_id 27 | ).groups() 28 | model = model.split("_")[0] # remove the suffix after -ft_* 29 | method = method_map.get(method, method) 30 | return model, call_id, method 31 | 32 | if isinstance(input, str): 33 | return _parse_result_id(input) 34 | else: 35 | return zip(*[_parse_result_id(r) for r in input]) 36 | 37 | 38 | def compile_results(results_dir): 39 | # read and compile both scoring_results and analysis_results 40 | # compile all scoring_results.tsv files 41 | scored_tsvs = [] 42 | tsv_list = glob.glob(f"{results_dir}/**/scoring_results.tsv", recursive=True) 43 | for tsv in tsv_list: 44 | scored_tsvs.append(pd.read_csv(tsv, sep="\t")) 45 | results_df = pd.concat(scored_tsvs) 46 | results_df["model"], results_df["call_id"], results_df["method"] = parse_result_id( 47 | results_df.result_id 48 | ) 49 | print(f"Read {len(results_df)} results from {len(tsv_list)} files") 50 | 51 | # # collect all side-by-side analysis results 52 | analysis_results = dict() 53 | analysis_sbs_list = glob.glob( 54 | f"{results_dir}/**/spk_turn/results/*.sbs", recursive=True 55 | ) 56 | for sbs in analysis_sbs_list: 57 | # get the result_id from the path 58 | result_id = Path(sbs).parts[-4] # directory name of ** 59 | precision_errors, recall_errors = parse_analysis_file(sbs) 60 | key_tuple = parse_result_id(result_id) 61 | analysis_results[key_tuple] = dict( 62 | precision_errors=precision_errors, recall_errors=recall_errors 63 | ) 64 | print(f"Read {len(analysis_results)} side-by-side analysis results") 65 | 66 | return results_df, analysis_results 67 | 68 | 69 | def summarize_results(results_df, omit_extra_results=True): 70 | # omit some methods for brevity 71 | if omit_extra_results: 72 | # WARNING: silent modification by reference 73 | results_df = results_df[results_df["method"] != "segment_timestamped_clustered"] 74 | results_df = results_df[ 75 | ~( 76 | (results_df["model"] == "small.en-tdrz") 77 | & (results_df["method"] != "tdrz_token") 78 | ) 79 | ] 80 | # create a summarized dataframe that sums errors for each model+method combination 81 | summary_df = ( 82 | results_df.groupby(["metric", "model", "method"]) 83 | .sum(numeric_only=True) 84 | .reset_index() 85 | ) 86 | summary_df["value"] = round( 87 | summary_df["numerator"] / summary_df["denominator"] * 100, 1 88 | ) 89 | # print some summary statistics 90 | num_words = summary_df.query('metric == "wer_overall"')["denominator"].values[0] 91 | num_speaker_turns = summary_df.query('metric == "spk_turn_recall"')[ 92 | "denominator" 93 | ].values[0] 94 | print( 95 | f"Total # of words: {num_words}, Total # of speaker turns: {num_speaker_turns}" 96 | ) 97 | # pivot to a row for each metric, and a column for each model+method combination 98 | return ( 99 | summary_df.pivot(index="metric", columns=["model", "method"], values="value"), 100 | results_df, 101 | ) 102 | 103 | 104 | def query_metric_results(results_df, metric, groups=["call_id", "method"]): 105 | """Query the results for a given metric""" 106 | metric_df = ( 107 | results_df.query(f'metric=="{metric}"') 108 | .groupby(groups)["value"] 109 | .first() 110 | .unstack() 111 | .round(2) 112 | ) 113 | return metric_df 114 | 115 | 116 | def plot_metric_results(metric_df, title=None, ax=None, legend=True): 117 | if ax is None: 118 | ax = plt.gca() 119 | metric_df.plot.barh(title=title, ax=ax, legend=legend, grid=True) 120 | if legend: 121 | # edit the legend of the plot so that it doesn't overlap with the plot 122 | _ = ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0) 123 | 124 | 125 | def inspect_spk_errors( 126 | results_df, 127 | analysis_results, 128 | result_key, 129 | precision_errors=[], 130 | recall_errors=[], 131 | display_results=False, 132 | ): 133 | model_id, call_id, method = result_key 134 | subset_df = results_df.query( 135 | f'model=="{model_id}" and call_id=="{call_id}" and method=="{method}"' 136 | ).iloc[:, :7] 137 | 138 | if display_results: 139 | ipd.display(subset_df) 140 | 141 | print(f"Results for: {result_key}") 142 | spk_turn_errors = analysis_results[result_key] 143 | precision = subset_df.query('metric == "spk_turn_precision"')["value"].values[0] 144 | num_false_positives = len(spk_turn_errors["precision_errors"]) 145 | recall = subset_df.query('metric == "spk_turn_recall"')["value"].values[0] 146 | num_false_negatives = len(spk_turn_errors["recall_errors"]) 147 | print(f"Precision: {precision:.2f}, # of false positives: {num_false_positives}") 148 | print(f"Recall: {recall:.2f}, # of false negatives: {num_false_negatives}") 149 | 150 | if len(precision_errors) > 0: 151 | print("\n", "--" * 5, "Spk turn precision errors:", "--" * 5) 152 | for idx in precision_errors: 153 | print( 154 | f"\nLine: {spk_turn_errors['precision_errors'][idx]['line']}, Index: {idx}" 155 | ) 156 | print(spk_turn_errors["precision_errors"][idx]["context"]) 157 | 158 | if len(recall_errors) > 0: 159 | print("\n", "--" * 5, "Spk turn recall errors:", "--" * 5) 160 | for idx in recall_errors: 161 | print( 162 | f"\nLine: {spk_turn_errors['recall_errors'][idx]['line']}, Index: {idx}" 163 | ) 164 | print(spk_turn_errors["recall_errors"][idx]["context"]) 165 | 166 | 167 | """ 168 | Nice-to-have TODOs: 169 | - parse analysis results into custom class with 170 | errors uniquely identified by ref word # 171 | configurable context 172 | - enable diff between two sets of errors 173 | - make a neat side-by-side fixed width print 174 | """ 175 | 176 | 177 | # function to print two strings side-by-side with a fixed width 178 | def print_side_by_side(s1, s2, width=50): 179 | # split the strings into lines 180 | s1_lines = s1.splitlines() 181 | s2_lines = s2.splitlines() 182 | # get the maximum number of lines 183 | max_lines = max(len(s1_lines), len(s2_lines)) 184 | # pad the lines with empty strings 185 | s1_lines += [""] * (max_lines - len(s1_lines)) 186 | s2_lines += [""] * (max_lines - len(s2_lines)) 187 | # print the lines side-by-side 188 | for s1, s2 in zip(s1_lines, s2_lines): 189 | s1 = s1.rsplit("\t", 2)[ 190 | 0 191 | ] # remove the last 2 columns, keep only words and ERR hints 192 | s2 = s2.rsplit("\t", 2)[0] 193 | print(f"{s1: <{width}}{s2: <{width}}") 194 | -------------------------------------------------------------------------------- /tdrz_dev/notebooks/duplicate-serialization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akashmjn/tinydiarize/7cba47def707514fe68bbd6663f97663ea482158/tdrz_dev/notebooks/duplicate-serialization.png -------------------------------------------------------------------------------- /tdrz_dev/score.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import re 6 | import shutil 7 | import subprocess 8 | from copy import deepcopy 9 | from glob import glob 10 | from pathlib import Path 11 | 12 | import pandas as pd 13 | 14 | DESCRIPTION = """ 15 | Score whisper reco with speaker turns added. 16 | WER and speaker turn errors are computed using revdotcom/fstalign via edit 17 | distance of ref/reco transcripts. Speaker turn errors are measured by inserting a 18 | special token, re-aligning via edit distance 19 | and using the token-level errors (e.g. precision/recall) exposed by the tool. 20 | """ 21 | 22 | 23 | # need to replace <|speakerturn|> since fstalign treats '|' as a separator 24 | SCRIPT = Path(__file__).parent.absolute() / "score_fstalign.sh" 25 | WHISPER_ST_TOKEN = "<|speakerturn|>" 26 | ST_TOKEN = "SPEAKER__TURN" 27 | PUNCTUATION = set([".", "?", "!", ",", ";"]) 28 | ENDING_PUNCTUATION = set([".", "?", "!"]) 29 | 30 | 31 | # TODO@Akash - replace with proper whisper normalizer/tokenizer 32 | # returns list of tuples (token, trailing punctuation) 33 | def _tokenize_line(line): 34 | # handle "word.<|speakerturn|>", "word. <|speakerturn|> word", "word. <|speakerturn|>word" 35 | line = line.replace(WHISPER_ST_TOKEN, " " + ST_TOKEN + " ") 36 | # split into tokens 37 | tokens = [t for t in re.split(r"[-\s]", line) if len(t.strip()) > 0] 38 | 39 | # handle ending punctuation 40 | result = [] 41 | for t in tokens: 42 | if t[-1] in PUNCTUATION: 43 | if len(t) == 1: 44 | print( 45 | f"WARNING: skipping token with only punctuation: `{t}` in line: `{line}`" 46 | ) 47 | continue 48 | p = t[-1] 49 | t = t[:-1] 50 | else: 51 | p = "" 52 | result.append((t, p)) 53 | 54 | return result 55 | 56 | 57 | # take earnings nlp file, convert and add speaker turn tags, remove extra entity tags 58 | def prepare_ref_nlp(nlp_file, output_dir=None, tag_speaker_turns=True): 59 | with open(nlp_file) as fp: 60 | raw_lines = fp.read().splitlines() 61 | 62 | suffix = "_tagged.nlp" if tag_speaker_turns else ".nlp" 63 | fname = Path(nlp_file).name.replace(".nlp", suffix) 64 | output_dir = Path(nlp_file).parent if output_dir is None else Path(output_dir) 65 | output_file = output_dir / fname 66 | n = 0 67 | with open(output_file, "w") as fp: 68 | for i, line in enumerate(raw_lines): 69 | if i == 0: 70 | fp.write(line + "\n") 71 | continue 72 | elif i == 1: 73 | speaker = line.split("|")[1] 74 | 75 | if line.split("|")[1] != speaker and tag_speaker_turns: 76 | # add the speaker turn tag 77 | fp.write( 78 | f"{ST_TOKEN}|0||||LC|['{n}:SPEAKER_TURN']|['{n}']" + "\n" 79 | ) # entity ids must be unique 80 | n += 1 81 | 82 | # replace all with blank tags 83 | line = "|".join(line.split("|")[:6]) + "|[]|[]" 84 | fp.write(line + "\n") 85 | speaker = line.split("|")[1] 86 | 87 | logging.debug(f"Written {i+n} lines to {output_file}") 88 | return output_file 89 | 90 | 91 | # take whisper reco json, write out in nlp format, with speaker turn tokens added in two modes - 92 | # one where speaker turn are added at segment boundaries, and one where they are added after every [.?!] punctuation token 93 | def whisper_reco_to_nlp(reco_json, output_dir=None, speaker_turn_mode="segment"): 94 | """ 95 | token|speaker|ts|endTs|punctuation|case|tags 96 | good|1||||| 97 | """ 98 | with open(reco_json) as fp: 99 | reco = json.load(fp) 100 | 101 | output_dir = Path(reco_json).parent if output_dir is None else Path(output_dir) 102 | suffix = f"_spkturn_{speaker_turn_mode}.nlp" if speaker_turn_mode else ".nlp" 103 | fname = Path(reco_json).stem + suffix 104 | output_file = output_dir / fname 105 | 106 | n = 0 107 | with open(output_file, "w") as fp: 108 | fp.write("token|speaker|ts|endTs|punctuation|case|tags\n") 109 | n += 1 110 | for i, segment in enumerate(reco["segments"]): 111 | if i == 0 and speaker_turn_mode == "segment": 112 | curr_speaker = segment.get("speaker", "") 113 | 114 | if speaker_turn_mode == "segment": 115 | # if speakers are provided, add speaker turn if speaker differs from previous one 116 | if "speaker" in segment: 117 | if segment["speaker"] != curr_speaker: 118 | fp.write(f"{ST_TOKEN}|0|||||\n") 119 | n += 1 120 | curr_speaker = segment["speaker"] 121 | # else default add speaker turn before every new segment 122 | elif n > 0: 123 | fp.write(f"{ST_TOKEN}|0|||||\n") 124 | n += 1 125 | 126 | for token, punc in _tokenize_line(segment["text"]): 127 | if token == ST_TOKEN: 128 | if speaker_turn_mode and "token" in speaker_turn_mode: 129 | fp.write(f"{ST_TOKEN}|0|||{punc}||\n") 130 | n += 1 131 | # strips out speaker turn tokens if speaker_turn_mode is None 132 | continue 133 | fp.write(f"{token}|0|||{punc}||\n") 134 | n += 1 135 | # add speaker turn after every punctuation token 136 | if speaker_turn_mode and "punctuation" in speaker_turn_mode: 137 | if punc in ENDING_PUNCTUATION: 138 | fp.write(f"{ST_TOKEN}|0|||||\n") 139 | n += 1 140 | 141 | logging.debug(f"Written {n} lines to {output_file}") 142 | return output_file 143 | 144 | 145 | # function to read an .nlp file and strip out all speaker turn tokens 146 | def strip_speaker_turn_tokens(nlp_file, output_dir=None): 147 | with open(nlp_file) as fp: 148 | raw_lines = fp.read().splitlines() 149 | 150 | output_dir = Path(nlp_file).parent if output_dir is None else Path(output_dir) 151 | output_file = output_dir / Path(nlp_file).name.replace(".nlp", "_for_wer.nlp") 152 | n = 0 153 | with open(output_file, "w") as fp: 154 | for i, line in enumerate(raw_lines): 155 | if i == 0: 156 | fp.write(line + "\n") 157 | n += 1 158 | continue 159 | 160 | if line.split("|")[0] == ST_TOKEN: 161 | continue 162 | 163 | fp.write(line + "\n") 164 | n += 1 165 | 166 | logging.debug(f"Written {i+n} lines to {output_file}") 167 | 168 | return output_file 169 | 170 | 171 | def parse_result(wer_json, speaker_turn_json): 172 | with open(wer_json) as fp: 173 | wer_result = json.load(fp) 174 | with open(speaker_turn_json) as fp: 175 | speaker_turn_result = json.load(fp) 176 | 177 | columns = [ 178 | "value", 179 | "denominator", 180 | "numerator", 181 | "deletions", 182 | "insertions", 183 | "substitutions", 184 | ] 185 | metrics = [ 186 | "wer_overall", 187 | "wer_speaker_switch", 188 | "spk_turn_precision", 189 | "spk_turn_recall", 190 | ] 191 | result = dict.fromkeys(metrics) 192 | 193 | # WER results 194 | wer_map = { 195 | "wer": "value", 196 | "numWordsInReference": "denominator", 197 | "numErrors": "numerator", 198 | } 199 | for k in ["deletions", "insertions", "substitutions"]: 200 | wer_map[k] = k 201 | # insert values corresponding to columns 202 | result["wer_overall"] = { 203 | wer_map[k]: wer_result["wer"]["bestWER"][k] for k in wer_map 204 | } 205 | result["wer_overall"]["value"] *= 100.0 206 | result["wer_speaker_switch"] = { 207 | wer_map[k]: wer_result["wer"]["speakerSwitchWER"][k] for k in wer_map 208 | } 209 | result["wer_speaker_switch"]["value"] *= 100.0 210 | 211 | # speaker turn results 212 | speaker_turn = deepcopy(speaker_turn_result["wer"]["unigrams"][ST_TOKEN.lower()]) 213 | speaker_turn["numPredictions"] = sum( 214 | [speaker_turn[k] for k in ["correct", "insertions", "substitutions_fp"]] 215 | ) 216 | speaker_turn["numWordsInReference"] = sum( 217 | [speaker_turn[k] for k in ["correct", "deletions", "substitutions_fn"]] 218 | ) 219 | # insert values corresponding to columns 220 | precision_map = { 221 | "precision": "value", 222 | "numPredictions": "denominator", 223 | "correct": "numerator", 224 | "insertions": "insertions", 225 | "substitutions_fp": "substitutions", 226 | } 227 | result["spk_turn_precision"] = { 228 | precision_map[k]: speaker_turn[k] for k in precision_map 229 | } 230 | result["spk_turn_precision"]["deletions"] = 0 231 | recall_map = { 232 | "recall": "value", 233 | "numWordsInReference": "denominator", 234 | "correct": "numerator", 235 | "deletions": "deletions", 236 | "substitutions_fn": "substitutions", 237 | } 238 | result["spk_turn_recall"] = {recall_map[k]: speaker_turn[k] for k in recall_map} 239 | result["spk_turn_recall"]["insertions"] = 0 240 | 241 | result = pd.DataFrame.from_dict(result, orient="index", columns=columns) 242 | # reset the index to make the metric name a column 243 | result = result.reset_index().rename(columns={"index": "metric"}) 244 | 245 | return result 246 | 247 | 248 | def score_fstalign( 249 | ref_nlp, 250 | reco_file, 251 | result_name, 252 | work_dir="./workdir_analysis/fstalign_scoring/results", 253 | speaker_turn_mode="segment", 254 | ): 255 | """ 256 | Output directory structure: 257 | work_dir 258 | ├── result_name-{speaker_turn_mode} 259 | │ ├── wer 260 | │ │ ├── inputs 261 | │ │ ├── results 262 | | ├── spk_turn 263 | | | ├── ... (same as wer) 264 | │ ├── scoring_results.tsv 265 | """ 266 | 267 | # assert that the current directory is the parent dir of this file 268 | assert ( 269 | Path(__file__).parent.absolute() == Path.cwd() 270 | ), f"Please call score_fstalign from the parent directory of {__file__} (so that docker mounted paths work correctly)" 271 | 272 | # make reco/ref_nlp paths relative to parent dir of this file (so that docker mounted paths work correctly) 273 | ref_nlp = str(Path(ref_nlp).relative_to(Path(__file__).parent)) 274 | reco_file = str(Path(reco_file).relative_to(Path(__file__).parent)) 275 | 276 | def _make_subdir(parent_dir, subdir_name): 277 | subdir = Path(parent_dir) / subdir_name 278 | os.makedirs(subdir, exist_ok=True) 279 | return subdir 280 | 281 | # prepare output directories 282 | result_id = f"{result_name}__{speaker_turn_mode}" 283 | output_dir = Path(work_dir) / result_id 284 | wer_dir = _make_subdir(output_dir, "wer") 285 | spk_turn_dir = _make_subdir(output_dir, "spk_turn") 286 | 287 | # prepare inputs 288 | wer_inputs_dir = _make_subdir(wer_dir, "inputs") 289 | ref_nlp_for_wer = prepare_ref_nlp(ref_nlp, wer_inputs_dir, tag_speaker_turns=False) 290 | spk_turn_inputs_dir = _make_subdir(spk_turn_dir, "inputs") 291 | ref_nlp = prepare_ref_nlp(ref_nlp, spk_turn_inputs_dir, tag_speaker_turns=True) 292 | 293 | if Path(reco_file).suffix == ".nlp": 294 | # copy the reco file to inputs_dir 295 | reco_nlp = spk_turn_inputs_dir / Path(reco_file).name 296 | shutil.copyfile(reco_file, reco_nlp) 297 | reco_nlp_for_wer = strip_speaker_turn_tokens(reco_nlp, wer_inputs_dir) 298 | elif Path(reco_file).suffix == ".json": 299 | # convert to nlp format 300 | logging.debug(f"Converting reco to nlp format for scoring: {reco_file}") 301 | reco_nlp = whisper_reco_to_nlp( 302 | reco_file, spk_turn_inputs_dir, speaker_turn_mode=speaker_turn_mode 303 | ) 304 | reco_nlp_for_wer = whisper_reco_to_nlp( 305 | reco_file, wer_inputs_dir, speaker_turn_mode=None 306 | ) 307 | 308 | def _run_script(cmdlist): 309 | cmdlist = [str(c) for c in cmdlist] 310 | logging.debug(f"Running command: {' '.join(cmdlist)}") 311 | result = subprocess.check_output(cmdlist).decode("utf-8").splitlines()[-1] 312 | assert result.startswith("RESULT="), "Unexpected output from " + SCRIPT 313 | result_file = result.split("=")[1] 314 | assert os.path.exists(result_file), "Result file not found" 315 | return result_file 316 | 317 | # we need to call fstalign twice, once for WER and once for speaker turn errors 318 | wer_result_dir = _make_subdir(wer_dir, "results") 319 | wer_json = _run_script( 320 | ["sh", SCRIPT, ref_nlp_for_wer, reco_nlp_for_wer, wer_result_dir] 321 | ) 322 | spk_turn_result_dir = _make_subdir(spk_turn_dir, "results") 323 | speaker_turn_json = _run_script( 324 | ["sh", SCRIPT, ref_nlp, reco_nlp, spk_turn_result_dir] 325 | ) 326 | 327 | # process result 328 | result = parse_result(wer_json, speaker_turn_json) 329 | result["result_id"] = result_id 330 | # write to tsv file 331 | result.to_csv(output_dir / "scoring_results.tsv", sep="\t", index=False) 332 | 333 | return result, output_dir 334 | 335 | 336 | # TODO@Akash - make into a class 337 | def parse_analysis_file(sbs_analysis_file, context_lines=10): 338 | # precision error regex: (\tspeaker__turn\s+ERR) 339 | # recall error regex: (speaker__turn\t) 340 | with open(sbs_analysis_file) as f: 341 | lines = f.read().splitlines() 342 | 343 | def _get_context(i, lines, context_lines): 344 | before = lines[max(0, i - context_lines) : i] 345 | after = lines[i + 1 : i + context_lines + 1] 346 | # apply a regex to clean up lines before and after 347 | before = [re.sub(r"\t___.*", "\t\t", l).replace("ERR", "") for l in before] 348 | after = [re.sub(r"\t___.*", "\t\t", l).replace("ERR", "") for l in after] 349 | return "\n".join([*before, lines[i], *after]) 350 | 351 | # save line numbers where precision/recall errors occur 352 | precision_errors = [] 353 | recall_errors = [] 354 | # TODO@Akash - add the index of the token in the reference 355 | for i, line in enumerate(lines): 356 | if re.match(r"^.*\tspeaker__turn\s+ERR", line): # erroneous prediction 357 | # also save context of lines before and after 358 | context = _get_context(i, lines, context_lines) 359 | precision_errors.append(dict(line=i, context=context)) 360 | if re.match(r"^\s+speaker__turn\t.*\s+ERR", line): # missed label 361 | # also save context of 5 lines before and after 362 | context = _get_context(i, lines, context_lines) 363 | recall_errors.append(dict(line=i, context=context)) 364 | 365 | return precision_errors, recall_errors 366 | 367 | 368 | if __name__ == "__main__": 369 | print("NOTE: dont forget to pass the glob pattern inside quotes") 370 | parser = argparse.ArgumentParser(description=DESCRIPTION) 371 | parser.add_argument("glob_pattern", help="glob pattern for reco files") 372 | parser.add_argument( 373 | "--ref_nlp", 374 | help="reference nlp file", 375 | default="./workdir_analysis/fstalign_scoring/references/earnings21-4341191-ref_tagged.nlp", 376 | ) 377 | parser.add_argument( 378 | "--work_dir", 379 | help="working directory for fstalign scoring", 380 | default="./workdir_analysis/fstalign_scoring/results", 381 | ) 382 | parser.add_argument( 383 | "--speaker_turn_mode", 384 | help="speaker turn mode", 385 | choices=["segment", "punctuation", "token", "punctuation_token"], 386 | default="segment", 387 | ) 388 | args = parser.parse_args() 389 | 390 | ref_nlp = args.ref_nlp 391 | glob_pattern = args.glob_pattern 392 | 393 | reco_files = glob(glob_pattern) 394 | if ( 395 | input(f"Scoring {len(reco_files)} files under: {glob_pattern}, [y/n]\t").lower() 396 | == "n" 397 | ): 398 | exit() 399 | 400 | results = [] 401 | for reco_file in reco_files: 402 | # convert reco_file to result_name in this way 403 | # e.g. /home/whisper/tdrz_dev/tiny.en/d1/f.json -> tiny.en-d1-f 404 | result_name = "__".join(Path(reco_file).parts[-3:]).replace(".json", "") 405 | df, _ = score_fstalign( 406 | ref_nlp, 407 | reco_file, 408 | result_name, 409 | work_dir=args.work_dir, 410 | speaker_turn_mode=args.speaker_turn_mode, 411 | ) 412 | results.append(df) 413 | 414 | result_df = pd.concat(results) 415 | print(result_df.groupby(["metric", "result_id"]).sum()) 416 | -------------------------------------------------------------------------------- /tdrz_dev/score_fstalign.sh: -------------------------------------------------------------------------------- 1 | # Description: Score a hypothesis file using fstalign to get WER, per-token metrics, and side-by-side analysis 2 | # Usage: score_fstalign.sh [] 3 | # (the --wer-sidecar option is not used in the current version of the code) 4 | # Setup: docker pull revdotcom/fstalign (more info at https://github.com/revdotcom/fstalign) 5 | 6 | 7 | # these paths must fall within the current directory so that they can be mounted in the docker container 8 | REF=$1 9 | HYP=$2 10 | OUTDIR=$3 11 | FNAME=$(basename $HYP .nlp) 12 | 13 | # set -x 14 | 15 | # the current directory is mounted as /fstalign/workdir so all relative paths have to be relative to that 16 | PREFIX="/fstalign/workdir" 17 | CMD="/fstalign/build/fstalign wer \ 18 | --ref $PREFIX/$REF --hyp $PREFIX/$HYP \ 19 | --log $PREFIX/$OUTDIR/$FNAME.log --json-log $PREFIX/$OUTDIR/$FNAME.json --output-sbs $PREFIX/$OUTDIR/$FNAME.sbs" 20 | 21 | # this argument doesn't really seem to be necessary 22 | if ! [ -z "$4" ] 23 | then 24 | CMD="$CMD --wer-sidecar $4" 25 | fi 26 | 27 | # command is run inside the docker container 28 | docker run --rm -it -v $PWD:/fstalign/workdir revdotcom/fstalign $CMD 29 | 30 | # print absolute path of output filename to stdout 31 | echo "RESULT="$(realpath $OUTDIR/$FNAME.json) 32 | -------------------------------------------------------------------------------- /tdrz_dev/scripts/diarize_post_sr.py: -------------------------------------------------------------------------------- 1 | # Script adapted from https://huggingface.co/spaces/dwarkesh/whisper-speaker-recognition 2 | import contextlib 3 | import datetime 4 | import json 5 | import logging 6 | import os 7 | import subprocess 8 | import wave 9 | from pathlib import Path 10 | 11 | import numpy as np 12 | import torch 13 | from pyannote.audio import Audio 14 | from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding 15 | from pyannote.core import Segment 16 | from sklearn.cluster import AgglomerativeClustering 17 | from tqdm import tqdm 18 | 19 | WHISPERMODEL = "tiny.en" 20 | 21 | pyannote_audio = Audio() 22 | 23 | 24 | def convert_to_wav(path): 25 | if path[-3:] != "wav": 26 | wav_path = ".".join(path.split(".")[:-1]) + ".wav" 27 | try: 28 | subprocess.call(["ffmpeg", "-i", path, "-ar", "16000", wav_path, "-y"]) 29 | except Exception: 30 | return path, "Error: Could not convert file to .wav" 31 | path = wav_path 32 | return path, None 33 | 34 | 35 | def get_duration(wav_path): 36 | with contextlib.closing(wave.open(wav_path, "r")) as f: 37 | frames = f.getnframes() 38 | rate = f.getframerate() 39 | return frames / float(rate) 40 | 41 | 42 | def make_embeddings(embedding_model, wav_path, segments, duration): 43 | embeddings = np.zeros(shape=(len(segments), 192)) 44 | for i, segment in enumerate(tqdm(segments)): 45 | embeddings[i] = segment_embedding(embedding_model, wav_path, segment, duration) 46 | return np.nan_to_num(embeddings) 47 | 48 | 49 | def segment_embedding(embedding_model, wav_path, segment, duration): 50 | start = segment["start"] 51 | # Whisper overshoots the end timestamp in the last segment 52 | end = min(duration, segment["end"]) 53 | clip = Segment(start, end) 54 | waveform, sample_rate = pyannote_audio.crop(wav_path, clip) 55 | assert ( 56 | sample_rate == 16000 57 | ), f"Invalid sampling rate for spk embedding model {sample_rate}" 58 | return embedding_model(waveform[None]) 59 | 60 | 61 | def add_speaker_labels(segments, embeddings, num_speakers): 62 | clustering = AgglomerativeClustering(num_speakers).fit(embeddings) 63 | labels = clustering.labels_ 64 | for i in range(len(segments)): 65 | segments[i]["speaker"] = "SPEAKER " + str(labels[i] + 1) 66 | 67 | 68 | def time(secs): 69 | return datetime.timedelta(seconds=round(secs)) 70 | 71 | 72 | def get_output(segments): 73 | output = "" 74 | for i, segment in enumerate(segments): 75 | if i == 0 or segments[i - 1]["speaker"] != segment["speaker"]: 76 | if i != 0: 77 | output += "\n\n" 78 | output += segment["speaker"] + " " + str(time(segment["start"])) + "\n\n" 79 | output += segment["text"][1:] + " " 80 | return output 81 | 82 | 83 | def add_speakers_to_segments(audio, segments, num_speakers): 84 | embedding_model = PretrainedSpeakerEmbedding( 85 | "speechbrain/spkrec-ecapa-voxceleb", 86 | device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), 87 | ) 88 | 89 | wav_path, error = convert_to_wav(audio) 90 | if error is not None: 91 | return error 92 | 93 | duration = get_duration(wav_path) 94 | if duration > 4 * 60 * 60: 95 | return "Audio duration too long" 96 | 97 | num_speakers = min(max(round(num_speakers), 1), len(segments)) 98 | if len(segments) == 1: 99 | segments[0]["speaker"] = "SPEAKER 1" 100 | else: 101 | logging.info(f"Creating embeddings for {len(segments)} segments ..") 102 | embeddings = make_embeddings(embedding_model, wav_path, segments, duration) 103 | logging.info(f"Clustering embeddings into {num_speakers} speakers ..") 104 | add_speaker_labels(segments, embeddings, num_speakers) 105 | 106 | return segments 107 | 108 | 109 | if __name__ == "__main__": 110 | # get arguments from command line 111 | import argparse 112 | 113 | parser = argparse.ArgumentParser() 114 | parser.add_argument( 115 | "audio", type=str, default="../scratch/audio/earnings21-4341191.mp3" 116 | ) 117 | parser.add_argument( 118 | "reco_file", 119 | type=str, 120 | default="../scratch/transcripts/tiny.en-diarize_postsr/earnings21-4341191/earnings21-4341191-tiny.en.json", 121 | ) 122 | parser.add_argument("--num_speakers", type=int, default=5) 123 | parser.add_argument("--output_dir", type=str, default=None) 124 | args = parser.parse_args() 125 | 126 | output_dir = ( 127 | Path(args.reco_file).parent 128 | if args.output_dir is None 129 | else Path(args.output_dir) 130 | ) 131 | os.makedirs(output_dir, exist_ok=True) 132 | 133 | print("Loading reco file ..") 134 | print(args.reco_file) 135 | with open(args.reco_file) as fp: 136 | reco = json.load(fp) 137 | 138 | print(f"Clustering segments into {args.num_speakers} speakers ..") 139 | segments = add_speakers_to_segments(args.audio, reco["segments"], args.num_speakers) 140 | 141 | output_file = output_dir / Path(args.reco_file).name.replace( 142 | ".json", f"_drzpostsr_{args.num_speakers}.json" 143 | ) 144 | with open(output_file, "w") as fp: 145 | json.dump(reco, fp) 146 | 147 | print("Written to output file ..") 148 | print(output_file) 149 | -------------------------------------------------------------------------------- /tdrz_dev/scripts/diarize_pre_sr.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from pathlib import Path 5 | from statistics import median 6 | 7 | import torch 8 | from diarize_post_sr import convert_to_wav 9 | from pyannote.audio import Audio, Pipeline 10 | from pyannote.core import Segment 11 | from tqdm import tqdm 12 | 13 | import whisper 14 | import whisper.utils as wutils 15 | 16 | WHISPERMODEL = "small.en" 17 | TOKEN_FILE = "HF_TOK.txt" 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def run_pyannote_pipeline(audio_file, hf_token_file, num_speakers=None): 23 | # run pyannote diarization pipeline and save resulting segments 24 | audio_file, _ = convert_to_wav(audio_file) 25 | # raise an informative error if the token file is not found 26 | if not Path(hf_token_file).is_file(): 27 | logging.error( 28 | f"Could not find the HuggingFace token file at {hf_token_file}. " 29 | "Please create an account at https://huggingface.co/ and " 30 | " 1. visit hf.co/pyannote/speaker-diarization and accept user conditions " 31 | " 2. visit hf.co/pyannote/segmentation and accept user conditions " 32 | " 3. visit hf.co/settings/tokens to create an access token " 33 | " and save it to a text file at {hf_token_file}." 34 | ) 35 | raise FileNotFoundError( 36 | f"Could not find the HuggingFace token file at {hf_token_file}." 37 | ) 38 | # read the auth token 39 | with open(hf_token_file) as f: 40 | hf_tok = f.read().strip() 41 | logging.info("Creating pyannote diarization pipeline ..") 42 | pipeline = Pipeline.from_pretrained( 43 | "pyannote/speaker-diarization@2.1", use_auth_token=hf_tok 44 | ) 45 | pipeline.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) 46 | 47 | logging.info("Processing audio with pyannote diarization pipeline ..") 48 | diarization = pipeline(audio_file, num_speakers=num_speakers) 49 | diarized_segments = [] 50 | for turn, _, speaker in diarization.itertracks(yield_label=True): 51 | s = {"start": turn.start, "end": turn.end} 52 | s["speaker"] = speaker 53 | diarized_segments.append(s) 54 | 55 | return diarized_segments 56 | 57 | 58 | def transcribe_cropped_segments(audio_file, diarized_segments, whisper_model): 59 | # transcribe cropped segments with whisper 60 | model = whisper.load_model(whisper_model) 61 | pn_audio = Audio(sample_rate=16_000, mono=True) 62 | logging.info("Transcribing file with whisper ..") 63 | result = dict(text="", segments=[]) 64 | for segment in tqdm(diarized_segments): 65 | waveform, sr = pn_audio.crop( 66 | audio_file, Segment(segment["start"], segment["end"]) 67 | ) 68 | r = model.transcribe(waveform.squeeze().numpy(), verbose=None) 69 | # make all sub-segments relative to the original audio file 70 | for si in r["segments"]: 71 | si["start"] += segment["start"] 72 | si["end"] += segment["start"] 73 | si["speaker"] = segment["speaker"] 74 | 75 | result["text"] += r["text"] + " " 76 | result["segments"].extend(r["segments"]) 77 | 78 | return result 79 | 80 | 81 | def run_pre_sr_pipeline( 82 | audio_file, 83 | output_dir, 84 | num_speakers=None, 85 | hf_token_file=TOKEN_FILE, 86 | whisper_model=WHISPERMODEL, 87 | ): 88 | os.makedirs(output_dir, exist_ok=True) 89 | audio_file, _ = convert_to_wav(audio_file) 90 | 91 | # TODO@Akash - wrap this pattern into a decorator 92 | diarization_result_file = ( 93 | Path(output_dir) / f"{Path(audio_file).stem}-diarization.json" 94 | ) 95 | if not Path(diarization_result_file).is_file(): 96 | diarized_segments = run_pyannote_pipeline( 97 | audio_file, hf_token_file, num_speakers 98 | ) 99 | with open(diarization_result_file, "w") as f: 100 | json.dump(diarized_segments, f, indent=4) 101 | else: 102 | with open(diarization_result_file) as f: 103 | diarized_segments = json.load(f) 104 | 105 | # summarize diarization result 106 | segment_durations = [s["end"] - s["start"] for s in diarized_segments] 107 | num_speakers = len(set([s["speaker"] for s in diarized_segments])) 108 | logging.info( 109 | f"Diarized into {len(segment_durations)} segments with min/median/max duration: \ 110 | {min(segment_durations):.2f}/{median(segment_durations):.2f}/{max(segment_durations):.2f}" 111 | ) 112 | logging.info(f"Detected {num_speakers} unique speakers") 113 | 114 | # # visualize / explore diarization result 115 | # from pyannote.core import notebook 116 | # from pyannote.core import Segment 117 | # # notebook.crop = Segment(0.0, 600.0) # zoom into a region 118 | # notebook.crop = Segment(3000.0, 3600.0) # zoom into a region 119 | # # notebook.reset() 120 | 121 | # TODO@Akash - wrap this pattern into a decorator 122 | final_reco_file = Path(output_dir) / f"{Path(audio_file).stem}.json" 123 | if not Path(final_reco_file).is_file(): 124 | result = transcribe_cropped_segments( 125 | audio_file, diarized_segments, whisper_model 126 | ) 127 | # with open(final_reco_file, 'w') as f: 128 | # json.dump(result, f) 129 | else: 130 | with open(final_reco_file) as f: 131 | result = json.load(f) 132 | 133 | return result 134 | 135 | 136 | if __name__ == "__main__": 137 | # get arguments from command line 138 | import argparse 139 | 140 | parser = argparse.ArgumentParser() 141 | parser.add_argument( 142 | "audio", type=str, default="../scratch/audio/earnings21-4341191.mp3" 143 | ) 144 | parser.add_argument( 145 | "output_dir", 146 | type=str, 147 | default="../scratch/transcripts/tiny.en_drzpresr/earnings21-4341191/", 148 | ) 149 | parser.add_argument("--num_speakers", type=int, default=None) 150 | args = parser.parse_args() 151 | 152 | os.makedirs(args.output_dir, exist_ok=True) 153 | 154 | # run pipeline 155 | result = run_pre_sr_pipeline(args.audio, args.output_dir, args.num_speakers) 156 | 157 | # save result 158 | writer = wutils.get_writer("all", args.output_dir) 159 | writer(result, args.audio) 160 | -------------------------------------------------------------------------------- /tdrz_dev/scripts/fetch_earnings21_calls.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # raise an error if WORKDIR is not set 4 | if [ -z "$1" ]; then 5 | echo "WORKDIR is not set" 6 | exit 1 7 | fi 8 | 9 | AUDIODIR=$1/audio 10 | REFDIR=$1/fstalign_scoring/references 11 | mkdir -p $AUDIODIR $REFDIR 12 | 13 | # define an array of IDs of calls to download 14 | IDS=("4385939" "4374910" "4359971") 15 | 16 | # loop through the IDs 17 | for i in ${IDS[@]}; 18 | do 19 | # URL of audio file 20 | URL=https://github.com/revdotcom/speech-datasets/blob/main/earnings21/media/$i.mp3?raw=true 21 | # use wget to download file from a URL 22 | wget $URL -O $AUDIODIR/earnings21-$i.mp3 23 | 24 | # URL of transcript file 25 | URL=https://github.com/revdotcom/speech-datasets/raw/main/earnings21/transcripts/nlp_references/$i.nlp 26 | # use wget to download file from a URL 27 | wget $URL -O $REFDIR/earnings21-$i-ref.nlp 28 | done 29 | -------------------------------------------------------------------------------- /tdrz_dev/scripts/run_pipelines.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import sys 5 | from pathlib import Path 6 | 7 | import pandas as pd 8 | from diarize_post_sr import add_speakers_to_segments 9 | from diarize_pre_sr import run_pre_sr_pipeline 10 | 11 | import whisper 12 | import whisper.utils as wutils 13 | 14 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 15 | from score import score_fstalign # noqa: E402 16 | 17 | DESCRIPTION = """ 18 | Script that will run the following pipelines: 19 | 1. transcribe audio with whisper 20 | 2. apply post_sr diarization by clustering whisper segments 21 | 3. run pyannote pre_sr diarization and retranscribe segmented audio 22 | 4. transcribe audio with whisper-tdrz 23 | 5. score all the results 24 | """ 25 | 26 | 27 | def setup_logging(output_dir, audio_name): 28 | log_dir = f"{output_dir}/run_pipeline_logs" 29 | os.makedirs(log_dir, exist_ok=True) 30 | logging.basicConfig( 31 | format="%(asctime)s %(levelname)s %(message)s", 32 | datefmt="%Y-%m-%d %H:%M:%S", 33 | level=logging.DEBUG, 34 | ) 35 | # setup logging to also write to file in output_dir with same format as console 36 | fh = logging.FileHandler(f"{log_dir}/run_pipelines-{audio_name}.log") 37 | fh.setLevel(logging.DEBUG) 38 | formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") 39 | fh.setFormatter(formatter) 40 | logging.getLogger().addHandler(fh) 41 | return log_dir 42 | 43 | 44 | WHISPERMODEL = "small.en" 45 | TOKEN_FILE = "HF_TOK.txt" 46 | 47 | 48 | if __name__ == "__main__": 49 | import argparse 50 | 51 | parser = argparse.ArgumentParser(description=DESCRIPTION) 52 | parser.add_argument("audio_file", help="path to audio file") 53 | parser.add_argument( 54 | "ref_file", help="path to reference transcript file in nlp format" 55 | ) 56 | parser.add_argument("output_dir", help="path to output directory") 57 | parser.add_argument( 58 | "--num_speakers", 59 | type=int, 60 | help="provide the oracle number as we are only evaluating local diarization", 61 | ) 62 | parser.add_argument( 63 | "--pipelines_to_run", 64 | help="pipelines to run. either 'all' or a comma separated list of numbers 1-5", 65 | default="all", 66 | ) 67 | parser.add_argument( 68 | "--whisper_model", 69 | help="valid whisper model name or path to checkpoint", 70 | default=WHISPERMODEL, 71 | ) 72 | parser.add_argument( 73 | "--hf_token_file", 74 | help="text file containing HuggingFace token required for pyannote", 75 | default=TOKEN_FILE, 76 | ) 77 | args = parser.parse_args() 78 | 79 | files_to_score = [] 80 | pipelines_to_run = ( 81 | args.pipelines_to_run.split(",") 82 | if args.pipelines_to_run != "all" 83 | else ["1", "2", "3", "4", "5"] 84 | ) 85 | 86 | audio_file, audio_name = args.audio_file, Path(args.audio_file).stem 87 | log_dir = setup_logging(args.output_dir, audio_name) 88 | reco_fname = f"{audio_name}.json" 89 | ref_file = Path(args.ref_file).resolve() 90 | whisper_output_dir = f"{args.output_dir}/{args.whisper_model}/{audio_name}" 91 | 92 | # 1. transcribe audio with whisper 93 | if "1" in pipelines_to_run: 94 | logging.info("Transcribing audio with whisper ..") 95 | os.makedirs(whisper_output_dir, exist_ok=True) 96 | transcribe_result_file = (Path(whisper_output_dir) / reco_fname).resolve() 97 | if not Path(transcribe_result_file).is_file(): 98 | model = whisper.load_model(args.whisper_model) 99 | result = model.transcribe( 100 | audio_file, verbose=False, condition_on_previous_text=True, beam_size=4 101 | ) 102 | writer = wutils.get_writer("all", whisper_output_dir) 103 | writer(result, audio_file) 104 | else: 105 | with open(transcribe_result_file) as f: 106 | result = json.load(f) 107 | 108 | files_to_score.append((transcribe_result_file, "segment")) 109 | files_to_score.append((transcribe_result_file, "punctuation")) 110 | 111 | # 2. apply post_sr diarization by clustering whisper segments 112 | if "2" in pipelines_to_run: 113 | logging.info("Applying post_sr diarization ..") 114 | drz_post_sr_output_dir = whisper_output_dir + "_drz_post_sr" 115 | os.makedirs(drz_post_sr_output_dir, exist_ok=True) 116 | drz_post_sr_reco_file = (Path(drz_post_sr_output_dir) / reco_fname).resolve() 117 | if not Path(drz_post_sr_reco_file).is_file(): 118 | result["segments"] = add_speakers_to_segments( 119 | audio_file, result["segments"], num_speakers=args.num_speakers 120 | ) 121 | writer = wutils.get_writer("all", drz_post_sr_output_dir) 122 | writer(result, audio_file) 123 | 124 | files_to_score.append((drz_post_sr_reco_file, "segment")) 125 | 126 | # 3. run pyannote pre_sr diarization and retranscribe segmented audio 127 | if "3" in pipelines_to_run: 128 | logging.info("Running pre_sr diarization and retranscribing segmented audio ..") 129 | drz_pre_sr_output_dir = whisper_output_dir + "_drz_pre_sr" 130 | os.makedirs(drz_pre_sr_output_dir, exist_ok=True) 131 | drz_pre_sr_reco_file = (Path(drz_pre_sr_output_dir) / reco_fname).resolve() 132 | if not Path(drz_pre_sr_reco_file).is_file(): 133 | result = run_pre_sr_pipeline( 134 | audio_file, 135 | drz_pre_sr_output_dir, 136 | num_speakers=args.num_speakers, 137 | hf_token_file=args.hf_token_file, 138 | whisper_model=args.whisper_model, 139 | ) 140 | writer = wutils.get_writer("all", drz_pre_sr_output_dir) 141 | writer(result, audio_file) 142 | 143 | files_to_score.append((drz_pre_sr_reco_file, "segment")) 144 | 145 | # 4. transcribe audio with whisper-tdrz 146 | if "4" in pipelines_to_run: 147 | logging.info("Transcribing audio with whisper tinydiarize..") 148 | tdrz_output_dir = f"{args.output_dir}/{args.whisper_model}-tdrz/{audio_name}" 149 | os.makedirs(tdrz_output_dir, exist_ok=True) 150 | transcribe_result_file = (Path(tdrz_output_dir) / reco_fname).resolve() 151 | if not Path(transcribe_result_file).is_file(): 152 | model = whisper.load_model(args.whisper_model + "-tdrz") 153 | result = model.transcribe( 154 | audio_file, verbose=False, condition_on_previous_text=True, beam_size=4 155 | ) 156 | writer = wutils.get_writer("all", tdrz_output_dir) 157 | writer(result, audio_file) 158 | 159 | files_to_score.append((transcribe_result_file, "token")) 160 | files_to_score.append((transcribe_result_file, "segment")) 161 | files_to_score.append((transcribe_result_file, "punctuation")) 162 | 163 | # 5. score all the results 164 | if "5" in pipelines_to_run: 165 | logging.info("Scoring all the results ..") 166 | results = [] 167 | 168 | cwd = os.getcwd() # record the current working directory 169 | os.chdir(Path(__file__).parent.parent) # change to tdrz_dev parent directory 170 | for reco_file, scoring_mode in files_to_score: 171 | # convert reco_file to result_name in this way 172 | # e.g. /home/whisper/tdrz_dev/tiny.en/d1/f.json -> tiny.en-d1 173 | result_name = "__".join(Path(reco_file).parts[-3:-1]) 174 | logging.info(f"Scoring {result_name} with mode {scoring_mode} ..") 175 | result, _ = score_fstalign( 176 | ref_file, reco_file, result_name, speaker_turn_mode=scoring_mode 177 | ) 178 | results.append(result) 179 | os.chdir(cwd) # change back 180 | 181 | results_df = pd.concat(results) 182 | results_df["audio_file"] = Path(audio_file).name 183 | print(results_df) 184 | 185 | # save results to tsv 186 | results_file = f"{log_dir}/scoring_results-{audio_name}.tsv" 187 | results_df.to_csv(results_file, sep="\t", index=False) 188 | logging.info(f"Saved results to {results_file}") 189 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import random as rand 2 | 3 | import numpy 4 | import pytest 5 | 6 | 7 | def pytest_configure(config): 8 | config.addinivalue_line("markers", "requires_cuda") 9 | 10 | 11 | @pytest.fixture 12 | def random(): 13 | rand.seed(42) 14 | numpy.random.seed(42) 15 | -------------------------------------------------------------------------------- /tests/jfk.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akashmjn/tinydiarize/7cba47def707514fe68bbd6663f97663ea482158/tests/jfk.flac -------------------------------------------------------------------------------- /tests/test_audio.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | import numpy as np 4 | 5 | from whisper.audio import SAMPLE_RATE, load_audio, log_mel_spectrogram 6 | 7 | 8 | def test_audio(): 9 | audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac") 10 | audio = load_audio(audio_path) 11 | assert audio.ndim == 1 12 | assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12 13 | assert 0 < audio.std() < 1 14 | 15 | mel_from_audio = log_mel_spectrogram(audio) 16 | mel_from_file = log_mel_spectrogram(audio_path) 17 | 18 | assert np.allclose(mel_from_audio, mel_from_file) 19 | assert mel_from_audio.max() - mel_from_audio.min() <= 2.0 20 | -------------------------------------------------------------------------------- /tests/test_normalizer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from whisper.normalizers import EnglishTextNormalizer 4 | from whisper.normalizers.english import ( 5 | EnglishNumberNormalizer, 6 | EnglishSpellingNormalizer, 7 | ) 8 | 9 | 10 | @pytest.mark.parametrize("std", [EnglishNumberNormalizer(), EnglishTextNormalizer()]) 11 | def test_number_normalizer(std): 12 | assert std("two") == "2" 13 | assert std("thirty one") == "31" 14 | assert std("five twenty four") == "524" 15 | assert std("nineteen ninety nine") == "1999" 16 | assert std("twenty nineteen") == "2019" 17 | 18 | assert std("two point five million") == "2500000" 19 | assert std("four point two billions") == "4200000000s" 20 | assert std("200 thousand") == "200000" 21 | assert std("200 thousand dollars") == "$200000" 22 | assert std("$20 million") == "$20000000" 23 | assert std("€52.4 million") == "€52400000" 24 | assert std("£77 thousands") == "£77000s" 25 | 26 | assert std("two double o eight") == "2008" 27 | 28 | assert std("three thousand twenty nine") == "3029" 29 | assert std("forty three thousand two hundred sixty") == "43260" 30 | assert std("forty three thousand two hundred and sixty") == "43260" 31 | 32 | assert std("nineteen fifties") == "1950s" 33 | assert std("thirty first") == "31st" 34 | assert std("thirty three thousand and three hundred and thirty third") == "33333rd" 35 | 36 | assert std("three billion") == "3000000000" 37 | assert std("millions") == "1000000s" 38 | 39 | assert std("july third twenty twenty") == "july 3rd 2020" 40 | assert std("august twenty sixth twenty twenty one") == "august 26th 2021" 41 | assert std("3 14") == "3 14" 42 | assert std("3.14") == "3.14" 43 | assert std("3 point 2") == "3.2" 44 | assert std("3 point 14") == "3.14" 45 | assert std("fourteen point 4") == "14.4" 46 | assert std("two point two five dollars") == "$2.25" 47 | assert std("two hundred million dollars") == "$200000000" 48 | assert std("$20.1 million") == "$20100000" 49 | 50 | assert std("ninety percent") == "90%" 51 | assert std("seventy six per cent") == "76%" 52 | 53 | assert std("double oh seven") == "007" 54 | assert std("double zero seven") == "007" 55 | assert std("nine one one") == "911" 56 | assert std("nine double one") == "911" 57 | assert std("one triple oh one") == "10001" 58 | 59 | assert std("two thousandth") == "2000th" 60 | assert std("thirty two thousandth") == "32000th" 61 | 62 | assert std("minus 500") == "-500" 63 | assert std("positive twenty thousand") == "+20000" 64 | 65 | assert std("two dollars and seventy cents") == "$2.70" 66 | assert std("3 cents") == "¢3" 67 | assert std("$0.36") == "¢36" 68 | assert std("three euros and sixty five cents") == "€3.65" 69 | 70 | assert std("three and a half million") == "3500000" 71 | assert std("forty eight and a half dollars") == "$48.5" 72 | assert std("b747") == "b 747" 73 | assert std("10 th") == "10th" 74 | assert std("10th") == "10th" 75 | 76 | 77 | def test_spelling_normalizer(): 78 | std = EnglishSpellingNormalizer() 79 | 80 | assert std("mobilisation") == "mobilization" 81 | assert std("cancelation") == "cancellation" 82 | 83 | 84 | def test_text_normalizer(): 85 | std = EnglishTextNormalizer() 86 | assert std("Let's") == "let us" 87 | assert std("he's like") == "he is like" 88 | assert std("she's been like") == "she has been like" 89 | assert std("10km") == "10 km" 90 | assert std("10mm") == "10 mm" 91 | assert std("RC232") == "rc 232" 92 | 93 | assert ( 94 | std("Mr. Park visited Assoc. Prof. Kim Jr.") 95 | == "mister park visited associate professor kim junior" 96 | ) 97 | -------------------------------------------------------------------------------- /tests/test_timing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import scipy.ndimage 4 | import torch 5 | 6 | from whisper.timing import dtw_cpu, dtw_cuda, median_filter 7 | 8 | sizes = [ 9 | (10, 20), 10 | (32, 16), 11 | (123, 1500), 12 | (234, 189), 13 | ] 14 | shapes = [ 15 | (10,), 16 | (1, 15), 17 | (4, 5, 345), 18 | (6, 12, 240, 512), 19 | ] 20 | 21 | 22 | @pytest.mark.parametrize("N, M", sizes) 23 | def test_dtw(N: int, M: int): 24 | steps = np.concatenate([np.zeros(N - 1), np.ones(M - 1)]) 25 | np.random.shuffle(steps) 26 | x = np.random.random((N, M)).astype(np.float32) 27 | 28 | i, j, k = 0, 0, 0 29 | trace = [] 30 | while True: 31 | x[i, j] -= 1 32 | trace.append((i, j)) 33 | 34 | if k == len(steps): 35 | break 36 | 37 | if k + 1 < len(steps) and steps[k] != steps[k + 1]: 38 | i += 1 39 | j += 1 40 | k += 2 41 | continue 42 | 43 | if steps[k] == 0: 44 | i += 1 45 | if steps[k] == 1: 46 | j += 1 47 | k += 1 48 | 49 | trace = np.array(trace).T 50 | dtw_trace = dtw_cpu(x) 51 | 52 | assert np.allclose(trace, dtw_trace) 53 | 54 | 55 | @pytest.mark.requires_cuda 56 | @pytest.mark.parametrize("N, M", sizes) 57 | def test_dtw_cuda_equivalence(N: int, M: int): 58 | x_numpy = np.random.randn(N, M).astype(np.float32) 59 | x_cuda = torch.from_numpy(x_numpy).cuda() 60 | 61 | trace_cpu = dtw_cpu(x_numpy) 62 | trace_cuda = dtw_cuda(x_cuda) 63 | 64 | assert np.allclose(trace_cpu, trace_cuda) 65 | 66 | 67 | @pytest.mark.parametrize("shape", shapes) 68 | def test_median_filter(shape): 69 | x = torch.randn(*shape) 70 | 71 | for filter_width in [3, 5, 7, 13]: 72 | filtered = median_filter(x, filter_width) 73 | 74 | # using np.pad to reflect-pad, because Scipy's behavior is different near the edges. 75 | pad_width = filter_width // 2 76 | padded_x = np.pad( 77 | x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect" 78 | ) 79 | scipy_filtered = scipy.ndimage.median_filter( 80 | padded_x, [1] * (x.ndim - 1) + [filter_width] 81 | ) 82 | scipy_filtered = scipy_filtered[..., pad_width:-pad_width] 83 | 84 | assert np.allclose(filtered, scipy_filtered) 85 | 86 | 87 | @pytest.mark.requires_cuda 88 | @pytest.mark.parametrize("shape", shapes) 89 | def test_median_filter_equivalence(shape): 90 | x = torch.randn(*shape) 91 | 92 | for filter_width in [3, 5, 7, 13]: 93 | filtered_cpu = median_filter(x, filter_width) 94 | filtered_gpu = median_filter(x.cuda(), filter_width).cpu() 95 | 96 | assert np.allclose(filtered_cpu, filtered_gpu) 97 | -------------------------------------------------------------------------------- /tests/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | from whisper.tokenizer import get_tokenizer 2 | 3 | 4 | def test_tokenizer(): 5 | gpt2_tokenizer = get_tokenizer(multilingual=False) 6 | multilingual_tokenizer = get_tokenizer(multilingual=True) 7 | 8 | text = "다람쥐 헌 쳇바퀴에 타고파" 9 | gpt2_tokens = gpt2_tokenizer.encode(text) 10 | multilingual_tokens = multilingual_tokenizer.encode(text) 11 | 12 | assert gpt2_tokenizer.decode(gpt2_tokens) == text 13 | assert multilingual_tokenizer.decode(multilingual_tokens) == text 14 | assert len(gpt2_tokens) > len(multilingual_tokens) 15 | -------------------------------------------------------------------------------- /tests/test_transcribe.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | import torch 5 | 6 | import whisper 7 | 8 | 9 | @pytest.mark.parametrize("model_name", whisper.available_models()) 10 | def test_transcribe(model_name: str): 11 | device = "cuda" if torch.cuda.is_available() else "cpu" 12 | model = whisper.load_model(model_name).to(device) 13 | audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac") 14 | 15 | language = "en" if model_name.endswith(".en") else None 16 | result = model.transcribe( 17 | audio_path, language=language, temperature=0.0, word_timestamps=True 18 | ) 19 | assert result["language"] == "en" 20 | assert result["text"] == "".join([s["text"] for s in result["segments"]]) 21 | 22 | transcription = result["text"].lower() 23 | assert "my fellow americans" in transcription 24 | assert "your country" in transcription 25 | assert "do for you" in transcription 26 | 27 | timing_checked = False 28 | for segment in result["segments"]: 29 | for timing in segment["words"]: 30 | assert timing["start"] < timing["end"] 31 | if timing["word"].strip(" ,") == "Americans": 32 | assert timing["start"] <= 1.8 33 | assert timing["end"] >= 1.8 34 | print(timing) 35 | timing_checked = True 36 | 37 | assert timing_checked 38 | -------------------------------------------------------------------------------- /trim-tinydiarize.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akashmjn/tinydiarize/7cba47def707514fe68bbd6663f97663ea482158/trim-tinydiarize.gif -------------------------------------------------------------------------------- /whisper/__init__.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import io 3 | import os 4 | import urllib 5 | import warnings 6 | from typing import List, Optional, Union 7 | 8 | import torch 9 | from tqdm import tqdm 10 | 11 | from .audio import load_audio, log_mel_spectrogram, pad_or_trim 12 | from .decoding import DecodingOptions, DecodingResult, decode, detect_language 13 | from .model import ModelDimensions, Whisper 14 | from .transcribe import transcribe 15 | from .version import __version__ 16 | 17 | _MODELS = { 18 | "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", 19 | "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", 20 | "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", 21 | "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", 22 | "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", 23 | "small.en-tdrz": "https://sharedstorage7190.blob.core.windows.net/tinydiarize/whisper/models/53dfb0a7f5393bd3612173f84cad3fa2b347a3106b53c116628ead31641e9a53/small.en-tdrz.pt", 24 | "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", 25 | "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", 26 | "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", 27 | "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", 28 | "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", 29 | "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", 30 | } 31 | 32 | # base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are 33 | # highly correlated to the word-level timing, i.e. the alignment between audio and text tokens. 34 | _ALIGNMENT_HEADS = { 35 | "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00", 36 | "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO", 37 | "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00", 38 | "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00", 40 | "small.en-tdrz": None, # TODO@Akash - check if it changed after finetuning 41 | "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P%R7%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9", 44 | "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", 46 | "large": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", 47 | } 48 | 49 | 50 | def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: 51 | os.makedirs(root, exist_ok=True) 52 | 53 | expected_sha256 = url.split("/")[-2] 54 | download_target = os.path.join(root, os.path.basename(url)) 55 | 56 | if os.path.exists(download_target) and not os.path.isfile(download_target): 57 | raise RuntimeError(f"{download_target} exists and is not a regular file") 58 | 59 | if os.path.isfile(download_target): 60 | with open(download_target, "rb") as f: 61 | model_bytes = f.read() 62 | if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: 63 | return model_bytes if in_memory else download_target 64 | else: 65 | warnings.warn( 66 | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" 67 | ) 68 | 69 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 70 | with tqdm( 71 | total=int(source.info().get("Content-Length")), 72 | ncols=80, 73 | unit="iB", 74 | unit_scale=True, 75 | unit_divisor=1024, 76 | ) as loop: 77 | while True: 78 | buffer = source.read(8192) 79 | if not buffer: 80 | break 81 | 82 | output.write(buffer) 83 | loop.update(len(buffer)) 84 | 85 | model_bytes = open(download_target, "rb").read() 86 | if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: 87 | raise RuntimeError( 88 | "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." 89 | ) 90 | 91 | return model_bytes if in_memory else download_target 92 | 93 | 94 | def available_models() -> List[str]: 95 | """Returns the names of available models""" 96 | return list(_MODELS.keys()) 97 | 98 | 99 | def load_model( 100 | name: str, 101 | device: Optional[Union[str, torch.device]] = None, 102 | download_root: str = None, 103 | in_memory: bool = False, 104 | ) -> Whisper: 105 | """ 106 | Load a Whisper ASR model 107 | 108 | Parameters 109 | ---------- 110 | name : str 111 | one of the official model names listed by `whisper.available_models()`, or 112 | path to a model checkpoint containing the model dimensions and the model state_dict. 113 | device : Union[str, torch.device] 114 | the PyTorch device to put the model into 115 | download_root: str 116 | path to download the model files; by default, it uses "~/.cache/whisper" 117 | in_memory: bool 118 | whether to preload the model weights into host memory 119 | 120 | Returns 121 | ------- 122 | model : Whisper 123 | The Whisper ASR model instance 124 | """ 125 | 126 | if device is None: 127 | device = "cuda" if torch.cuda.is_available() else "cpu" 128 | if download_root is None: 129 | default = os.path.join(os.path.expanduser("~"), ".cache") 130 | download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") 131 | 132 | if name in _MODELS: 133 | checkpoint_file = _download(_MODELS[name], download_root, in_memory) 134 | alignment_heads = _ALIGNMENT_HEADS[name] 135 | elif os.path.isfile(name): 136 | checkpoint_file = open(name, "rb").read() if in_memory else name 137 | alignment_heads = None 138 | else: 139 | raise RuntimeError( 140 | f"Model {name} not found; available models = {available_models()}" 141 | ) 142 | 143 | with ( 144 | io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") 145 | ) as fp: 146 | checkpoint = torch.load(fp, map_location=device) 147 | del checkpoint_file 148 | 149 | dims = ModelDimensions(**checkpoint["dims"]) 150 | model = Whisper(dims) 151 | model.load_state_dict(checkpoint["model_state_dict"]) 152 | model.is_tdrz = "-tdrz" in name # tinydiarize finetuned model 153 | 154 | if alignment_heads is not None: 155 | model.set_alignment_heads(alignment_heads) 156 | 157 | return model.to(device) 158 | -------------------------------------------------------------------------------- /whisper/__main__.py: -------------------------------------------------------------------------------- 1 | from .transcribe import cli 2 | 3 | cli() 4 | -------------------------------------------------------------------------------- /whisper/assets/gpt2/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"} -------------------------------------------------------------------------------- /whisper/assets/gpt2/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"} -------------------------------------------------------------------------------- /whisper/assets/mel_filters.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akashmjn/tinydiarize/7cba47def707514fe68bbd6663f97663ea482158/whisper/assets/mel_filters.npz -------------------------------------------------------------------------------- /whisper/assets/multilingual/added_tokens.json: -------------------------------------------------------------------------------- 1 | {"<|endoftext|>": 50257} 2 | -------------------------------------------------------------------------------- /whisper/assets/multilingual/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"} -------------------------------------------------------------------------------- /whisper/assets/multilingual/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"} -------------------------------------------------------------------------------- /whisper/audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import lru_cache 3 | from typing import Optional, Union 4 | 5 | import ffmpeg 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from .utils import exact_div 11 | 12 | # hard-coded audio hyperparameters 13 | SAMPLE_RATE = 16000 14 | N_FFT = 400 15 | N_MELS = 80 16 | HOP_LENGTH = 160 17 | CHUNK_LENGTH = 30 18 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk 19 | N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input 20 | 21 | N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 22 | FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame 23 | TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token 24 | 25 | 26 | def load_audio(file: str, sr: int = SAMPLE_RATE): 27 | """ 28 | Open an audio file and read as mono waveform, resampling as necessary 29 | 30 | Parameters 31 | ---------- 32 | file: str 33 | The audio file to open 34 | 35 | sr: int 36 | The sample rate to resample the audio if necessary 37 | 38 | Returns 39 | ------- 40 | A NumPy array containing the audio waveform, in float32 dtype. 41 | """ 42 | try: 43 | # This launches a subprocess to decode audio while down-mixing and resampling as necessary. 44 | # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. 45 | out, _ = ( 46 | ffmpeg.input(file, threads=0) 47 | .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr) 48 | .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) 49 | ) 50 | except ffmpeg.Error as e: 51 | raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e 52 | 53 | return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 54 | 55 | 56 | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): 57 | """ 58 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder. 59 | """ 60 | if torch.is_tensor(array): 61 | if array.shape[axis] > length: 62 | array = array.index_select( 63 | dim=axis, index=torch.arange(length, device=array.device) 64 | ) 65 | 66 | if array.shape[axis] < length: 67 | pad_widths = [(0, 0)] * array.ndim 68 | pad_widths[axis] = (0, length - array.shape[axis]) 69 | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) 70 | else: 71 | if array.shape[axis] > length: 72 | array = array.take(indices=range(length), axis=axis) 73 | 74 | if array.shape[axis] < length: 75 | pad_widths = [(0, 0)] * array.ndim 76 | pad_widths[axis] = (0, length - array.shape[axis]) 77 | array = np.pad(array, pad_widths) 78 | 79 | return array 80 | 81 | 82 | @lru_cache(maxsize=None) 83 | def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: 84 | """ 85 | load the mel filterbank matrix for projecting STFT into a Mel spectrogram. 86 | Allows decoupling librosa dependency; saved using: 87 | 88 | np.savez_compressed( 89 | "mel_filters.npz", 90 | mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), 91 | ) 92 | """ 93 | assert n_mels == 80, f"Unsupported n_mels: {n_mels}" 94 | with np.load( 95 | os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") 96 | ) as f: 97 | return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) 98 | 99 | 100 | def log_mel_spectrogram( 101 | audio: Union[str, np.ndarray, torch.Tensor], 102 | n_mels: int = N_MELS, 103 | padding: int = 0, 104 | device: Optional[Union[str, torch.device]] = None, 105 | ): 106 | """ 107 | Compute the log-Mel spectrogram of 108 | 109 | Parameters 110 | ---------- 111 | audio: Union[str, np.ndarray, torch.Tensor], shape = (*) 112 | The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz 113 | 114 | n_mels: int 115 | The number of Mel-frequency filters, only 80 is supported 116 | 117 | padding: int 118 | Number of zero samples to pad to the right 119 | 120 | device: Optional[Union[str, torch.device]] 121 | If given, the audio tensor is moved to this device before STFT 122 | 123 | Returns 124 | ------- 125 | torch.Tensor, shape = (80, n_frames) 126 | A Tensor that contains the Mel spectrogram 127 | """ 128 | if not torch.is_tensor(audio): 129 | if isinstance(audio, str): 130 | audio = load_audio(audio) 131 | audio = torch.from_numpy(audio) 132 | 133 | if device is not None: 134 | audio = audio.to(device) 135 | if padding > 0: 136 | audio = F.pad(audio, (0, padding)) 137 | window = torch.hann_window(N_FFT).to(audio.device) 138 | stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) 139 | magnitudes = stft[..., :-1].abs() ** 2 140 | 141 | filters = mel_filters(audio.device, n_mels) 142 | mel_spec = filters @ magnitudes 143 | 144 | log_spec = torch.clamp(mel_spec, min=1e-10).log10() 145 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) 146 | log_spec = (log_spec + 4.0) / 4.0 147 | return log_spec 148 | -------------------------------------------------------------------------------- /whisper/model.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import gzip 3 | from dataclasses import dataclass 4 | from typing import Dict, Iterable, Optional 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import Tensor, nn 10 | 11 | from .decoding import decode as decode_function 12 | from .decoding import detect_language as detect_language_function 13 | from .transcribe import transcribe as transcribe_function 14 | 15 | 16 | @dataclass 17 | class ModelDimensions: 18 | n_mels: int 19 | n_audio_ctx: int 20 | n_audio_state: int 21 | n_audio_head: int 22 | n_audio_layer: int 23 | n_vocab: int 24 | n_text_ctx: int 25 | n_text_state: int 26 | n_text_head: int 27 | n_text_layer: int 28 | 29 | 30 | class LayerNorm(nn.LayerNorm): 31 | def forward(self, x: Tensor) -> Tensor: 32 | return super().forward(x.float()).type(x.dtype) 33 | 34 | 35 | class Linear(nn.Linear): 36 | def forward(self, x: Tensor) -> Tensor: 37 | return F.linear( 38 | x, 39 | self.weight.to(x.dtype), 40 | None if self.bias is None else self.bias.to(x.dtype), 41 | ) 42 | 43 | 44 | class Conv1d(nn.Conv1d): 45 | def _conv_forward( 46 | self, x: Tensor, weight: Tensor, bias: Optional[Tensor] 47 | ) -> Tensor: 48 | return super()._conv_forward( 49 | x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) 50 | ) 51 | 52 | 53 | def sinusoids(length, channels, max_timescale=10000): 54 | """Returns sinusoids for positional embedding""" 55 | assert channels % 2 == 0 56 | log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) 57 | inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) 58 | scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] 59 | return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) 60 | 61 | 62 | class MultiHeadAttention(nn.Module): 63 | def __init__(self, n_state: int, n_head: int): 64 | super().__init__() 65 | self.n_head = n_head 66 | self.query = Linear(n_state, n_state) 67 | self.key = Linear(n_state, n_state, bias=False) 68 | self.value = Linear(n_state, n_state) 69 | self.out = Linear(n_state, n_state) 70 | 71 | def forward( 72 | self, 73 | x: Tensor, 74 | xa: Optional[Tensor] = None, 75 | mask: Optional[Tensor] = None, 76 | kv_cache: Optional[dict] = None, 77 | ): 78 | q = self.query(x) 79 | 80 | if kv_cache is None or xa is None or self.key not in kv_cache: 81 | # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; 82 | # otherwise, perform key/value projections for self- or cross-attention as usual. 83 | k = self.key(x if xa is None else xa) 84 | v = self.value(x if xa is None else xa) 85 | else: 86 | # for cross-attention, calculate keys and values once and reuse in subsequent calls. 87 | k = kv_cache[self.key] 88 | v = kv_cache[self.value] 89 | 90 | wv, qk = self.qkv_attention(q, k, v, mask) 91 | return self.out(wv), qk 92 | 93 | def qkv_attention( 94 | self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None 95 | ): 96 | n_batch, n_ctx, n_state = q.shape 97 | scale = (n_state // self.n_head) ** -0.25 98 | q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale 99 | k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale 100 | v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) 101 | 102 | qk = q @ k 103 | if mask is not None: 104 | qk = qk + mask[:n_ctx, :n_ctx] 105 | qk = qk.float() 106 | 107 | w = F.softmax(qk, dim=-1).to(q.dtype) 108 | return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() 109 | 110 | 111 | class ResidualAttentionBlock(nn.Module): 112 | def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): 113 | super().__init__() 114 | 115 | self.attn = MultiHeadAttention(n_state, n_head) 116 | self.attn_ln = LayerNorm(n_state) 117 | 118 | self.cross_attn = ( 119 | MultiHeadAttention(n_state, n_head) if cross_attention else None 120 | ) 121 | self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None 122 | 123 | n_mlp = n_state * 4 124 | self.mlp = nn.Sequential( 125 | Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state) 126 | ) 127 | self.mlp_ln = LayerNorm(n_state) 128 | 129 | def forward( 130 | self, 131 | x: Tensor, 132 | xa: Optional[Tensor] = None, 133 | mask: Optional[Tensor] = None, 134 | kv_cache: Optional[dict] = None, 135 | ): 136 | x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] 137 | if self.cross_attn: 138 | x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] 139 | x = x + self.mlp(self.mlp_ln(x)) 140 | return x 141 | 142 | 143 | class AudioEncoder(nn.Module): 144 | def __init__( 145 | self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int 146 | ): 147 | super().__init__() 148 | self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) 149 | self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) 150 | self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) 151 | 152 | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( 153 | [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] 154 | ) 155 | self.ln_post = LayerNorm(n_state) 156 | 157 | def forward(self, x: Tensor): 158 | """ 159 | x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) 160 | the mel spectrogram of the audio 161 | """ 162 | x = F.gelu(self.conv1(x)) 163 | x = F.gelu(self.conv2(x)) 164 | x = x.permute(0, 2, 1) 165 | 166 | assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" 167 | x = (x + self.positional_embedding).to(x.dtype) 168 | 169 | for block in self.blocks: 170 | x = block(x) 171 | 172 | x = self.ln_post(x) 173 | return x 174 | 175 | 176 | class TextDecoder(nn.Module): 177 | def __init__( 178 | self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int 179 | ): 180 | super().__init__() 181 | 182 | self.token_embedding = nn.Embedding(n_vocab, n_state) 183 | self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) 184 | 185 | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( 186 | [ 187 | ResidualAttentionBlock(n_state, n_head, cross_attention=True) 188 | for _ in range(n_layer) 189 | ] 190 | ) 191 | self.ln = LayerNorm(n_state) 192 | 193 | mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) 194 | self.register_buffer("mask", mask, persistent=False) 195 | 196 | def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): 197 | """ 198 | x : torch.LongTensor, shape = (batch_size, <= n_ctx) 199 | the text tokens 200 | xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx) 201 | the encoded audio features to be attended on 202 | """ 203 | offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 204 | x = ( 205 | self.token_embedding(x) 206 | + self.positional_embedding[offset : offset + x.shape[-1]] 207 | ) 208 | x = x.to(xa.dtype) 209 | 210 | for block in self.blocks: 211 | x = block(x, xa, mask=self.mask, kv_cache=kv_cache) 212 | 213 | x = self.ln(x) 214 | logits = ( 215 | x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) 216 | ).float() 217 | 218 | return logits 219 | 220 | 221 | class Whisper(nn.Module): 222 | def __init__(self, dims: ModelDimensions): 223 | super().__init__() 224 | self.dims = dims 225 | self.encoder = AudioEncoder( 226 | self.dims.n_mels, 227 | self.dims.n_audio_ctx, 228 | self.dims.n_audio_state, 229 | self.dims.n_audio_head, 230 | self.dims.n_audio_layer, 231 | ) 232 | self.decoder = TextDecoder( 233 | self.dims.n_vocab, 234 | self.dims.n_text_ctx, 235 | self.dims.n_text_state, 236 | self.dims.n_text_head, 237 | self.dims.n_text_layer, 238 | ) 239 | # use the last half layers for alignment by default; see `set_alignment_heads()` below 240 | all_heads = torch.zeros( 241 | self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool 242 | ) 243 | all_heads[self.dims.n_text_layer // 2 :] = True 244 | self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) 245 | 246 | def set_alignment_heads(self, dump: bytes): 247 | array = np.frombuffer( 248 | gzip.decompress(base64.b85decode(dump)), dtype=bool 249 | ).copy() 250 | mask = torch.from_numpy(array).reshape( 251 | self.dims.n_text_layer, self.dims.n_text_head 252 | ) 253 | self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False) 254 | 255 | def embed_audio(self, mel: torch.Tensor): 256 | return self.encoder(mel) 257 | 258 | def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): 259 | return self.decoder(tokens, audio_features) 260 | 261 | def forward( 262 | self, mel: torch.Tensor, tokens: torch.Tensor 263 | ) -> Dict[str, torch.Tensor]: 264 | return self.decoder(tokens, self.encoder(mel)) 265 | 266 | @property 267 | def device(self): 268 | return next(self.parameters()).device 269 | 270 | @property 271 | def is_multilingual(self): 272 | return self.dims.n_vocab == 51865 273 | 274 | def install_kv_cache_hooks(self, cache: Optional[dict] = None): 275 | """ 276 | The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value 277 | tensors calculated for the previous positions. This method returns a dictionary that stores 278 | all caches, and the necessary hooks for the key and value projection modules that save the 279 | intermediate tensors to be reused during later calculations. 280 | 281 | Returns 282 | ------- 283 | cache : Dict[nn.Module, torch.Tensor] 284 | A dictionary object mapping the key/value projection modules to its cache 285 | hooks : List[RemovableHandle] 286 | List of PyTorch RemovableHandle objects to stop the hooks to be called 287 | """ 288 | cache = {**cache} if cache is not None else {} 289 | hooks = [] 290 | 291 | def save_to_cache(module, _, output): 292 | if module not in cache or output.shape[1] > self.dims.n_text_ctx: 293 | # save as-is, for the first token or cross attention 294 | cache[module] = output 295 | else: 296 | cache[module] = torch.cat([cache[module], output], dim=1).detach() 297 | return cache[module] 298 | 299 | def install_hooks(layer: nn.Module): 300 | if isinstance(layer, MultiHeadAttention): 301 | hooks.append(layer.key.register_forward_hook(save_to_cache)) 302 | hooks.append(layer.value.register_forward_hook(save_to_cache)) 303 | 304 | self.decoder.apply(install_hooks) 305 | return cache, hooks 306 | 307 | detect_language = detect_language_function 308 | transcribe = transcribe_function 309 | decode = decode_function 310 | -------------------------------------------------------------------------------- /whisper/normalizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic import BasicTextNormalizer as BasicTextNormalizer 2 | from .english import EnglishTextNormalizer as EnglishTextNormalizer 3 | -------------------------------------------------------------------------------- /whisper/normalizers/basic.py: -------------------------------------------------------------------------------- 1 | import re 2 | import unicodedata 3 | 4 | import regex 5 | 6 | # non-ASCII letters that are not separated by "NFKD" normalization 7 | ADDITIONAL_DIACRITICS = { 8 | "œ": "oe", 9 | "Œ": "OE", 10 | "ø": "o", 11 | "Ø": "O", 12 | "æ": "ae", 13 | "Æ": "AE", 14 | "ß": "ss", 15 | "ẞ": "SS", 16 | "đ": "d", 17 | "Đ": "D", 18 | "ð": "d", 19 | "Ð": "D", 20 | "þ": "th", 21 | "Þ": "th", 22 | "ł": "l", 23 | "Ł": "L", 24 | } 25 | 26 | 27 | def remove_symbols_and_diacritics(s: str, keep=""): 28 | """ 29 | Replace any other markers, symbols, and punctuations with a space, 30 | and drop any diacritics (category 'Mn' and some manual mappings) 31 | """ 32 | return "".join( 33 | c 34 | if c in keep 35 | else ADDITIONAL_DIACRITICS[c] 36 | if c in ADDITIONAL_DIACRITICS 37 | else "" 38 | if unicodedata.category(c) == "Mn" 39 | else " " 40 | if unicodedata.category(c)[0] in "MSP" 41 | else c 42 | for c in unicodedata.normalize("NFKD", s) 43 | ) 44 | 45 | 46 | def remove_symbols(s: str): 47 | """ 48 | Replace any other markers, symbols, punctuations with a space, keeping diacritics 49 | """ 50 | return "".join( 51 | " " if unicodedata.category(c)[0] in "MSP" else c 52 | for c in unicodedata.normalize("NFKC", s) 53 | ) 54 | 55 | 56 | class BasicTextNormalizer: 57 | def __init__(self, remove_diacritics: bool = False, split_letters: bool = False): 58 | self.clean = ( 59 | remove_symbols_and_diacritics if remove_diacritics else remove_symbols 60 | ) 61 | self.split_letters = split_letters 62 | 63 | def __call__(self, s: str): 64 | s = s.lower() 65 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets 66 | s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis 67 | s = self.clean(s).lower() 68 | 69 | if self.split_letters: 70 | s = " ".join(regex.findall(r"\X", s, regex.U)) 71 | 72 | s = re.sub( 73 | r"\s+", " ", s 74 | ) # replace any successive whitespace characters with a space 75 | 76 | return s 77 | -------------------------------------------------------------------------------- /whisper/normalizers/english.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | from fractions import Fraction 5 | from typing import Iterator, List, Match, Optional, Union 6 | 7 | from more_itertools import windowed 8 | 9 | from .basic import remove_symbols_and_diacritics 10 | 11 | 12 | class EnglishNumberNormalizer: 13 | """ 14 | Convert any spelled-out numbers into arabic numbers, while handling: 15 | 16 | - remove any commas 17 | - keep the suffixes such as: `1960s`, `274th`, `32nd`, etc. 18 | - spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars` 19 | - spell out `one` and `ones` 20 | - interpret successive single-digit numbers as nominal: `one oh one` -> `101` 21 | """ 22 | 23 | def __init__(self): 24 | super().__init__() 25 | 26 | self.zeros = {"o", "oh", "zero"} 27 | self.ones = { 28 | name: i 29 | for i, name in enumerate( 30 | [ 31 | "one", 32 | "two", 33 | "three", 34 | "four", 35 | "five", 36 | "six", 37 | "seven", 38 | "eight", 39 | "nine", 40 | "ten", 41 | "eleven", 42 | "twelve", 43 | "thirteen", 44 | "fourteen", 45 | "fifteen", 46 | "sixteen", 47 | "seventeen", 48 | "eighteen", 49 | "nineteen", 50 | ], 51 | start=1, 52 | ) 53 | } 54 | self.ones_plural = { 55 | "sixes" if name == "six" else name + "s": (value, "s") 56 | for name, value in self.ones.items() 57 | } 58 | self.ones_ordinal = { 59 | "zeroth": (0, "th"), 60 | "first": (1, "st"), 61 | "second": (2, "nd"), 62 | "third": (3, "rd"), 63 | "fifth": (5, "th"), 64 | "twelfth": (12, "th"), 65 | **{ 66 | name + ("h" if name.endswith("t") else "th"): (value, "th") 67 | for name, value in self.ones.items() 68 | if value > 3 and value != 5 and value != 12 69 | }, 70 | } 71 | self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal} 72 | 73 | self.tens = { 74 | "twenty": 20, 75 | "thirty": 30, 76 | "forty": 40, 77 | "fifty": 50, 78 | "sixty": 60, 79 | "seventy": 70, 80 | "eighty": 80, 81 | "ninety": 90, 82 | } 83 | self.tens_plural = { 84 | name.replace("y", "ies"): (value, "s") for name, value in self.tens.items() 85 | } 86 | self.tens_ordinal = { 87 | name.replace("y", "ieth"): (value, "th") 88 | for name, value in self.tens.items() 89 | } 90 | self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal} 91 | 92 | self.multipliers = { 93 | "hundred": 100, 94 | "thousand": 1_000, 95 | "million": 1_000_000, 96 | "billion": 1_000_000_000, 97 | "trillion": 1_000_000_000_000, 98 | "quadrillion": 1_000_000_000_000_000, 99 | "quintillion": 1_000_000_000_000_000_000, 100 | "sextillion": 1_000_000_000_000_000_000_000, 101 | "septillion": 1_000_000_000_000_000_000_000_000, 102 | "octillion": 1_000_000_000_000_000_000_000_000_000, 103 | "nonillion": 1_000_000_000_000_000_000_000_000_000_000, 104 | "decillion": 1_000_000_000_000_000_000_000_000_000_000_000, 105 | } 106 | self.multipliers_plural = { 107 | name + "s": (value, "s") for name, value in self.multipliers.items() 108 | } 109 | self.multipliers_ordinal = { 110 | name + "th": (value, "th") for name, value in self.multipliers.items() 111 | } 112 | self.multipliers_suffixed = { 113 | **self.multipliers_plural, 114 | **self.multipliers_ordinal, 115 | } 116 | self.decimals = {*self.ones, *self.tens, *self.zeros} 117 | 118 | self.preceding_prefixers = { 119 | "minus": "-", 120 | "negative": "-", 121 | "plus": "+", 122 | "positive": "+", 123 | } 124 | self.following_prefixers = { 125 | "pound": "£", 126 | "pounds": "£", 127 | "euro": "€", 128 | "euros": "€", 129 | "dollar": "$", 130 | "dollars": "$", 131 | "cent": "¢", 132 | "cents": "¢", 133 | } 134 | self.prefixes = set( 135 | list(self.preceding_prefixers.values()) 136 | + list(self.following_prefixers.values()) 137 | ) 138 | self.suffixers = { 139 | "per": {"cent": "%"}, 140 | "percent": "%", 141 | } 142 | self.specials = {"and", "double", "triple", "point"} 143 | 144 | self.words = set( 145 | [ 146 | key 147 | for mapping in [ 148 | self.zeros, 149 | self.ones, 150 | self.ones_suffixed, 151 | self.tens, 152 | self.tens_suffixed, 153 | self.multipliers, 154 | self.multipliers_suffixed, 155 | self.preceding_prefixers, 156 | self.following_prefixers, 157 | self.suffixers, 158 | self.specials, 159 | ] 160 | for key in mapping 161 | ] 162 | ) 163 | self.literal_words = {"one", "ones"} 164 | 165 | def process_words(self, words: List[str]) -> Iterator[str]: 166 | prefix: Optional[str] = None 167 | value: Optional[Union[str, int]] = None 168 | skip = False 169 | 170 | def to_fraction(s: str): 171 | try: 172 | return Fraction(s) 173 | except ValueError: 174 | return None 175 | 176 | def output(result: Union[str, int]): 177 | nonlocal prefix, value 178 | result = str(result) 179 | if prefix is not None: 180 | result = prefix + result 181 | value = None 182 | prefix = None 183 | return result 184 | 185 | if len(words) == 0: 186 | return 187 | 188 | for prev, current, next in windowed([None] + words + [None], 3): 189 | if skip: 190 | skip = False 191 | continue 192 | 193 | next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next) 194 | has_prefix = current[0] in self.prefixes 195 | current_without_prefix = current[1:] if has_prefix else current 196 | if re.match(r"^\d+(\.\d+)?$", current_without_prefix): 197 | # arabic numbers (potentially with signs and fractions) 198 | f = to_fraction(current_without_prefix) 199 | assert f is not None 200 | if value is not None: 201 | if isinstance(value, str) and value.endswith("."): 202 | # concatenate decimals / ip address components 203 | value = str(value) + str(current) 204 | continue 205 | else: 206 | yield output(value) 207 | 208 | prefix = current[0] if has_prefix else prefix 209 | if f.denominator == 1: 210 | value = f.numerator # store integers as int 211 | else: 212 | value = current_without_prefix 213 | elif current not in self.words: 214 | # non-numeric words 215 | if value is not None: 216 | yield output(value) 217 | yield output(current) 218 | elif current in self.zeros: 219 | value = str(value or "") + "0" 220 | elif current in self.ones: 221 | ones = self.ones[current] 222 | 223 | if value is None: 224 | value = ones 225 | elif isinstance(value, str) or prev in self.ones: 226 | if ( 227 | prev in self.tens and ones < 10 228 | ): # replace the last zero with the digit 229 | assert value[-1] == "0" 230 | value = value[:-1] + str(ones) 231 | else: 232 | value = str(value) + str(ones) 233 | elif ones < 10: 234 | if value % 10 == 0: 235 | value += ones 236 | else: 237 | value = str(value) + str(ones) 238 | else: # eleven to nineteen 239 | if value % 100 == 0: 240 | value += ones 241 | else: 242 | value = str(value) + str(ones) 243 | elif current in self.ones_suffixed: 244 | # ordinal or cardinal; yield the number right away 245 | ones, suffix = self.ones_suffixed[current] 246 | if value is None: 247 | yield output(str(ones) + suffix) 248 | elif isinstance(value, str) or prev in self.ones: 249 | if prev in self.tens and ones < 10: 250 | assert value[-1] == "0" 251 | yield output(value[:-1] + str(ones) + suffix) 252 | else: 253 | yield output(str(value) + str(ones) + suffix) 254 | elif ones < 10: 255 | if value % 10 == 0: 256 | yield output(str(value + ones) + suffix) 257 | else: 258 | yield output(str(value) + str(ones) + suffix) 259 | else: # eleven to nineteen 260 | if value % 100 == 0: 261 | yield output(str(value + ones) + suffix) 262 | else: 263 | yield output(str(value) + str(ones) + suffix) 264 | value = None 265 | elif current in self.tens: 266 | tens = self.tens[current] 267 | if value is None: 268 | value = tens 269 | elif isinstance(value, str): 270 | value = str(value) + str(tens) 271 | else: 272 | if value % 100 == 0: 273 | value += tens 274 | else: 275 | value = str(value) + str(tens) 276 | elif current in self.tens_suffixed: 277 | # ordinal or cardinal; yield the number right away 278 | tens, suffix = self.tens_suffixed[current] 279 | if value is None: 280 | yield output(str(tens) + suffix) 281 | elif isinstance(value, str): 282 | yield output(str(value) + str(tens) + suffix) 283 | else: 284 | if value % 100 == 0: 285 | yield output(str(value + tens) + suffix) 286 | else: 287 | yield output(str(value) + str(tens) + suffix) 288 | elif current in self.multipliers: 289 | multiplier = self.multipliers[current] 290 | if value is None: 291 | value = multiplier 292 | elif isinstance(value, str) or value == 0: 293 | f = to_fraction(value) 294 | p = f * multiplier if f is not None else None 295 | if f is not None and p.denominator == 1: 296 | value = p.numerator 297 | else: 298 | yield output(value) 299 | value = multiplier 300 | else: 301 | before = value // 1000 * 1000 302 | residual = value % 1000 303 | value = before + residual * multiplier 304 | elif current in self.multipliers_suffixed: 305 | multiplier, suffix = self.multipliers_suffixed[current] 306 | if value is None: 307 | yield output(str(multiplier) + suffix) 308 | elif isinstance(value, str): 309 | f = to_fraction(value) 310 | p = f * multiplier if f is not None else None 311 | if f is not None and p.denominator == 1: 312 | yield output(str(p.numerator) + suffix) 313 | else: 314 | yield output(value) 315 | yield output(str(multiplier) + suffix) 316 | else: # int 317 | before = value // 1000 * 1000 318 | residual = value % 1000 319 | value = before + residual * multiplier 320 | yield output(str(value) + suffix) 321 | value = None 322 | elif current in self.preceding_prefixers: 323 | # apply prefix (positive, minus, etc.) if it precedes a number 324 | if value is not None: 325 | yield output(value) 326 | 327 | if next in self.words or next_is_numeric: 328 | prefix = self.preceding_prefixers[current] 329 | else: 330 | yield output(current) 331 | elif current in self.following_prefixers: 332 | # apply prefix (dollars, cents, etc.) only after a number 333 | if value is not None: 334 | prefix = self.following_prefixers[current] 335 | yield output(value) 336 | else: 337 | yield output(current) 338 | elif current in self.suffixers: 339 | # apply suffix symbols (percent -> '%') 340 | if value is not None: 341 | suffix = self.suffixers[current] 342 | if isinstance(suffix, dict): 343 | if next in suffix: 344 | yield output(str(value) + suffix[next]) 345 | skip = True 346 | else: 347 | yield output(value) 348 | yield output(current) 349 | else: 350 | yield output(str(value) + suffix) 351 | else: 352 | yield output(current) 353 | elif current in self.specials: 354 | if next not in self.words and not next_is_numeric: 355 | # apply special handling only if the next word can be numeric 356 | if value is not None: 357 | yield output(value) 358 | yield output(current) 359 | elif current == "and": 360 | # ignore "and" after hundreds, thousands, etc. 361 | if prev not in self.multipliers: 362 | if value is not None: 363 | yield output(value) 364 | yield output(current) 365 | elif current == "double" or current == "triple": 366 | if next in self.ones or next in self.zeros: 367 | repeats = 2 if current == "double" else 3 368 | ones = self.ones.get(next, 0) 369 | value = str(value or "") + str(ones) * repeats 370 | skip = True 371 | else: 372 | if value is not None: 373 | yield output(value) 374 | yield output(current) 375 | elif current == "point": 376 | if next in self.decimals or next_is_numeric: 377 | value = str(value or "") + "." 378 | else: 379 | # should all have been covered at this point 380 | raise ValueError(f"Unexpected token: {current}") 381 | else: 382 | # all should have been covered at this point 383 | raise ValueError(f"Unexpected token: {current}") 384 | 385 | if value is not None: 386 | yield output(value) 387 | 388 | def preprocess(self, s: str): 389 | # replace " and a half" with " point five" 390 | results = [] 391 | 392 | segments = re.split(r"\band\s+a\s+half\b", s) 393 | for i, segment in enumerate(segments): 394 | if len(segment.strip()) == 0: 395 | continue 396 | if i == len(segments) - 1: 397 | results.append(segment) 398 | else: 399 | results.append(segment) 400 | last_word = segment.rsplit(maxsplit=2)[-1] 401 | if last_word in self.decimals or last_word in self.multipliers: 402 | results.append("point five") 403 | else: 404 | results.append("and a half") 405 | 406 | s = " ".join(results) 407 | 408 | # put a space at number/letter boundary 409 | s = re.sub(r"([a-z])([0-9])", r"\1 \2", s) 410 | s = re.sub(r"([0-9])([a-z])", r"\1 \2", s) 411 | 412 | # but remove spaces which could be a suffix 413 | s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s) 414 | 415 | return s 416 | 417 | def postprocess(self, s: str): 418 | def combine_cents(m: Match): 419 | try: 420 | currency = m.group(1) 421 | integer = m.group(2) 422 | cents = int(m.group(3)) 423 | return f"{currency}{integer}.{cents:02d}" 424 | except ValueError: 425 | return m.string 426 | 427 | def extract_cents(m: Match): 428 | try: 429 | return f"¢{int(m.group(1))}" 430 | except ValueError: 431 | return m.string 432 | 433 | # apply currency postprocessing; "$2 and ¢7" -> "$2.07" 434 | s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s) 435 | s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s) 436 | 437 | # write "one(s)" instead of "1(s)", just for the readability 438 | s = re.sub(r"\b1(s?)\b", r"one\1", s) 439 | 440 | return s 441 | 442 | def __call__(self, s: str): 443 | s = self.preprocess(s) 444 | s = " ".join(word for word in self.process_words(s.split()) if word is not None) 445 | s = self.postprocess(s) 446 | 447 | return s 448 | 449 | 450 | class EnglishSpellingNormalizer: 451 | """ 452 | Applies British-American spelling mappings as listed in [1]. 453 | 454 | [1] https://www.tysto.com/uk-us-spelling-list.html 455 | """ 456 | 457 | def __init__(self): 458 | mapping_path = os.path.join(os.path.dirname(__file__), "english.json") 459 | self.mapping = json.load(open(mapping_path)) 460 | 461 | def __call__(self, s: str): 462 | return " ".join(self.mapping.get(word, word) for word in s.split()) 463 | 464 | 465 | class EnglishTextNormalizer: 466 | def __init__(self): 467 | self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b" 468 | self.replacers = { 469 | # common contractions 470 | r"\bwon't\b": "will not", 471 | r"\bcan't\b": "can not", 472 | r"\blet's\b": "let us", 473 | r"\bain't\b": "aint", 474 | r"\by'all\b": "you all", 475 | r"\bwanna\b": "want to", 476 | r"\bgotta\b": "got to", 477 | r"\bgonna\b": "going to", 478 | r"\bi'ma\b": "i am going to", 479 | r"\bimma\b": "i am going to", 480 | r"\bwoulda\b": "would have", 481 | r"\bcoulda\b": "could have", 482 | r"\bshoulda\b": "should have", 483 | r"\bma'am\b": "madam", 484 | # contractions in titles/prefixes 485 | r"\bmr\b": "mister ", 486 | r"\bmrs\b": "missus ", 487 | r"\bst\b": "saint ", 488 | r"\bdr\b": "doctor ", 489 | r"\bprof\b": "professor ", 490 | r"\bcapt\b": "captain ", 491 | r"\bgov\b": "governor ", 492 | r"\bald\b": "alderman ", 493 | r"\bgen\b": "general ", 494 | r"\bsen\b": "senator ", 495 | r"\brep\b": "representative ", 496 | r"\bpres\b": "president ", 497 | r"\brev\b": "reverend ", 498 | r"\bhon\b": "honorable ", 499 | r"\basst\b": "assistant ", 500 | r"\bassoc\b": "associate ", 501 | r"\blt\b": "lieutenant ", 502 | r"\bcol\b": "colonel ", 503 | r"\bjr\b": "junior ", 504 | r"\bsr\b": "senior ", 505 | r"\besq\b": "esquire ", 506 | # prefect tenses, ideally it should be any past participles, but it's harder.. 507 | r"'d been\b": " had been", 508 | r"'s been\b": " has been", 509 | r"'d gone\b": " had gone", 510 | r"'s gone\b": " has gone", 511 | r"'d done\b": " had done", # "'s done" is ambiguous 512 | r"'s got\b": " has got", 513 | # general contractions 514 | r"n't\b": " not", 515 | r"'re\b": " are", 516 | r"'s\b": " is", 517 | r"'d\b": " would", 518 | r"'ll\b": " will", 519 | r"'t\b": " not", 520 | r"'ve\b": " have", 521 | r"'m\b": " am", 522 | } 523 | self.standardize_numbers = EnglishNumberNormalizer() 524 | self.standardize_spellings = EnglishSpellingNormalizer() 525 | 526 | def __call__(self, s: str): 527 | s = s.lower() 528 | 529 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets 530 | s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis 531 | s = re.sub(self.ignore_patterns, "", s) 532 | s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe 533 | 534 | for pattern, replacement in self.replacers.items(): 535 | s = re.sub(pattern, replacement, s) 536 | 537 | s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits 538 | s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers 539 | s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols 540 | 541 | s = self.standardize_numbers(s) 542 | s = self.standardize_spellings(s) 543 | 544 | # now remove prefix/suffix symbols that are not preceded/followed by numbers 545 | s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s) 546 | s = re.sub(r"([^0-9])%", r"\1 ", s) 547 | 548 | s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space 549 | 550 | return s 551 | -------------------------------------------------------------------------------- /whisper/timing.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import warnings 3 | from dataclasses import dataclass 4 | from typing import TYPE_CHECKING, List 5 | 6 | import numba 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND 12 | from .tokenizer import Tokenizer 13 | 14 | if TYPE_CHECKING: 15 | from .model import Whisper 16 | 17 | 18 | def median_filter(x: torch.Tensor, filter_width: int): 19 | """Apply a median filter of width `filter_width` along the last dimension of `x`""" 20 | pad_width = filter_width // 2 21 | if x.shape[-1] <= pad_width: 22 | # F.pad requires the padding width to be smaller than the input dimension 23 | return x 24 | 25 | if (ndim := x.ndim) <= 2: 26 | # `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D 27 | x = x[None, None, :] 28 | 29 | assert ( 30 | filter_width > 0 and filter_width % 2 == 1 31 | ), "`filter_width` should be an odd number" 32 | 33 | result = None 34 | x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect") 35 | if x.is_cuda: 36 | try: 37 | from .triton_ops import median_filter_cuda 38 | 39 | result = median_filter_cuda(x, filter_width) 40 | except (RuntimeError, subprocess.CalledProcessError): 41 | warnings.warn( 42 | "Failed to launch Triton kernels, likely due to missing CUDA toolkit; " 43 | "falling back to a slower median kernel implementation..." 44 | ) 45 | 46 | if result is None: 47 | # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450) 48 | result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2] 49 | 50 | if ndim <= 2: 51 | result = result[0, 0] 52 | 53 | return result 54 | 55 | 56 | @numba.jit 57 | def backtrace(trace: np.ndarray): 58 | i = trace.shape[0] - 1 59 | j = trace.shape[1] - 1 60 | trace[0, :] = 2 61 | trace[:, 0] = 1 62 | 63 | result = [] 64 | while i > 0 or j > 0: 65 | result.append((i - 1, j - 1)) 66 | 67 | if trace[i, j] == 0: 68 | i -= 1 69 | j -= 1 70 | elif trace[i, j] == 1: 71 | i -= 1 72 | elif trace[i, j] == 2: 73 | j -= 1 74 | else: 75 | raise ValueError("Unexpected trace[i, j]") 76 | 77 | result = np.array(result) 78 | return result[::-1, :].T 79 | 80 | 81 | @numba.jit(nopython=True, parallel=True) 82 | def dtw_cpu(x: np.ndarray): 83 | N, M = x.shape 84 | cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf 85 | trace = -np.ones((N + 1, M + 1), dtype=np.float32) 86 | 87 | cost[0, 0] = 0 88 | for j in range(1, M + 1): 89 | for i in range(1, N + 1): 90 | c0 = cost[i - 1, j - 1] 91 | c1 = cost[i - 1, j] 92 | c2 = cost[i, j - 1] 93 | 94 | if c0 < c1 and c0 < c2: 95 | c, t = c0, 0 96 | elif c1 < c0 and c1 < c2: 97 | c, t = c1, 1 98 | else: 99 | c, t = c2, 2 100 | 101 | cost[i, j] = x[i - 1, j - 1] + c 102 | trace[i, j] = t 103 | 104 | return backtrace(trace) 105 | 106 | 107 | def dtw_cuda(x, BLOCK_SIZE=1024): 108 | from .triton_ops import dtw_kernel 109 | 110 | M, N = x.shape 111 | assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}" 112 | 113 | x_skew = ( 114 | F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M) 115 | ) 116 | x_skew = x_skew.T.contiguous() 117 | cost = torch.ones(N + M + 2, M + 2) * np.inf 118 | cost[0, 0] = 0 119 | cost = cost.cuda() 120 | trace = torch.zeros_like(cost, dtype=torch.int32) 121 | 122 | dtw_kernel[(1,)]( 123 | cost, 124 | trace, 125 | x_skew, 126 | x_skew.stride(0), 127 | cost.stride(0), 128 | trace.stride(0), 129 | N, 130 | M, 131 | BLOCK_SIZE=BLOCK_SIZE, 132 | ) 133 | 134 | trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[ 135 | :, : N + 1 136 | ] 137 | return backtrace(trace.cpu().numpy()) 138 | 139 | 140 | def dtw(x: torch.Tensor) -> np.ndarray: 141 | if x.is_cuda: 142 | try: 143 | return dtw_cuda(x) 144 | except (RuntimeError, subprocess.CalledProcessError): 145 | warnings.warn( 146 | "Failed to launch Triton kernels, likely due to missing CUDA toolkit; " 147 | "falling back to a slower DTW implementation..." 148 | ) 149 | 150 | return dtw_cpu(x.double().cpu().numpy()) 151 | 152 | 153 | @dataclass 154 | class WordTiming: 155 | word: str 156 | tokens: List[int] 157 | start: float 158 | end: float 159 | probability: float 160 | 161 | 162 | def find_alignment( 163 | model: "Whisper", 164 | tokenizer: Tokenizer, 165 | text_tokens: List[int], 166 | mel: torch.Tensor, 167 | num_frames: int, 168 | *, 169 | medfilt_width: int = 7, 170 | qk_scale: float = 1.0, 171 | ) -> List[WordTiming]: 172 | tokens = torch.tensor( 173 | [ 174 | *tokenizer.sot_sequence, 175 | tokenizer.no_timestamps, 176 | *text_tokens, 177 | tokenizer.eot, 178 | ] 179 | ).to(model.device) 180 | 181 | # install hooks on the cross attention layers to retrieve the attention weights 182 | QKs = [None] * model.dims.n_text_layer 183 | hooks = [ 184 | block.cross_attn.register_forward_hook( 185 | lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0]) 186 | ) 187 | for i, block in enumerate(model.decoder.blocks) 188 | ] 189 | 190 | with torch.no_grad(): 191 | logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0] 192 | sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot] 193 | token_probs = sampled_logits.softmax(dim=-1) 194 | text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens] 195 | text_token_probs = text_token_probs.tolist() 196 | 197 | for hook in hooks: 198 | hook.remove() 199 | 200 | # heads * tokens * frames 201 | weights = torch.stack([QKs[l][h] for l, h in model.alignment_heads.indices().T]) 202 | weights = weights[:, :, : num_frames // 2] 203 | weights = (weights * qk_scale).softmax(dim=-1) 204 | std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False) 205 | weights = (weights - mean) / std 206 | weights = median_filter(weights, medfilt_width) 207 | 208 | matrix = weights.mean(axis=0) 209 | matrix = matrix[len(tokenizer.sot_sequence) : -1] 210 | text_indices, time_indices = dtw(-matrix) 211 | 212 | words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot]) 213 | word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)) 214 | 215 | jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) 216 | jump_times = time_indices[jumps] / TOKENS_PER_SECOND 217 | start_times = jump_times[word_boundaries[:-1]] 218 | end_times = jump_times[word_boundaries[1:]] 219 | word_probabilities = [ 220 | np.mean(text_token_probs[i:j]) 221 | for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) 222 | ] 223 | 224 | # hack: ensure the first and second word is not longer than twice the median word duration. 225 | # a better segmentation algorithm based on VAD should be able to replace this. 226 | word_durations = end_times - start_times 227 | word_durations = word_durations[word_durations.nonzero()] 228 | if len(word_durations) > 0: 229 | median_duration = np.median(word_durations) 230 | max_duration = median_duration * 2 231 | if len(word_durations) >= 2 and word_durations[1] > max_duration: 232 | boundary = max(end_times[2] / 2, end_times[2] - max_duration) 233 | end_times[0] = start_times[1] = boundary 234 | if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration: 235 | start_times[0] = max(0, end_times[0] - max_duration) 236 | 237 | return [ 238 | WordTiming(word, tokens, start, end, probability) 239 | for word, tokens, start, end, probability in zip( 240 | words, word_tokens, start_times, end_times, word_probabilities 241 | ) 242 | ] 243 | 244 | 245 | def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str): 246 | # merge prepended punctuations 247 | i = len(alignment) - 2 248 | j = len(alignment) - 1 249 | while i >= 0: 250 | previous = alignment[i] 251 | following = alignment[j] 252 | if previous.word.startswith(" ") and previous.word.strip() in prepended: 253 | # prepend it to the following word 254 | following.word = previous.word + following.word 255 | following.tokens = previous.tokens + following.tokens 256 | previous.word = "" 257 | previous.tokens = [] 258 | else: 259 | j = i 260 | i -= 1 261 | 262 | # merge appended punctuations 263 | i = 0 264 | j = 1 265 | while j < len(alignment): 266 | previous = alignment[i] 267 | following = alignment[j] 268 | if not previous.word.endswith(" ") and following.word in appended: 269 | # append it to the previous word 270 | previous.word = previous.word + following.word 271 | previous.tokens = previous.tokens + following.tokens 272 | following.word = "" 273 | following.tokens = [] 274 | else: 275 | i = j 276 | j += 1 277 | 278 | 279 | def add_word_timestamps( 280 | *, 281 | segments: List[dict], 282 | model: "Whisper", 283 | tokenizer: Tokenizer, 284 | mel: torch.Tensor, 285 | num_frames: int, 286 | prepend_punctuations: str = "\"'“¿([{-", 287 | append_punctuations: str = "\"'.。,,!!??::”)]}、", 288 | **kwargs, 289 | ): 290 | if len(segments) == 0: 291 | return 292 | 293 | text_tokens = [t for s in segments for t in s["tokens"] if t < tokenizer.eot] 294 | alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs) 295 | merge_punctuations(alignment, prepend_punctuations, append_punctuations) 296 | 297 | time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE 298 | segment_lengths = [len(s["tokens"]) for s in segments] 299 | token_sources = np.repeat(np.arange(len(segments)), segment_lengths) 300 | 301 | for segment in segments: 302 | segment["words"] = [] 303 | 304 | word_boundaries = np.pad(np.cumsum([len(w.tokens) for w in alignment]), (1, 0)) 305 | for i, timing in enumerate(alignment): 306 | if timing.word: 307 | segment = segments[token_sources[word_boundaries[i]]] 308 | start = round(time_offset + timing.start, 2) 309 | end = round(time_offset + timing.end, 2) 310 | segment["words"].append( 311 | dict( 312 | word=timing.word, 313 | start=start, 314 | end=end, 315 | probability=timing.probability, 316 | ) 317 | ) 318 | 319 | for segment in segments: 320 | if len(words := segment["words"]) > 0: 321 | # adjust the segment-level timestamps based on the word-level timestamps 322 | segment["start"] = words[0]["start"] 323 | segment["end"] = words[-1]["end"] 324 | -------------------------------------------------------------------------------- /whisper/tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import string 3 | from dataclasses import dataclass 4 | from functools import cached_property, lru_cache 5 | from typing import List, Optional, Tuple, Union 6 | 7 | import numpy as np 8 | import torch 9 | from transformers import GPT2TokenizerFast 10 | 11 | LANGUAGES = { 12 | "en": "english", 13 | "zh": "chinese", 14 | "de": "german", 15 | "es": "spanish", 16 | "ru": "russian", 17 | "ko": "korean", 18 | "fr": "french", 19 | "ja": "japanese", 20 | "pt": "portuguese", 21 | "tr": "turkish", 22 | "pl": "polish", 23 | "ca": "catalan", 24 | "nl": "dutch", 25 | "ar": "arabic", 26 | "sv": "swedish", 27 | "it": "italian", 28 | "id": "indonesian", 29 | "hi": "hindi", 30 | "fi": "finnish", 31 | "vi": "vietnamese", 32 | "he": "hebrew", 33 | "uk": "ukrainian", 34 | "el": "greek", 35 | "ms": "malay", 36 | "cs": "czech", 37 | "ro": "romanian", 38 | "da": "danish", 39 | "hu": "hungarian", 40 | "ta": "tamil", 41 | "no": "norwegian", 42 | "th": "thai", 43 | "ur": "urdu", 44 | "hr": "croatian", 45 | "bg": "bulgarian", 46 | "lt": "lithuanian", 47 | "la": "latin", 48 | "mi": "maori", 49 | "ml": "malayalam", 50 | "cy": "welsh", 51 | "sk": "slovak", 52 | "te": "telugu", 53 | "fa": "persian", 54 | "lv": "latvian", 55 | "bn": "bengali", 56 | "sr": "serbian", 57 | "az": "azerbaijani", 58 | "sl": "slovenian", 59 | "kn": "kannada", 60 | "et": "estonian", 61 | "mk": "macedonian", 62 | "br": "breton", 63 | "eu": "basque", 64 | "is": "icelandic", 65 | "hy": "armenian", 66 | "ne": "nepali", 67 | "mn": "mongolian", 68 | "bs": "bosnian", 69 | "kk": "kazakh", 70 | "sq": "albanian", 71 | "sw": "swahili", 72 | "gl": "galician", 73 | "mr": "marathi", 74 | "pa": "punjabi", 75 | "si": "sinhala", 76 | "km": "khmer", 77 | "sn": "shona", 78 | "yo": "yoruba", 79 | "so": "somali", 80 | "af": "afrikaans", 81 | "oc": "occitan", 82 | "ka": "georgian", 83 | "be": "belarusian", 84 | "tg": "tajik", 85 | "sd": "sindhi", 86 | "gu": "gujarati", 87 | "am": "amharic", 88 | "yi": "yiddish", 89 | "lo": "lao", 90 | "uz": "uzbek", 91 | "fo": "faroese", 92 | "ht": "haitian creole", 93 | "ps": "pashto", 94 | "tk": "turkmen", 95 | "nn": "nynorsk", 96 | "mt": "maltese", 97 | "sa": "sanskrit", 98 | "lb": "luxembourgish", 99 | "my": "myanmar", 100 | "bo": "tibetan", 101 | "tl": "tagalog", 102 | "mg": "malagasy", 103 | "as": "assamese", 104 | "tt": "tatar", 105 | "haw": "hawaiian", 106 | "ln": "lingala", 107 | "ha": "hausa", 108 | "ba": "bashkir", 109 | "jw": "javanese", 110 | "su": "sundanese", 111 | } 112 | 113 | # language code lookup by name, with a few language aliases 114 | TO_LANGUAGE_CODE = { 115 | **{language: code for code, language in LANGUAGES.items()}, 116 | "burmese": "my", 117 | "valencian": "ca", 118 | "flemish": "nl", 119 | "haitian": "ht", 120 | "letzeburgesch": "lb", 121 | "pushto": "ps", 122 | "panjabi": "pa", 123 | "moldavian": "ro", 124 | "moldovan": "ro", 125 | "sinhalese": "si", 126 | "castilian": "es", 127 | } 128 | 129 | 130 | @dataclass(frozen=True) 131 | class Tokenizer: 132 | """A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens""" 133 | 134 | tokenizer: "GPT2TokenizerFast" 135 | language: Optional[str] 136 | sot_sequence: Tuple[int] 137 | 138 | def encode(self, text, **kwargs): 139 | return self.tokenizer.encode(text, **kwargs) 140 | 141 | def decode( 142 | self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs 143 | ): 144 | return self.tokenizer.decode(token_ids, **kwargs) 145 | 146 | def decode_with_timestamps(self, tokens) -> str: 147 | """ 148 | Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. 149 | This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". 150 | """ 151 | outputs = [[]] 152 | for token in tokens: 153 | if token >= self.timestamp_begin: 154 | timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>" 155 | outputs.append(timestamp) 156 | outputs.append([]) 157 | else: 158 | outputs[-1].append(token) 159 | return "".join( 160 | [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs] 161 | ) 162 | 163 | @cached_property 164 | def eot(self) -> int: 165 | return self.tokenizer.eos_token_id 166 | 167 | @cached_property 168 | def transcribe(self) -> int: 169 | return self._get_single_token_id("<|transcribe|>") 170 | 171 | @cached_property 172 | def translate(self) -> int: 173 | return self._get_single_token_id("<|translate|>") 174 | 175 | @cached_property 176 | def sot(self) -> int: 177 | return self._get_single_token_id("<|startoftranscript|>") 178 | 179 | @cached_property 180 | def speaker_turn(self) -> int: # replaces unused sot_lm token 181 | return self._get_single_token_id("<|speakerturn|>") 182 | 183 | @cached_property 184 | def sot_prev(self) -> int: 185 | return self._get_single_token_id("<|startofprev|>") 186 | 187 | @cached_property 188 | def no_speech(self) -> int: 189 | return self._get_single_token_id("<|nospeech|>") 190 | 191 | @cached_property 192 | def no_timestamps(self) -> int: 193 | return self._get_single_token_id("<|notimestamps|>") 194 | 195 | @cached_property 196 | def timestamp_begin(self) -> int: 197 | return self.tokenizer.all_special_ids[-1] + 1 198 | 199 | @cached_property 200 | def language_token(self) -> int: 201 | """Returns the token id corresponding to the value of the `language` field""" 202 | if self.language is None: 203 | raise ValueError("This tokenizer does not have language token configured") 204 | 205 | additional_tokens = dict( 206 | zip( 207 | self.tokenizer.additional_special_tokens, 208 | self.tokenizer.additional_special_tokens_ids, 209 | ) 210 | ) 211 | candidate = f"<|{self.language}|>" 212 | if candidate in additional_tokens: 213 | return additional_tokens[candidate] 214 | 215 | raise KeyError(f"Language {self.language} not found in tokenizer.") 216 | 217 | @cached_property 218 | def all_language_tokens(self) -> Tuple[int]: 219 | result = [] 220 | for token, token_id in zip( 221 | self.tokenizer.additional_special_tokens, 222 | self.tokenizer.additional_special_tokens_ids, 223 | ): 224 | if token.strip("<|>") in LANGUAGES: 225 | result.append(token_id) 226 | return tuple(result) 227 | 228 | @cached_property 229 | def all_language_codes(self) -> Tuple[str]: 230 | return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens) 231 | 232 | @cached_property 233 | def sot_sequence_including_notimestamps(self) -> Tuple[int]: 234 | return tuple(list(self.sot_sequence) + [self.no_timestamps]) 235 | 236 | @cached_property 237 | def non_speech_tokens(self) -> Tuple[int]: 238 | """ 239 | Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech 240 | annotations, to prevent sampling texts that are not actually spoken in the audio, e.g. 241 | 242 | - ♪♪♪ 243 | - ( SPEAKING FOREIGN LANGUAGE ) 244 | - [DAVID] Hey there, 245 | 246 | keeping basic punctuations like commas, periods, question marks, exclamation points, etc. 247 | """ 248 | symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』') 249 | symbols += ( 250 | "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split() 251 | ) 252 | 253 | # symbols that may be a single token or multiple tokens depending on the tokenizer. 254 | # In case they're multiple tokens, suppress the first token, which is safe because: 255 | # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress 256 | # in generations, and in the 3-byte UTF-8 representation they share the first two bytes. 257 | miscellaneous = set("♩♪♫♬♭♮♯") 258 | assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous) 259 | 260 | # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word 261 | result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]} 262 | for symbol in symbols + list(miscellaneous): 263 | for tokens in [ 264 | self.tokenizer.encode(symbol), 265 | self.tokenizer.encode(" " + symbol), 266 | ]: 267 | if len(tokens) == 1 or symbol in miscellaneous: 268 | result.add(tokens[0]) 269 | 270 | return tuple(sorted(result)) 271 | 272 | def _get_single_token_id(self, text) -> int: 273 | tokens = self.tokenizer.encode(text) 274 | assert len(tokens) == 1, f"{text} is not encoded as a single token" 275 | return tokens[0] 276 | 277 | def split_to_word_tokens(self, tokens: List[int]): 278 | if self.language in {"zh", "ja", "th", "lo", "my"}: 279 | # These languages don't typically use spaces, so it is difficult to split words 280 | # without morpheme analysis. Here, we instead split words at any 281 | # position where the tokens are decoded as valid unicode points 282 | return self.split_tokens_on_unicode(tokens) 283 | 284 | return self.split_tokens_on_spaces(tokens) 285 | 286 | def split_tokens_on_unicode(self, tokens: List[int]): 287 | words = [] 288 | word_tokens = [] 289 | current_tokens = [] 290 | 291 | for token in tokens: 292 | current_tokens.append(token) 293 | decoded = self.decode_with_timestamps(current_tokens) 294 | if "\ufffd" not in decoded: 295 | words.append(decoded) 296 | word_tokens.append(current_tokens) 297 | current_tokens = [] 298 | 299 | return words, word_tokens 300 | 301 | def split_tokens_on_spaces(self, tokens: List[int]): 302 | subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens) 303 | words = [] 304 | word_tokens = [] 305 | 306 | for subword, subword_tokens in zip(subwords, subword_tokens_list): 307 | special = subword_tokens[0] >= self.eot 308 | with_space = subword.startswith(" ") 309 | punctuation = subword.strip() in string.punctuation 310 | if special or with_space or punctuation or len(words) == 0: 311 | words.append(subword) 312 | word_tokens.append(subword_tokens) 313 | else: 314 | words[-1] = words[-1] + subword 315 | word_tokens[-1].extend(subword_tokens) 316 | 317 | return words, word_tokens 318 | 319 | 320 | @lru_cache(maxsize=None) 321 | def build_tokenizer(name: str = "gpt2"): 322 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 323 | path = os.path.join(os.path.dirname(__file__), "assets", name) 324 | tokenizer = GPT2TokenizerFast.from_pretrained(path) 325 | 326 | specials = [ 327 | "<|startoftranscript|>", 328 | *[f"<|{lang}|>" for lang in LANGUAGES.keys()], 329 | "<|translate|>", 330 | "<|transcribe|>", 331 | "<|speakerturn|>", # hack and override "<|startoflm|>" 332 | "<|startofprev|>", 333 | "<|nospeech|>", 334 | "<|notimestamps|>", 335 | ] 336 | 337 | tokenizer.add_special_tokens(dict(additional_special_tokens=specials)) 338 | return tokenizer 339 | 340 | 341 | @lru_cache(maxsize=None) 342 | def get_tokenizer( 343 | multilingual: bool, 344 | *, 345 | task: Optional[str] = None, # Literal["transcribe", "translate", None] 346 | language: Optional[str] = None, 347 | ) -> Tokenizer: 348 | if language is not None: 349 | language = language.lower() 350 | if language not in LANGUAGES: 351 | if language in TO_LANGUAGE_CODE: 352 | language = TO_LANGUAGE_CODE[language] 353 | else: 354 | raise ValueError(f"Unsupported language: {language}") 355 | 356 | if multilingual: 357 | tokenizer_name = "multilingual" 358 | task = task or "transcribe" 359 | language = language or "en" 360 | else: 361 | tokenizer_name = "gpt2" 362 | task = None 363 | language = None 364 | 365 | tokenizer = build_tokenizer(name=tokenizer_name) 366 | all_special_ids: List[int] = tokenizer.all_special_ids 367 | sot: int = all_special_ids[1] 368 | translate: int = all_special_ids[-6] 369 | transcribe: int = all_special_ids[-5] 370 | 371 | langs = tuple(LANGUAGES.keys()) 372 | sot_sequence = [sot] 373 | if language is not None: 374 | sot_sequence.append(sot + 1 + langs.index(language)) 375 | if task is not None: 376 | sot_sequence.append(transcribe if task == "transcribe" else translate) 377 | 378 | return Tokenizer( 379 | tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence) 380 | ) 381 | -------------------------------------------------------------------------------- /whisper/transcribe.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | from typing import TYPE_CHECKING, Optional, Tuple, Union 5 | 6 | import numpy as np 7 | import torch 8 | import tqdm 9 | 10 | from .audio import ( 11 | FRAMES_PER_SECOND, 12 | HOP_LENGTH, 13 | N_FRAMES, 14 | N_SAMPLES, 15 | SAMPLE_RATE, 16 | log_mel_spectrogram, 17 | pad_or_trim, 18 | ) 19 | from .decoding import DecodingOptions, DecodingResult 20 | from .timing import add_word_timestamps 21 | from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer 22 | from .utils import ( 23 | exact_div, 24 | format_timestamp, 25 | get_writer, 26 | make_safe, 27 | optional_float, 28 | optional_int, 29 | str2bool, 30 | ) 31 | 32 | if TYPE_CHECKING: 33 | from .model import Whisper 34 | 35 | 36 | def transcribe( 37 | model: "Whisper", 38 | audio: Union[str, np.ndarray, torch.Tensor], 39 | *, 40 | verbose: Optional[bool] = None, 41 | temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), 42 | compression_ratio_threshold: Optional[float] = 2.4, 43 | logprob_threshold: Optional[float] = -1.0, 44 | no_speech_threshold: Optional[float] = 0.6, 45 | condition_on_previous_text: bool = True, 46 | initial_prompt: Optional[str] = None, 47 | word_timestamps: bool = False, 48 | prepend_punctuations: str = "\"'“¿([{-", 49 | append_punctuations: str = "\"'.。,,!!??::”)]}、", 50 | **decode_options, 51 | ): 52 | """ 53 | Transcribe an audio file using Whisper 54 | 55 | Parameters 56 | ---------- 57 | model: Whisper 58 | The Whisper model instance 59 | 60 | audio: Union[str, np.ndarray, torch.Tensor] 61 | The path to the audio file to open, or the audio waveform 62 | 63 | verbose: bool 64 | Whether to display the text being decoded to the console. If True, displays all the details, 65 | If False, displays minimal details. If None, does not display anything 66 | 67 | temperature: Union[float, Tuple[float, ...]] 68 | Temperature for sampling. It can be a tuple of temperatures, which will be successively used 69 | upon failures according to either `compression_ratio_threshold` or `logprob_threshold`. 70 | 71 | compression_ratio_threshold: float 72 | If the gzip compression ratio is above this value, treat as failed 73 | 74 | logprob_threshold: float 75 | If the average log probability over sampled tokens is below this value, treat as failed 76 | 77 | no_speech_threshold: float 78 | If the no_speech probability is higher than this value AND the average log probability 79 | over sampled tokens is below `logprob_threshold`, consider the segment as silent 80 | 81 | condition_on_previous_text: bool 82 | if True, the previous output of the model is provided as a prompt for the next window; 83 | disabling may make the text inconsistent across windows, but the model becomes less prone to 84 | getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. 85 | 86 | word_timestamps: bool 87 | Extract word-level timestamps using the cross-attention pattern and dynamic time warping, 88 | and include the timestamps for each word in each segment. 89 | 90 | prepend_punctuations: str 91 | If word_timestamps is True, merge these punctuation symbols with the next word 92 | 93 | append_punctuations: str 94 | If word_timestamps is True, merge these punctuation symbols with the previous word 95 | 96 | initial_prompt: Optional[str] 97 | Optional text to provide as a prompt for the first window. This can be used to provide, or 98 | "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns 99 | to make it more likely to predict those word correctly. 100 | 101 | decode_options: dict 102 | Keyword arguments to construct `DecodingOptions` instances 103 | 104 | Returns 105 | ------- 106 | A dictionary containing the resulting text ("text") and segment-level details ("segments"), and 107 | the spoken language ("language"), which is detected when `decode_options["language"]` is None. 108 | """ 109 | dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32 110 | if model.device == torch.device("cpu"): 111 | if torch.cuda.is_available(): 112 | warnings.warn("Performing inference on CPU when CUDA is available") 113 | if dtype == torch.float16: 114 | warnings.warn("FP16 is not supported on CPU; using FP32 instead") 115 | dtype = torch.float32 116 | 117 | if dtype == torch.float32: 118 | decode_options["fp16"] = False 119 | 120 | # Pad 30-seconds of silence to the input audio, for slicing 121 | mel = log_mel_spectrogram(audio, padding=N_SAMPLES) 122 | content_frames = mel.shape[-1] - N_FRAMES 123 | 124 | if decode_options.get("language", None) is None: 125 | if not model.is_multilingual: 126 | decode_options["language"] = "en" 127 | else: 128 | if verbose: 129 | print( 130 | "Detecting language using up to the first 30 seconds. Use `--language` to specify the language" 131 | ) 132 | mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) 133 | _, probs = model.detect_language(mel_segment) 134 | decode_options["language"] = max(probs, key=probs.get) 135 | if verbose is not None: 136 | print( 137 | f"Detected language: {LANGUAGES[decode_options['language']].title()}" 138 | ) 139 | 140 | # decode <|speakerturn|> tokens by default for tdrz (tinydiarize) models 141 | if model.is_tdrz and "with_speaker_turns" not in decode_options: 142 | decode_options["with_speaker_turns"] = True 143 | 144 | language: str = decode_options["language"] 145 | task: str = decode_options.get("task", "transcribe") 146 | tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task) 147 | 148 | if word_timestamps and task == "translate": 149 | warnings.warn("Word-level timestamps on translations may not be reliable.") 150 | 151 | def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: 152 | temperatures = ( 153 | [temperature] if isinstance(temperature, (int, float)) else temperature 154 | ) 155 | decode_result = None 156 | 157 | for t in temperatures: 158 | kwargs = {**decode_options} 159 | if t > 0: 160 | # disable beam_size and patience when t > 0 161 | kwargs.pop("beam_size", None) 162 | kwargs.pop("patience", None) 163 | else: 164 | # disable best_of when t == 0 165 | kwargs.pop("best_of", None) 166 | 167 | options = DecodingOptions(**kwargs, temperature=t) 168 | decode_result = model.decode(segment, options) 169 | 170 | needs_fallback = False 171 | if ( 172 | compression_ratio_threshold is not None 173 | and decode_result.compression_ratio > compression_ratio_threshold 174 | ): 175 | needs_fallback = True # too repetitive 176 | if ( 177 | logprob_threshold is not None 178 | and decode_result.avg_logprob < logprob_threshold 179 | ): 180 | needs_fallback = True # average log probability is too low 181 | 182 | if not needs_fallback: 183 | break 184 | 185 | return decode_result 186 | 187 | seek = 0 188 | input_stride = exact_div( 189 | N_FRAMES, model.dims.n_audio_ctx 190 | ) # mel frames per output token: 2 191 | time_precision = ( 192 | input_stride * HOP_LENGTH / SAMPLE_RATE 193 | ) # time per output token: 0.02 (seconds) 194 | all_tokens = [] 195 | all_segments = [] 196 | prompt_reset_since = 0 197 | 198 | if initial_prompt is not None: 199 | initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip()) 200 | all_tokens.extend(initial_prompt_tokens) 201 | else: 202 | initial_prompt_tokens = [] 203 | 204 | def new_segment( 205 | *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult 206 | ): 207 | tokens = tokens.tolist() 208 | before_speaker_turn = tokens[-2] == tokenizer.speaker_turn 209 | text = tokenizer.decode([token for token in tokens if (token < tokenizer.eot)]) 210 | if before_speaker_turn: 211 | text = text + " [SPEAKER TURN]" 212 | return { 213 | "seek": seek, 214 | "start": start, 215 | "end": end, 216 | "text": text, 217 | "before_speaker_turn": before_speaker_turn, 218 | "tokens": tokens, 219 | "temperature": result.temperature, 220 | "avg_logprob": result.avg_logprob, 221 | "compression_ratio": result.compression_ratio, 222 | "no_speech_prob": result.no_speech_prob, 223 | } 224 | 225 | # show the progress bar when verbose is False (if True, transcribed text will be printed) 226 | with tqdm.tqdm( 227 | total=content_frames, unit="frames", disable=verbose is not False 228 | ) as pbar: 229 | while seek < content_frames: 230 | time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) 231 | mel_segment = mel[:, seek : seek + N_FRAMES] 232 | segment_size = min(N_FRAMES, content_frames - seek) 233 | segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE 234 | mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) 235 | 236 | decode_options["prompt"] = all_tokens[prompt_reset_since:] 237 | result: DecodingResult = decode_with_fallback(mel_segment) 238 | tokens = torch.tensor(result.tokens) 239 | 240 | if no_speech_threshold is not None: 241 | # no voice activity check 242 | should_skip = result.no_speech_prob > no_speech_threshold 243 | if ( 244 | logprob_threshold is not None 245 | and result.avg_logprob > logprob_threshold 246 | ): 247 | # don't skip if the logprob is high enough, despite the no_speech_prob 248 | should_skip = False 249 | 250 | if should_skip: 251 | seek += segment_size # fast-forward to the next segment boundary 252 | continue 253 | 254 | previous_seek = seek 255 | current_segments = [] 256 | 257 | timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) 258 | single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] 259 | 260 | consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] 261 | consecutive.add_(1) 262 | if len(consecutive) > 0: 263 | # if the output contains two consecutive timestamp tokens 264 | slices = consecutive.tolist() 265 | if single_timestamp_ending: 266 | slices.append(len(tokens)) 267 | 268 | last_slice = 0 269 | for current_slice in slices: 270 | sliced_tokens = tokens[last_slice:current_slice] 271 | start_timestamp_pos = ( 272 | sliced_tokens[0].item() - tokenizer.timestamp_begin 273 | ) 274 | end_timestamp_pos = ( 275 | sliced_tokens[-1].item() - tokenizer.timestamp_begin 276 | ) 277 | current_segments.append( 278 | new_segment( 279 | start=time_offset + start_timestamp_pos * time_precision, 280 | end=time_offset + end_timestamp_pos * time_precision, 281 | tokens=sliced_tokens, 282 | result=result, 283 | ) 284 | ) 285 | last_slice = current_slice 286 | 287 | if single_timestamp_ending: 288 | # single timestamp at the end means no speech after the last timestamp. 289 | seek += segment_size 290 | else: 291 | # otherwise, ignore the unfinished segment and seek to the last timestamp 292 | last_timestamp_pos = ( 293 | tokens[last_slice - 1].item() - tokenizer.timestamp_begin 294 | ) 295 | seek += last_timestamp_pos * input_stride 296 | else: 297 | duration = segment_duration 298 | timestamps = tokens[timestamp_tokens.nonzero().flatten()] 299 | if ( 300 | len(timestamps) > 0 301 | and timestamps[-1].item() != tokenizer.timestamp_begin 302 | ): 303 | # no consecutive timestamps but it has a timestamp; use the last one. 304 | last_timestamp_pos = ( 305 | timestamps[-1].item() - tokenizer.timestamp_begin 306 | ) 307 | duration = last_timestamp_pos * time_precision 308 | 309 | current_segments.append( 310 | new_segment( 311 | start=time_offset, 312 | end=time_offset + duration, 313 | tokens=tokens, 314 | result=result, 315 | ) 316 | ) 317 | seek += segment_size 318 | 319 | if not condition_on_previous_text or result.temperature > 0.5: 320 | # do not feed the prompt tokens if a high temperature was used 321 | prompt_reset_since = len(all_tokens) 322 | 323 | if word_timestamps: 324 | add_word_timestamps( 325 | segments=current_segments, 326 | model=model, 327 | tokenizer=tokenizer, 328 | mel=mel_segment, 329 | num_frames=segment_size, 330 | prepend_punctuations=prepend_punctuations, 331 | append_punctuations=append_punctuations, 332 | ) 333 | word_end_timestamps = [ 334 | w["end"] for s in current_segments for w in s["words"] 335 | ] 336 | if not single_timestamp_ending and len(word_end_timestamps) > 0: 337 | seek_shift = round( 338 | (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND 339 | ) 340 | if seek_shift > 0: 341 | seek = previous_seek + seek_shift 342 | 343 | if verbose: 344 | for segment in current_segments: 345 | start, end, text = segment["start"], segment["end"], segment["text"] 346 | line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}" 347 | print(make_safe(line)) 348 | if segment["before_speaker_turn"]: 349 | print(" " * 25) 350 | # print("=" * 25) # debug decoded chunk boundaries 351 | 352 | # if a segment is instantaneous or does not contain text, clear it 353 | for i, segment in enumerate(current_segments): 354 | if segment["start"] == segment["end"] or segment["text"].strip() == "": 355 | segment["text"] = "" 356 | segment["tokens"] = [] 357 | segment["words"] = [] 358 | 359 | all_segments.extend( 360 | [ 361 | {"id": i, **segment} 362 | for i, segment in enumerate( 363 | current_segments, start=len(all_segments) 364 | ) 365 | ] 366 | ) 367 | all_tokens.extend( 368 | [token for segment in current_segments for token in segment["tokens"]] 369 | ) 370 | 371 | # update progress bar 372 | pbar.update(min(content_frames, seek) - previous_seek) 373 | 374 | return dict( 375 | text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), 376 | segments=all_segments, 377 | language=language, 378 | ) 379 | 380 | 381 | def cli(): 382 | from . import available_models 383 | 384 | # fmt: off 385 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 386 | parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") 387 | parser.add_argument("--model", default="small.en-tdrz", choices=available_models(), help="name of the Whisper model to use") 388 | parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") 389 | parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") 390 | parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") 391 | parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced") 392 | parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") 393 | 394 | parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") 395 | parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection") 396 | 397 | parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling") 398 | parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature") 399 | parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero") 400 | parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search") 401 | parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default") 402 | 403 | parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations") 404 | parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") 405 | parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") 406 | parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default") 407 | 408 | parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below") 409 | parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed") 410 | parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed") 411 | parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") 412 | parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them") 413 | parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word") 414 | parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word") 415 | parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") 416 | # fmt: on 417 | 418 | args = parser.parse_args().__dict__ 419 | model_name: str = args.pop("model") 420 | model_dir: str = args.pop("model_dir") 421 | output_dir: str = args.pop("output_dir") 422 | output_format: str = args.pop("output_format") 423 | device: str = args.pop("device") 424 | os.makedirs(output_dir, exist_ok=True) 425 | 426 | if model_name.endswith(".en") and args["language"] not in {"en", "English"}: 427 | if args["language"] is not None: 428 | warnings.warn( 429 | f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead." 430 | ) 431 | args["language"] = "en" 432 | 433 | temperature = args.pop("temperature") 434 | if (increment := args.pop("temperature_increment_on_fallback")) is not None: 435 | temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment)) 436 | else: 437 | temperature = [temperature] 438 | 439 | if (threads := args.pop("threads")) > 0: 440 | torch.set_num_threads(threads) 441 | 442 | from . import load_model 443 | 444 | model = load_model(model_name, device=device, download_root=model_dir) 445 | 446 | writer = get_writer(output_format, output_dir) 447 | for audio_path in args.pop("audio"): 448 | result = transcribe(model, audio_path, temperature=temperature, **args) 449 | writer(result, audio_path) 450 | 451 | 452 | if __name__ == "__main__": 453 | cli() 454 | -------------------------------------------------------------------------------- /whisper/triton_ops.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | 3 | import numpy as np 4 | import torch 5 | 6 | try: 7 | import triton 8 | import triton.language as tl 9 | except ImportError: 10 | raise RuntimeError("triton import failed; try `pip install --pre triton`") 11 | 12 | 13 | @triton.jit 14 | def dtw_kernel( 15 | cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr 16 | ): 17 | offsets = tl.arange(0, BLOCK_SIZE) 18 | mask = offsets < M 19 | 20 | for k in range(1, N + M + 1): # k = i + j 21 | tl.debug_barrier() 22 | 23 | p0 = cost + (k - 1) * cost_stride 24 | p1 = cost + k * cost_stride 25 | p2 = cost + k * cost_stride + 1 26 | 27 | c0 = tl.load(p0 + offsets, mask=mask) 28 | c1 = tl.load(p1 + offsets, mask=mask) 29 | c2 = tl.load(p2 + offsets, mask=mask) 30 | 31 | x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0) 32 | cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2) 33 | 34 | cost_ptr = cost + (k + 1) * cost_stride + 1 35 | tl.store(cost_ptr + offsets, cost_row, mask=mask) 36 | 37 | trace_ptr = trace + (k + 1) * trace_stride + 1 38 | tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1)) 39 | tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2)) 40 | tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2)) 41 | 42 | 43 | @lru_cache(maxsize=None) 44 | def median_kernel(filter_width: int): 45 | @triton.jit 46 | def kernel( 47 | y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr 48 | ): # x.shape[-1] == filter_width 49 | row_idx = tl.program_id(0) 50 | offsets = tl.arange(0, BLOCK_SIZE) 51 | mask = offsets < y_stride 52 | 53 | x_ptr = x + row_idx * x_stride # noqa: F841 54 | y_ptr = y + row_idx * y_stride 55 | 56 | LOAD_ALL_ROWS_HERE # noqa: F821 57 | 58 | BUBBLESORT_HERE # noqa: F821 59 | 60 | tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821 61 | 62 | kernel = triton.JITFunction(kernel.fn) 63 | kernel.src = kernel.src.replace( 64 | " LOAD_ALL_ROWS_HERE", 65 | "\n".join( 66 | [ 67 | f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)" 68 | for i in range(filter_width) 69 | ] 70 | ), 71 | ) 72 | kernel.src = kernel.src.replace( 73 | " BUBBLESORT_HERE", 74 | "\n\n".join( 75 | [ 76 | "\n\n".join( 77 | [ 78 | "\n".join( 79 | [ 80 | f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})", 81 | f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})", 82 | f" row{j} = smaller", 83 | f" row{j + 1} = larger", 84 | ] 85 | ) 86 | for j in range(filter_width - i - 1) 87 | ] 88 | ) 89 | for i in range(filter_width // 2 + 1) 90 | ] 91 | ), 92 | ) 93 | kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}") 94 | 95 | return kernel 96 | 97 | 98 | def median_filter_cuda(x: torch.Tensor, filter_width: int): 99 | """Apply a median filter of given width along the last dimension of x""" 100 | slices = x.contiguous().unfold(-1, filter_width, 1) 101 | grid = np.prod(slices.shape[:-2]) 102 | 103 | kernel = median_kernel(filter_width) 104 | y = torch.empty_like(slices[..., 0]) 105 | 106 | BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length() 107 | kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE) 108 | 109 | return y 110 | -------------------------------------------------------------------------------- /whisper/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import zlib 5 | from typing import Callable, TextIO 6 | 7 | system_encoding = sys.getdefaultencoding() 8 | 9 | if system_encoding != "utf-8": 10 | 11 | def make_safe(string): 12 | # replaces any character not representable using the system default encoding with an '?', 13 | # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729). 14 | return string.encode(system_encoding, errors="replace").decode(system_encoding) 15 | 16 | else: 17 | 18 | def make_safe(string): 19 | # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding 20 | return string 21 | 22 | 23 | def exact_div(x, y): 24 | assert x % y == 0 25 | return x // y 26 | 27 | 28 | def str2bool(string): 29 | str2val = {"True": True, "False": False} 30 | if string in str2val: 31 | return str2val[string] 32 | else: 33 | raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") 34 | 35 | 36 | def optional_int(string): 37 | return None if string == "None" else int(string) 38 | 39 | 40 | def optional_float(string): 41 | return None if string == "None" else float(string) 42 | 43 | 44 | def compression_ratio(text) -> float: 45 | text_bytes = text.encode("utf-8") 46 | return len(text_bytes) / len(zlib.compress(text_bytes)) 47 | 48 | 49 | def format_timestamp( 50 | seconds: float, always_include_hours: bool = False, decimal_marker: str = "." 51 | ): 52 | assert seconds >= 0, "non-negative timestamp expected" 53 | milliseconds = round(seconds * 1000.0) 54 | 55 | hours = milliseconds // 3_600_000 56 | milliseconds -= hours * 3_600_000 57 | 58 | minutes = milliseconds // 60_000 59 | milliseconds -= minutes * 60_000 60 | 61 | seconds = milliseconds // 1_000 62 | milliseconds -= seconds * 1_000 63 | 64 | hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" 65 | return ( 66 | f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" 67 | ) 68 | 69 | 70 | class ResultWriter: 71 | extension: str 72 | 73 | def __init__(self, output_dir: str): 74 | self.output_dir = output_dir 75 | 76 | def __call__(self, result: dict, audio_path: str): 77 | audio_basename = os.path.basename(audio_path) 78 | audio_basename = os.path.splitext(audio_basename)[0] 79 | output_path = os.path.join( 80 | self.output_dir, audio_basename + "." + self.extension 81 | ) 82 | 83 | with open(output_path, "w", encoding="utf-8") as f: 84 | self.write_result(result, file=f) 85 | 86 | def write_result(self, result: dict, file: TextIO): 87 | raise NotImplementedError 88 | 89 | 90 | class WriteTXT(ResultWriter): 91 | extension: str = "txt" 92 | 93 | def write_result(self, result: dict, file: TextIO): 94 | for segment in result["segments"]: 95 | print(segment["text"].strip(), file=file, flush=True) 96 | 97 | 98 | class SubtitlesWriter(ResultWriter): 99 | always_include_hours: bool 100 | decimal_marker: str 101 | 102 | def iterate_result(self, result: dict): 103 | for segment in result["segments"]: 104 | segment_start = self.format_timestamp(segment["start"]) 105 | segment_end = self.format_timestamp(segment["end"]) 106 | segment_text = segment["text"].strip().replace("-->", "->") 107 | 108 | if word_timings := segment.get("words", None): 109 | all_words = [timing["word"] for timing in word_timings] 110 | all_words[0] = all_words[0].strip() # remove the leading space, if any 111 | last = segment_start 112 | for i, this_word in enumerate(word_timings): 113 | start = self.format_timestamp(this_word["start"]) 114 | end = self.format_timestamp(this_word["end"]) 115 | if last != start: 116 | yield last, start, segment_text 117 | 118 | yield start, end, "".join( 119 | [ 120 | f"{word}" if j == i else word 121 | for j, word in enumerate(all_words) 122 | ] 123 | ) 124 | last = end 125 | 126 | if last != segment_end: 127 | yield last, segment_end, segment_text 128 | else: 129 | yield segment_start, segment_end, segment_text 130 | 131 | def format_timestamp(self, seconds: float): 132 | return format_timestamp( 133 | seconds=seconds, 134 | always_include_hours=self.always_include_hours, 135 | decimal_marker=self.decimal_marker, 136 | ) 137 | 138 | 139 | class WriteVTT(SubtitlesWriter): 140 | extension: str = "vtt" 141 | always_include_hours: bool = False 142 | decimal_marker: str = "." 143 | 144 | def write_result(self, result: dict, file: TextIO): 145 | print("WEBVTT\n", file=file) 146 | for start, end, text in self.iterate_result(result): 147 | print(f"{start} --> {end}\n{text}\n", file=file, flush=True) 148 | 149 | 150 | class WriteSRT(SubtitlesWriter): 151 | extension: str = "srt" 152 | always_include_hours: bool = True 153 | decimal_marker: str = "," 154 | 155 | def write_result(self, result: dict, file: TextIO): 156 | for i, (start, end, text) in enumerate(self.iterate_result(result), start=1): 157 | print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True) 158 | 159 | 160 | class WriteTSV(ResultWriter): 161 | """ 162 | Write a transcript to a file in TSV (tab-separated values) format containing lines like: 163 | \t\t 164 | 165 | Using integer milliseconds as start and end times means there's no chance of interference from 166 | an environment setting a language encoding that causes the decimal in a floating point number 167 | to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++. 168 | """ 169 | 170 | extension: str = "tsv" 171 | 172 | def write_result(self, result: dict, file: TextIO): 173 | print("start", "end", "text", sep="\t", file=file) 174 | for segment in result["segments"]: 175 | print(round(1000 * segment["start"]), file=file, end="\t") 176 | print(round(1000 * segment["end"]), file=file, end="\t") 177 | print(segment["text"].strip().replace("\t", " "), file=file, flush=True) 178 | 179 | 180 | class WriteJSON(ResultWriter): 181 | extension: str = "json" 182 | 183 | def write_result(self, result: dict, file: TextIO): 184 | json.dump(result, file) 185 | 186 | 187 | def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]: 188 | writers = { 189 | "txt": WriteTXT, 190 | "vtt": WriteVTT, 191 | "srt": WriteSRT, 192 | "tsv": WriteTSV, 193 | "json": WriteJSON, 194 | } 195 | 196 | if output_format == "all": 197 | all_writers = [writer(output_dir) for writer in writers.values()] 198 | 199 | def write_all(result: dict, file: TextIO): 200 | for writer in all_writers: 201 | writer(result, file) 202 | 203 | return write_all 204 | 205 | return writers[output_format](output_dir) 206 | -------------------------------------------------------------------------------- /whisper/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "20230308" 2 | --------------------------------------------------------------------------------