├── .dockerignore ├── .gitignore ├── .vscode └── launch.json ├── Dockerfile.gpu ├── Dockerfile.tensorrt ├── LICENSE ├── README.md ├── test.py ├── test.sh ├── test_faster_whisper.py ├── test_faster_whisper.sh └── whisper ├── __init__.py ├── __main__.py ├── assets ├── gpt2 │ ├── merges.txt │ ├── special_tokens_map.json │ ├── tokenizer_config.json │ └── vocab.json ├── mel_filters.npz └── multilingual │ ├── added_tokens.json │ ├── merges.txt │ ├── special_tokens_map.json │ ├── tokenizer_config.json │ └── vocab.json ├── audio.py ├── decoding.py ├── model.py ├── normalizers ├── __init__.py ├── basic.py ├── english.json └── english.py ├── tokenizer.py ├── transcribe.py └── utils.py /.dockerignore: -------------------------------------------------------------------------------- 1 | * 2 | !nv-tensorrt-local-repo-ubuntu2204-8.5.3-cuda-11.8_1.0-1_amd64.deb -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.onnx 2 | *.mp3 3 | *.mp4 4 | *.npy 5 | *.srt 6 | *.txt 7 | *.vtt 8 | *.engine 9 | *.profile 10 | __pycache__ 11 | packages 12 | nv-tensorrt-local-repo-ubuntu2204-8.5.3-cuda-11.8_1.0-1_amd64.deb -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.2.0", 3 | "configurations": [ 4 | { 5 | "name": "Transcribe", 6 | "type": "python", 7 | "request": "launch", 8 | "program": "${file}", 9 | "console": "integratedTerminal", 10 | "justMyCode": true, 11 | "args": [ 12 | "carmack.mp3", 13 | "--model", "tiny.en", 14 | "--disable_cupy", 15 | "--beam_size", "2", 16 | ] 17 | } 18 | ] 19 | } -------------------------------------------------------------------------------- /Dockerfile.gpu: -------------------------------------------------------------------------------- 1 | FROM pinto0309/ubuntu22.04-cuda11.8:latest 2 | ENV DEBIAN_FRONTEND=noninteractive 3 | ARG USERNAME=user 4 | ARG OS=ubuntu2204 5 | ARG ONNXVER=1.13.1 6 | ARG ONNXRUNTIMEVER=1.13.1 7 | ARG CUDAVER=11.8 8 | ARG CUDNNVER=8.9 9 | ARG TENSORRTVER=8.5.3 10 | 11 | SHELL ["/bin/bash", "-c"] 12 | 13 | ENV CUDA_HOME=/usr/local/cuda 14 | ENV PATH=${PATH}:${CUDA_HOME}/bin 15 | ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${CUDA_HOME}/lib64 16 | 17 | RUN apt-get update \ 18 | && apt-get upgrade -y \ 19 | && apt-get install -y --no-install-recommends \ 20 | curl \ 21 | wget \ 22 | gcc \ 23 | git \ 24 | make \ 25 | sudo \ 26 | build-essential \ 27 | ca-certificates \ 28 | pciutils \ 29 | software-properties-common \ 30 | python3-all-dev \ 31 | python-is-python3 \ 32 | python3-pip \ 33 | ffmpeg \ 34 | && pip install pip -U \ 35 | && pip install requests==2.31.0 \ 36 | && pip install tqdm==4.65.0 \ 37 | && pip install more-itertools==8.10.0 \ 38 | && pip install ffmpeg-python==0.2.0 \ 39 | && pip install transformers==4.29.2 \ 40 | && pip install onnx==${ONNXVER} \ 41 | && pip install onnxsim==0.4.17 \ 42 | && pip install nvidia-pyindex \ 43 | && pip install onnx-graphsurgeon \ 44 | && pip install protobuf==3.20.3 \ 45 | && pip install h5py==3.7.0 \ 46 | && pip install pynvml==11.5.0 \ 47 | && wget https://s3.us-central-1.wasabisys.com/tensorrt-installers/${OS}-tensorrt${TENSORRTVER}-cuda${CUDAVER}-cudnn${CUDNNVER}-onnxruntime${ONNXRUNTIMEVER}/onnxruntime_gpu-${ONNXRUNTIMEVER}-cp310-cp310-linux_x86_64.whl \ 48 | && pip uninstall onnxruntime onnxruntime-gpu \ 49 | && pip install onnxruntime_gpu-${ONNXRUNTIMEVER}-cp310-cp310-linux_x86_64.whl \ 50 | && rm onnxruntime_gpu-${ONNXRUNTIMEVER}-cp310-cp310-linux_x86_64.whl \ 51 | && apt clean \ 52 | && rm -rf /var/lib/apt/lists/* \ 53 | && rm /etc/apt/apt.conf.d/docker-clean 54 | 55 | ENV USERNAME=user 56 | RUN echo "root:root" | chpasswd \ 57 | && adduser --disabled-password --gecos "" "${USERNAME}" \ 58 | && echo "${USERNAME}:${USERNAME}" | chpasswd \ 59 | && echo "%${USERNAME} ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers.d/${USERNAME} \ 60 | && chmod 0440 /etc/sudoers.d/${USERNAME} 61 | USER ${USERNAME} 62 | 63 | ARG CUPYDIR=/app 64 | WORKDIR ${CUPYDIR} 65 | RUN sudo chown ${USERNAME}:${USERNAME} ${CUPYDIR} 66 | RUN git clone --recursive -b v12.0.0 https://github.com/cupy/cupy.git \ 67 | && pushd cupy \ 68 | && pip install . \ 69 | && popd 70 | 71 | ARG WKDIR=/workdir 72 | WORKDIR ${WKDIR} 73 | RUN sudo chown ${USERNAME}:${USERNAME} ${WKDIR} 74 | 75 | RUN echo 'export PATH=${PATH}:${HOME}/.local/bin' >> ~/.bashrc \ 76 | && echo "export USER=`whoami`" >> ~/.bashrc \ 77 | && echo 'export PATH=/usr/local/cuda/bin:$PATH' >> ~/.bashrc \ 78 | && echo 'export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH' >> ~/.bashrc 79 | -------------------------------------------------------------------------------- /Dockerfile.tensorrt: -------------------------------------------------------------------------------- 1 | FROM pinto0309/whisper-onnx-cuda:latest 2 | ENV DEBIAN_FRONTEND=noninteractive 3 | ARG USERNAME=user 4 | ARG OS=ubuntu2204 5 | ARG CUDAVER=11.8 6 | ARG CUDNNVER=8.9 7 | ARG TENSORRTVER=8.5.3 8 | ARG PYCUDAVER=2022.2 9 | 10 | COPY nv-tensorrt-local-repo-${OS}-${TENSORRTVER}-cuda-${CUDAVER}_1.0-1_amd64.deb . 11 | 12 | SHELL ["/bin/bash", "-c"] 13 | 14 | ENV CUDA_HOME=/usr/local/cuda 15 | ENV PATH=${PATH}:${CUDA_HOME}/bin 16 | ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${CUDA_HOME}/lib64 17 | 18 | # Install TensorRT 19 | # https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/secure/8.5.3/local_repos/nv-tensorrt-local-repo-ubuntu2204-8.5.3-cuda-11.8_1.0-1_amd64.deb 20 | RUN sudo dpkg -i nv-tensorrt-local-repo-${OS}-${TENSORRTVER}-cuda-${CUDAVER}_1.0-1_amd64.deb \ 21 | && sudo cp /var/nv-tensorrt-local-repo-${OS}-${TENSORRTVER}-cuda-${CUDAVER}/*-keyring.gpg /usr/share/keyrings/ \ 22 | && sudo apt-get update \ 23 | && sudo apt-get install -y --no-install-recommends \ 24 | tensorrt=${TENSORRTVER}.1-1+cuda${CUDAVER} \ 25 | tensorrt-dev=${TENSORRTVER}.1-1+cuda${CUDAVER} \ 26 | tensorrt-libs=${TENSORRTVER}.1-1+cuda${CUDAVER} \ 27 | uff-converter-tf=${TENSORRTVER}-1+cuda${CUDAVER} \ 28 | python3-libnvinfer-dev=${TENSORRTVER}-1+cuda${CUDAVER} \ 29 | python3-libnvinfer=${TENSORRTVER}-1+cuda${CUDAVER} \ 30 | libnvparsers-dev=${TENSORRTVER}-1+cuda${CUDAVER} \ 31 | libnvparsers8=${TENSORRTVER}-1+cuda${CUDAVER} \ 32 | libnvonnxparsers-dev=${TENSORRTVER}-1+cuda${CUDAVER} \ 33 | libnvonnxparsers8=${TENSORRTVER}-1+cuda${CUDAVER} \ 34 | libnvinfer-samples=${TENSORRTVER}-1+cuda${CUDAVER} \ 35 | libnvinfer-plugin-dev=${TENSORRTVER}-1+cuda${CUDAVER} \ 36 | libnvinfer-plugin8=${TENSORRTVER}-1+cuda${CUDAVER} \ 37 | libnvinfer-dev=${TENSORRTVER}-1+cuda${CUDAVER} \ 38 | libnvinfer-bin=${TENSORRTVER}-1+cuda${CUDAVER} \ 39 | libnvinfer8=${TENSORRTVER}-1+cuda${CUDAVER} \ 40 | graphsurgeon-tf=${TENSORRTVER}-1+cuda${CUDAVER} \ 41 | onnx-graphsurgeon=${TENSORRTVER}-1+cuda${CUDAVER} \ 42 | libprotobuf-dev \ 43 | protobuf-compiler \ 44 | cmake \ 45 | && rm nv-tensorrt-local-repo-${OS}-${TENSORRTVER}-cuda-${CUDAVER}_1.0-1_amd64.deb \ 46 | && cd /usr/src/tensorrt/samples/trtexec \ 47 | && sudo make \ 48 | && sudo apt clean \ 49 | && sudo rm -rf /var/lib/apt/lists/* 50 | 51 | # Install onnx-tensorrt 52 | RUN git clone -b release/8.5-GA --recursive https://github.com/onnx/onnx-tensorrt /home/${USERNAME}/onnx-tensorrt \ 53 | && pushd /home/${USERNAME}/onnx-tensorrt \ 54 | && mkdir build \ 55 | && pushd build \ 56 | && cmake .. -DTENSORRT_ROOT=/usr/src/tensorrt \ 57 | && make -j$(nproc) \ 58 | && sudo make install \ 59 | && popd \ 60 | && popd \ 61 | && pip install pycuda==${PYCUDAVER} \ 62 | && echo "pushd /home/${USERNAME}/onnx-tensorrt > /dev/null" >> ~/.bashrc \ 63 | # At docker build time, setup.py fails because NVIDIA's physical GPU device cannot be detected. 64 | # Therefore, a workaround is applied to configure setup.py to run on first access. 65 | && echo 'python setup.py install --user 1>/dev/null 2>/dev/null' >> ~/.bashrc \ 66 | && echo 'popd > /dev/null' >> ~/.bashrc \ 67 | && echo 'export CUDA_MODULE_LOADING=LAZY' >> ~/.bashrc \ 68 | && echo 'export PATH=${PATH}:/usr/src/tensorrt/bin:${HOME}/onnx-tensorrt/build' >> ~/.bashrc 69 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Katsuya Hyodo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # whisper-onnx-tensorrt 2 | ONNX and TensorRT implementation of Whisper. 3 | 4 | This repository has been reimplemented with ONNX and TensorRT using [zhuzilin/whisper-openvino](https://github.com/zhuzilin/whisper-openvino) as a reference. 5 | 6 | Enables execution only with onnxruntime with CUDA and TensorRT Excecution Provider enabled, no need to install PyTorch or TensorFlow. All backend logic using PyTorch was rewritten to a Numpy/CuPy implementation from scratch. 7 | 8 | Click here for CPU version: https://github.com/PINTO0309/whisper-onnx-cpu 9 | 10 | ## 1. Environment 11 | Although it can run directly on the host PC, I strongly recommend the use of Docker to avoid breaking the environment. 12 | 13 | 1. Docker 14 | 2. NVIDIA GPU (VRAM 16 GB or more recommended) 15 | 3. onnx 1.13.1 16 | 4. onnxruntime-gpu 1.13.1 (TensorRT Execution Provider custom) 17 | 5. CUDA 11.8 18 | 6. cuDNN 8.9 19 | 7. TensorRT 8.5.3 20 | 8. onnx-tensorrt 8.5-GA 21 | 9. cupy v12.0.0 22 | 10. etc (See Dockerfile.xxx) 23 | 24 | ## 2. Converted Models 25 | https://github.com/PINTO0309/PINTO_model_zoo/tree/main/381_Whisper 26 | 27 | ## 3. Docker run 28 | ```bash 29 | git clone https://github.com/PINTO0309/whisper-onnx-tensorrt.git && cd whisper-onnx-tensorrt 30 | ``` 31 | ### 3-1. CUDA ver 32 | ```bash 33 | docker run --rm -it --gpus all -v `pwd`:/workdir pinto0309/whisper-onnx-cuda 34 | ``` 35 | ### 3-2. TensorRT ver 36 | ```bash 37 | docker run --rm -it --gpus all -v `pwd`:/workdir pinto0309/whisper-onnx-tensorrt 38 | ``` 39 | 40 | ## 4. Docker build 41 | If you do not need to build the docker image by yourself, you do not need to perform this step. 42 | ### 4-1. CUDA ver 43 | ```bash 44 | docker build -t whisper-onnx -f Dockerfile.gpu . 45 | ``` 46 | ### 4-2. TensorRT ver 47 | ```bash 48 | docker build -t whisper-onnx -f Dockerfile.tensorrt . 49 | ``` 50 | ### 4-3. docker run 51 | ```bash 52 | docker run --rm -it --gpus all -v `pwd`:/workdir whisper-onnx 53 | ``` 54 | 55 | ## 5. Transcribe 56 | - `--model` option 57 | ``` 58 | tiny.en 59 | tiny 60 | base.en 61 | base 62 | small.en 63 | small 64 | medium.en 65 | medium 66 | large-v1 67 | large-v2 68 | ``` 69 | - command 70 | 71 | The onnx file is automatically downloaded when the sample is run. Note that `Decoder` is run in CUDA, not TensorRT, because the shape of all input tensors must be undefined. When running the TensorRT version, there is a 5 to 10 minute wait for the compilation process from ONNX to the TensorRT Engine during the first inference. If `--language` is not specified, the tokenizer will auto-detect the language. 72 | ```bash 73 | python whisper/transcribe.py xxxx.mp4 --model small --beam_size 3 74 | ``` 75 | - results 76 | ``` 77 | Detecting language using up to the first 30 seconds. Use `--language` to specify the language 78 | Detected language: Japanese 79 | [00:00.000 --> 00:07.200] ストレオシンの推定モデルの最適化 としまして 後半のパート2は 実際 80 | [00:07.200 --> 00:11.600] のデモを交えまして 普段私がどのように モデルを最適化して 様々な 81 | [00:11.600 --> 00:15.600] フレームワークの環境でプロイしてる かというのを実際に操作をこの 82 | [00:15.600 --> 00:18.280] 画面上で見ていただきながら ご理解いただけるように努めたい 83 | [00:18.280 --> 00:21.600] と思います それでは早速ですが こちらの 84 | [00:21.600 --> 00:26.320] GitHubの方に本日の公演内容について は すべてチュートリアルをまとめて 85 | [00:26.320 --> 00:31.680] コミットしております 2021.0.20.28 インテルティブラーニング 86 | [00:31.680 --> 00:35.200] でヒットネットデモという ちょっと長い名前なんですけれども 現状 87 | [00:35.200 --> 00:39.120] はプライベートになってますが この公演のタイミングでパブリック 88 | [00:39.120 --> 00:43.440] の方に変更したいと思っております 基本的にはこちらの上から順前 89 | [00:43.440 --> 00:48.000] ですね チュートリアルを謎って いくという形になります 90 | [00:48.000 --> 00:52.640] まず本日対象にするモデルの内容 なんですけれども Google Research 91 | ``` 92 | - parameters 93 | ``` 94 | usage: transcribe.py 95 | [-h] 96 | [--model {tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2}] 97 | [--output_dir OUTPUT_DIR] 98 | [--verbose VERBOSE] 99 | [--disable_cupy] 100 | [--task {transcribe,translate}] 101 | [--language {af, am, ...}] 102 | [--temperature TEMPERATURE] 103 | [--best_of BEST_OF] 104 | [--beam_size BEAM_SIZE] 105 | [--patience PATIENCE] 106 | [--length_penalty LENGTH_PENALTY] 107 | [--suppress_tokens SUPPRESS_TOKENS] 108 | [--initial_prompt INITIAL_PROMPT] 109 | [--condition_on_previous_text CONDITION_ON_PREVIOUS_TEXT] 110 | [--temperature_increment_on_fallback TEMPERATURE_INCREMENT_ON_FALLBACK] 111 | [--compression_ratio_threshold COMPRESSION_RATIO_THRESHOLD] 112 | [--logprob_threshold LOGPROB_THRESHOLD] 113 | [--no_speech_threshold NO_SPEECH_THRESHOLD] 114 | audio [audio ...] 115 | 116 | positional arguments: 117 | audio 118 | audio file(s) to transcribe 119 | 120 | optional arguments: 121 | -h, --help 122 | show this help message and exit 123 | --model {tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2} 124 | name of the Whisper model to use 125 | (default: small) 126 | --output_dir OUTPUT_DIR, -o OUTPUT_DIR 127 | directory to save the outputs 128 | (default: .) 129 | --verbose VERBOSE 130 | whether to print out the progress and debug messages 131 | (default: True) 132 | --disable_cupy 133 | When Out of Memory occurs due to insufficient GPU RAM, this option suppresses GPU 134 | RAM consumption. 135 | --task {transcribe,translate} 136 | whether to perform X->X speech recognition ('transcribe') or 137 | X->English translation ('translate') 138 | (default: transcribe) 139 | --language {af, am, ...} 140 | language spoken in the audio, specify None to perform language detection 141 | (default: None) 142 | --temperature TEMPERATURE 143 | temperature to use for sampling 144 | (default: 0) 145 | --best_of BEST_OF 146 | number of candidates when sampling with non-zero temperature 147 | (default: 5) 148 | --beam_size BEAM_SIZE 149 | number of beams in beam search, only applicable when temperature is zero 150 | (default: 5) 151 | --patience PATIENCE 152 | optional patience value to use in beam decoding, 153 | as in https://arxiv.org/abs/2204.05424, 154 | the default (1.0) is equivalent to conventional beam search 155 | (default: None) 156 | --length_penalty LENGTH_PENALTY 157 | optional token length penalty coefficient (alpha) as in 158 | https://arxiv.org/abs/1609.08144, uses simple lengt normalization by default 159 | (default: None) 160 | --suppress_tokens SUPPRESS_TOKENS 161 | comma-separated list of token ids to suppress during sampling; 162 | '-1' will suppress most special characters except common punctuations 163 | (default: -1) 164 | --initial_prompt INITIAL_PROMPT 165 | optional text to provide as a prompt for the first window. 166 | (default: None) 167 | --condition_on_previous_text CONDITION_ON_PREVIOUS_TEXT 168 | if True, provide the previous output of the model as a prompt for the next window; 169 | disabling may make the text inconsistent across windows, but the model becomes 170 | less prone to getting stuck in a failure loop 171 | (default: True) 172 | --temperature_increment_on_fallback TEMPERATURE_INCREMENT_ON_FALLBACK 173 | temperature to increase when falling back when the decoding fails to meet either of 174 | the thresholds below 175 | (default: 0.2) 176 | --compression_ratio_threshold COMPRESSION_RATIO_THRESHOLD 177 | if the gzip compression ratio is higher than this value, treat the decoding as failed 178 | (default: 2.4) 179 | --logprob_threshold LOGPROB_THRESHOLD 180 | if the average log probability is lower than this value, treat the decoding as failed 181 | (default: -1.0) 182 | --no_speech_threshold NO_SPEECH_THRESHOLD 183 | if the probability of the <|nospeech|> token is higher than this value AND 184 | the decoding has failed due to `logprob_threshold`, consider the segment as silence 185 | (default: 0.6) 186 | ``` 187 | ## 6. Languages 188 | https://github.com/PINTO0309/whisper-onnx-tensorrt/blob/main/whisper/tokenizer.py 189 | ``` 190 | LANGUAGES = { 191 | "en": "english", 192 | "zh": "chinese", 193 | "de": "german", 194 | "es": "spanish", 195 | "ru": "russian", 196 | "ko": "korean", 197 | "fr": "french", 198 | "ja": "japanese", 199 | "pt": "portuguese", 200 | "tr": "turkish", 201 | "pl": "polish", 202 | "ca": "catalan", 203 | "nl": "dutch", 204 | "ar": "arabic", 205 | "sv": "swedish", 206 | "it": "italian", 207 | "id": "indonesian", 208 | "hi": "hindi", 209 | "fi": "finnish", 210 | "vi": "vietnamese", 211 | "iw": "hebrew", 212 | "uk": "ukrainian", 213 | "el": "greek", 214 | "ms": "malay", 215 | "cs": "czech", 216 | "ro": "romanian", 217 | "da": "danish", 218 | "hu": "hungarian", 219 | "ta": "tamil", 220 | "no": "norwegian", 221 | "th": "thai", 222 | "ur": "urdu", 223 | "hr": "croatian", 224 | "bg": "bulgarian", 225 | "lt": "lithuanian", 226 | "la": "latin", 227 | "mi": "maori", 228 | "ml": "malayalam", 229 | "cy": "welsh", 230 | "sk": "slovak", 231 | "te": "telugu", 232 | "fa": "persian", 233 | "lv": "latvian", 234 | "bn": "bengali", 235 | "sr": "serbian", 236 | "az": "azerbaijani", 237 | "sl": "slovenian", 238 | "kn": "kannada", 239 | "et": "estonian", 240 | "mk": "macedonian", 241 | "br": "breton", 242 | "eu": "basque", 243 | "is": "icelandic", 244 | "hy": "armenian", 245 | "ne": "nepali", 246 | "mn": "mongolian", 247 | "bs": "bosnian", 248 | "kk": "kazakh", 249 | "sq": "albanian", 250 | "sw": "swahili", 251 | "gl": "galician", 252 | "mr": "marathi", 253 | "pa": "punjabi", 254 | "si": "sinhala", 255 | "km": "khmer", 256 | "sn": "shona", 257 | "yo": "yoruba", 258 | "so": "somali", 259 | "af": "afrikaans", 260 | "oc": "occitan", 261 | "ka": "georgian", 262 | "be": "belarusian", 263 | "tg": "tajik", 264 | "sd": "sindhi", 265 | "gu": "gujarati", 266 | "am": "amharic", 267 | "yi": "yiddish", 268 | "lo": "lao", 269 | "uz": "uzbek", 270 | "fo": "faroese", 271 | "ht": "haitian creole", 272 | "ps": "pashto", 273 | "tk": "turkmen", 274 | "nn": "nynorsk", 275 | "mt": "maltese", 276 | "sa": "sanskrit", 277 | "lb": "luxembourgish", 278 | "my": "myanmar", 279 | "bo": "tibetan", 280 | "tl": "tagalog", 281 | "mg": "malagasy", 282 | "as": "assamese", 283 | "tt": "tatar", 284 | "haw": "hawaiian", 285 | "ln": "lingala", 286 | "ha": "hausa", 287 | "ba": "bashkir", 288 | "jw": "javanese", 289 | "su": "sundanese", 290 | } 291 | ``` 292 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import io 2 | import onnx 3 | import requests 4 | import onnxruntime as ort 5 | 6 | def model_download(name: str, onnx_file_save_path: str='.') -> bytes: 7 | URL = f'https://s3.ap-northeast-2.wasabisys.com/pinto-model-zoo/381_Whisper/onnx/{name}_11_float16.onnx' 8 | onnx_serialized_graph = requests.get(URL).content 9 | with io.BytesIO(onnx_serialized_graph) as f: 10 | onnx_graph: onnx.ModelProto = onnx.load(f) 11 | onnx.save(onnx_graph, f'{onnx_file_save_path}/{name}_11_float16.onnx') 12 | return onnx_serialized_graph 13 | 14 | if __name__ == '__main__': 15 | onnx_serialized_graph = model_download('tiny_encoder') 16 | ort_sess_encoder = \ 17 | ort.InferenceSession( 18 | path_or_bytes=onnx_serialized_graph, 19 | providers=['CUDAExecutionProvider'], 20 | ) 21 | onnx_serialized_graph = model_download('tiny_decoder') 22 | ort_sess_decoder = \ 23 | ort.InferenceSession( 24 | path_or_bytes=onnx_serialized_graph, 25 | providers=['CUDAExecutionProvider'], 26 | ) 27 | a=0 -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | start_time=`date +%s` 4 | 5 | python whisper/transcribe.py carmack.mp3 --model tiny.en --beam_size 2 6 | 7 | end_time=`date +%s` 8 | 9 | run_time=$((end_time - start_time)) 10 | 11 | echo $run_time -------------------------------------------------------------------------------- /test_faster_whisper.py: -------------------------------------------------------------------------------- 1 | from faster_whisper import WhisperModel 2 | 3 | model_size = "tiny.en" 4 | 5 | # Run on GPU with FP16 6 | model = WhisperModel(model_size, device="cuda", compute_type="float16") 7 | 8 | # or run on GPU with INT8 9 | # model = WhisperModel(model_size, device="cuda", compute_type="int8_float16") 10 | # or run on CPU with INT8 11 | # model = WhisperModel(model_size, device="cpu", compute_type="int8") 12 | 13 | segments, info = model.transcribe("carmack.mp3", beam_size=5) 14 | 15 | print("Detected language '%s' with probability %f" % (info.language, info.language_probability)) 16 | 17 | for segment in segments: 18 | print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text)) 19 | -------------------------------------------------------------------------------- /test_faster_whisper.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | start_time=`date +%s` 4 | 5 | python test_faster_whisper.py 6 | 7 | end_time=`date +%s` 8 | 9 | run_time=$((end_time - start_time)) 10 | 11 | echo $run_time -------------------------------------------------------------------------------- /whisper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PINTO0309/whisper-onnx-tensorrt/9449f557d221b627d8478ea066eb93a5f6bf2dc4/whisper/__init__.py -------------------------------------------------------------------------------- /whisper/__main__.py: -------------------------------------------------------------------------------- 1 | from .transcribe import cli 2 | 3 | 4 | cli() 5 | -------------------------------------------------------------------------------- /whisper/assets/gpt2/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"} -------------------------------------------------------------------------------- /whisper/assets/gpt2/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"} -------------------------------------------------------------------------------- /whisper/assets/mel_filters.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PINTO0309/whisper-onnx-tensorrt/9449f557d221b627d8478ea066eb93a5f6bf2dc4/whisper/assets/mel_filters.npz -------------------------------------------------------------------------------- /whisper/assets/multilingual/added_tokens.json: -------------------------------------------------------------------------------- 1 | {"<|endoftext|>": 50257} 2 | -------------------------------------------------------------------------------- /whisper/assets/multilingual/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"} -------------------------------------------------------------------------------- /whisper/assets/multilingual/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"} -------------------------------------------------------------------------------- /whisper/audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import lru_cache 3 | from typing import Union 4 | 5 | import ffmpeg 6 | import numpy as np 7 | import cupy as cp 8 | 9 | from whisper.utils import exact_div 10 | 11 | # hard-coded audio hyperparameters 12 | SAMPLE_RATE = 16000 13 | N_FFT = 400 14 | N_MELS = 80 15 | HOP_LENGTH = 160 16 | CHUNK_LENGTH = 30 17 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk 18 | N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input 19 | 20 | 21 | def load_audio(file: str, sr: int = SAMPLE_RATE): 22 | """ 23 | Open an audio file and read as mono waveform, resampling as necessary 24 | 25 | Parameters 26 | ---------- 27 | file: str 28 | The audio file to open 29 | 30 | sr: int 31 | The sample rate to resample the audio if necessary 32 | 33 | Returns 34 | ------- 35 | A NumPy array containing the audio waveform, in float32 dtype. 36 | """ 37 | try: 38 | # This launches a subprocess to decode audio while down-mixing and resampling as necessary. 39 | # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. 40 | out, _ = ( 41 | ffmpeg.input(file, threads=0) 42 | .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr) 43 | .run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True) 44 | ) 45 | except ffmpeg.Error as e: 46 | raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e 47 | 48 | return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 49 | 50 | 51 | def pad_or_trim(array: np.ndarray, length: int = N_SAMPLES, *, axis: int = -1): 52 | """ 53 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder. 54 | """ 55 | if array.shape[axis] > length: 56 | array = array.take(indices=range(length), axis=axis) 57 | 58 | if array.shape[axis] < length: 59 | pad_widths = [(0, 0)] * array.ndim 60 | pad_widths[axis] = (0, length - array.shape[axis]) 61 | array = np.pad(array, pad_widths) 62 | 63 | return array 64 | 65 | 66 | @lru_cache(maxsize=None) 67 | def mel_filters(n_mels: int = N_MELS) -> np.ndarray: 68 | """ 69 | load the mel filterbank matrix for projecting STFT into a Mel spectrogram. 70 | Allows decoupling librosa dependency; saved using: 71 | 72 | np.savez_compressed( 73 | "mel_filters.npz", 74 | mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), 75 | ) 76 | """ 77 | assert n_mels == 80, f"Unsupported n_mels: {n_mels}" 78 | with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f: 79 | return f[f"mel_{n_mels}"] 80 | 81 | def cupy_sliding_window_view(x, window_shape, step=1): 82 | shape = ((x.shape[-1] - window_shape) // step + 1,) + (window_shape,) 83 | strides = (step * x.strides[-1],) + x.strides 84 | return cp.lib.stride_tricks.as_strided(x, shape=shape, strides=strides) 85 | 86 | 87 | def cupy_stft(audio: np.ndarray, N_FFT: int, HOP_LENGTH: int): 88 | window = cp.hanning(N_FFT) 89 | cpaudio = cp.asarray(audio) 90 | num_frames = 1 + (cpaudio.size - N_FFT) // HOP_LENGTH 91 | if (cpaudio.size - N_FFT) % HOP_LENGTH > 0: 92 | num_frames += 1 93 | audio_padded = cp.pad(cpaudio, pad_width=(N_FFT//2, N_FFT//2), mode='constant') 94 | frames = cupy_sliding_window_view(audio_padded, N_FFT, HOP_LENGTH) 95 | frames = frames[:num_frames] 96 | stft = cp.fft.rfft(frames * window, axis=-1) 97 | 98 | cpstft = (cp.abs(stft[:,:N_FFT//2 + 1]) ** 2).T 99 | magnitudes = cp.asnumpy(cpstft).astype(audio.dtype) 100 | return magnitudes 101 | 102 | 103 | def numpy_sliding_window_view(x, window_shape, step=1): 104 | shape = ((x.shape[-1] - window_shape) // step + 1,) + (window_shape,) 105 | strides = (step * x.strides[-1],) + x.strides 106 | return np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides) 107 | 108 | 109 | def numpy_stft(audio: np.ndarray, N_FFT: int, HOP_LENGTH: int): 110 | window = np.hanning(N_FFT) 111 | num_frames = 1 + (audio.size - N_FFT) // HOP_LENGTH 112 | if (audio.size - N_FFT) % HOP_LENGTH > 0: 113 | num_frames += 1 114 | audio_padded = np.pad(audio, pad_width=(N_FFT//2, N_FFT//2), mode='constant') 115 | frames = numpy_sliding_window_view(audio_padded, N_FFT, HOP_LENGTH) 116 | frames = frames[:num_frames] 117 | stft = np.fft.rfft(frames * window, axis=-1) 118 | 119 | cpstft = (np.abs(stft[:,:N_FFT//2 + 1]) ** 2).T 120 | magnitudes = cpstft.astype(audio.dtype) 121 | return magnitudes 122 | 123 | 124 | def log_mel_spectrogram(audio: Union[str, np.ndarray], disable_cupy: bool, n_mels: int = N_MELS): 125 | """ 126 | Compute the log-Mel spectrogram of 127 | 128 | Parameters 129 | ---------- 130 | audio: Union[str, np.ndarray], shape = (*) 131 | The path to audio or either a NumPy array containing the audio waveform in 16 kHz 132 | 133 | n_mels: int 134 | The number of Mel-frequency filters, only 80 is supported 135 | 136 | Returns 137 | ------- 138 | np.ndarray, shape = (80, n_frames) 139 | A Tensor that contains the Mel spectrogram 140 | """ 141 | if isinstance(audio, str): 142 | audio = load_audio(audio) 143 | 144 | if not disable_cupy: 145 | magnitudes = cupy_stft(audio, N_FFT, HOP_LENGTH) 146 | else: 147 | magnitudes = numpy_stft(audio, N_FFT, HOP_LENGTH) 148 | 149 | filters = mel_filters(n_mels) 150 | mel_spec = filters @ magnitudes 151 | 152 | log_spec = np.log10(np.clip(mel_spec, a_min=1e-10, a_max=None)) 153 | log_spec = np.maximum(log_spec, np.max(log_spec) - 8.0) 154 | log_spec = (log_spec + 4.0) / 4.0 155 | return log_spec 156 | -------------------------------------------------------------------------------- /whisper/decoding.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING 3 | 4 | import numpy as np 5 | 6 | from whisper.audio import CHUNK_LENGTH 7 | from whisper.tokenizer import Tokenizer, get_tokenizer 8 | from whisper.utils import compression_ratio 9 | 10 | if TYPE_CHECKING: 11 | from whisper.model import Whisper 12 | 13 | 14 | def softmax(x, dim=-1): 15 | e_x = np.exp(x - np.max(x, axis=dim, keepdims=True)) 16 | return e_x / (np.sum(e_x, axis=dim, keepdims=True)) 17 | 18 | def log_softmax(x, dim=-1): 19 | y = softmax(x, dim=dim) 20 | return np.log(y) 21 | 22 | def numpy_categorical_sample(logits, temperature): 23 | logits /= temperature 24 | probs = softmax(logits, dim=-1) 25 | return np.array([np.random.choice(len(p), p=p) for p in probs]) 26 | 27 | 28 | def detect_language(model: "Whisper", mel: np.ndarray, tokenizer: Tokenizer = None) -> Tuple[np.ndarray, List[dict]]: 29 | """ 30 | Detect the spoken language in the audio, and return them as list of strings, along with the ids 31 | of the most probable language tokens and the probability distribution over all language tokens. 32 | This is performed outside the main decode loop in order to not interfere with kv-caching. 33 | 34 | Returns 35 | ------- 36 | language_tokens : np.ndarray, shape = (n_audio,) 37 | ids of the most probable language tokens, which appears after the startoftranscript token. 38 | language_probs : List[Dict[str, float]], length = n_audio 39 | list of dictionaries containing the probability distribution over all languages. 40 | """ 41 | if tokenizer is None: 42 | tokenizer = get_tokenizer(model.is_multilingual) 43 | if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence: 44 | raise ValueError(f"This model doesn't have language tokens so it can't perform lang id") 45 | 46 | single = mel.ndim == 2 47 | if single: 48 | mel = mel[np.newaxis, ...] 49 | 50 | # skip encoder forward pass if already-encoded audio features were given 51 | if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state): 52 | mel = model.encoder(mel) 53 | 54 | # forward pass using a single token, startoftranscript 55 | n_audio = mel.shape[0] 56 | x = np.array([[tokenizer.sot]] * n_audio) # [n_audio, 1] 57 | logits = model.logits(x, mel)[:, 0] 58 | 59 | # collect detected languages; suppress all non-language tokens 60 | mask = np.ones(logits.shape[-1], dtype=np.bool_) 61 | mask[list(tokenizer.all_language_tokens)] = False 62 | logits[:, mask] = -np.inf 63 | language_tokens = logits.argmax(axis=-1) 64 | language_token_probs = softmax(logits, dim=-1) 65 | language_probs = [ 66 | { 67 | c: language_token_probs[i, j].item() 68 | for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes) 69 | } 70 | for i in range(n_audio) 71 | ] 72 | 73 | if single: 74 | language_tokens = language_tokens[0] 75 | language_probs = language_probs[0] 76 | 77 | return language_tokens, language_probs 78 | 79 | 80 | @dataclass(frozen=True) 81 | class DecodingOptions: 82 | task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate" 83 | language: Optional[str] = None # language that the audio is in; uses detected language if None 84 | 85 | # sampling-related options 86 | temperature: float = 0.0 87 | sample_len: Optional[int] = None # maximum number of tokens to sample 88 | best_of: Optional[int] = None # number of independent samples to collect, when t > 0 89 | beam_size: Optional[int] = None # number of beams in beam search, when t == 0 90 | patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424) 91 | 92 | # options for ranking generations (either beams or best-of-N samples) 93 | length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm 94 | 95 | # prompt, prefix, and token suppression 96 | prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context 97 | prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context 98 | suppress_blank: bool = True # this will suppress blank outputs 99 | 100 | # list of tokens ids (or comma-separated token ids) to suppress 101 | # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()` 102 | suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1" 103 | 104 | # timestamp sampling options 105 | without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only 106 | max_initial_timestamp: Optional[float] = 0.0 # the initial timestamp cannot be later than this 107 | 108 | 109 | @dataclass(frozen=True) 110 | class DecodingResult: 111 | audio_features: np.ndarray 112 | language: str 113 | language_probs: Optional[Dict[str, float]] = None 114 | tokens: List[int] = field(default_factory=list) 115 | text: str = "" 116 | avg_logprob: float = np.nan 117 | no_speech_prob: float = np.nan 118 | temperature: float = np.nan 119 | compression_ratio: float = np.nan 120 | 121 | 122 | class Inference: 123 | def logits(self, tokens: np.ndarray, audio_features: np.ndarray) -> np.ndarray: 124 | """Perform a forward pass on the decoder and return per-token logits""" 125 | raise NotImplementedError 126 | 127 | def rearrange_kv_cache(self, source_indices) -> None: 128 | """Update the key-value cache according to the updated beams""" 129 | raise NotImplementedError 130 | 131 | def cleanup_caching(self) -> None: 132 | """Clean up any resources or hooks after decoding is finished""" 133 | pass 134 | 135 | 136 | class OnnxInference(Inference): 137 | def __init__(self, model: "Whisper", initial_token_length: int): 138 | self.model: "Whisper" = model 139 | self.initial_token_length = initial_token_length 140 | self.kv_cache = None 141 | 142 | def logits(self, tokens: np.ndarray, audio_features: np.ndarray) -> np.ndarray: 143 | n_group = tokens.shape[0] 144 | if self.kv_cache is None: 145 | self.kv_cache = self.model.new_kv_cache(n_group, self.initial_token_length) 146 | offset = 0 147 | else: 148 | offset = self.kv_cache.shape[2] 149 | new_kv_cache = self.model.new_kv_cache(n_group, offset + 1) 150 | new_kv_cache[:, :, :-1, :] = self.kv_cache 151 | self.kv_cache = new_kv_cache 152 | 153 | if tokens.shape[-1] > self.initial_token_length: 154 | # only need to use the last token except in the first forward pass 155 | tokens = tokens[:, -1:] 156 | 157 | output, self.kv_cache = self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache, offset=offset) 158 | return output 159 | 160 | def cleanup_caching(self): 161 | self.kv_cache = None 162 | 163 | def rearrange_kv_cache(self, source_indices): 164 | self.kv_cache = self.kv_cache[:, source_indices] 165 | 166 | 167 | class SequenceRanker: 168 | def rank(self, tokens: List[List[np.ndarray]], sum_logprobs: List[List[float]]) -> List[int]: 169 | """ 170 | Given a list of groups of samples and their cumulative log probabilities, 171 | return the indices of the samples in each group to select as the final result 172 | """ 173 | raise NotImplementedError 174 | 175 | 176 | class MaximumLikelihoodRanker(SequenceRanker): 177 | """ 178 | Select the sample with the highest log probabilities, penalized using either 179 | a simple length normalization or Google NMT paper's length penalty 180 | """ 181 | 182 | def __init__(self, length_penalty: Optional[float]): 183 | self.length_penalty = length_penalty 184 | 185 | def rank(self, tokens: List[List[np.ndarray]], sum_logprobs: List[List[float]]): 186 | def scores(logprobs, lengths): 187 | result = [] 188 | for logprob, length in zip(logprobs, lengths): 189 | if self.length_penalty is None: 190 | penalty = length 191 | else: 192 | # from the Google NMT paper 193 | penalty = ((5 + length) / 6) ** self.length_penalty 194 | result.append(logprob / penalty) 195 | return result 196 | 197 | # get the sequence with the highest score 198 | lengths = [[len(t) for t in s] for s in tokens] 199 | return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)] 200 | 201 | 202 | class TokenDecoder: 203 | def reset(self): 204 | """Initialize any stateful variables for decoding a new sequence""" 205 | 206 | def update(self, tokens: np.ndarray, logits: np.ndarray, sum_logprobs: np.ndarray) -> Tuple[np.ndarray, bool]: 207 | """Specify how to select the next token, based on the current trace and logits 208 | 209 | Parameters 210 | ---------- 211 | tokens : np.ndarray, shape = (n_batch, current_sequence_length) 212 | all tokens in the context so far, including the prefix and sot_sequence tokens 213 | 214 | logits : np.ndarray, shape = (n_batch, vocab_size) 215 | per-token logits of the probability distribution at the current step 216 | 217 | sum_logprobs : np.ndarray, shape = (n_batch) 218 | cumulative log probabilities for each sequence 219 | 220 | Returns 221 | ------- 222 | tokens : np.ndarray, shape = (n_batch, current_sequence_length + 1) 223 | the tokens, appended with the selected next token 224 | 225 | completed : bool 226 | True if all sequences has reached the end of text 227 | 228 | """ 229 | raise NotImplementedError 230 | 231 | def finalize( 232 | self, tokens: np.ndarray, sum_logprobs: np.ndarray 233 | ) -> Tuple[Sequence[Sequence[np.ndarray]], List[List[float]]]: 234 | """Finalize search and return the final candidate sequences 235 | 236 | Parameters 237 | ---------- 238 | tokens : np.ndarray, shape = (n_audio, n_group, current_sequence_length) 239 | all tokens in the context so far, including the prefix and sot_sequence 240 | 241 | sum_logprobs : np.ndarray, shape = (n_audio, n_group) 242 | cumulative log probabilities for each sequence 243 | 244 | Returns 245 | ------- 246 | tokens : Sequence[Sequence[np.ndarray]], length = n_audio 247 | sequence of Tensors containing candidate token sequences, for each audio input 248 | 249 | sum_logprobs : List[List[float]], length = n_audio 250 | sequence of cumulative log probabilities corresponding to the above 251 | 252 | """ 253 | raise NotImplementedError 254 | 255 | 256 | class GreedyDecoder(TokenDecoder): 257 | def __init__(self, temperature: float, eot: int): 258 | self.temperature = temperature 259 | self.eot = eot 260 | 261 | def update(self, tokens: np.ndarray, logits: np.ndarray, sum_logprobs: np.ndarray) -> Tuple[np.ndarray, bool]: 262 | temperature = self.temperature 263 | if temperature == 0: 264 | next_tokens = logits.argmax(dim=-1) 265 | else: 266 | next_tokens = numpy_categorical_sample(logits, temperature) 267 | 268 | logprobs = log_softmax(logits, dim=-1) 269 | current_logprobs = logprobs[np.arange(logprobs.shape[0]), next_tokens] 270 | sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot) 271 | 272 | next_tokens[tokens[:, -1] == self.eot] = self.eot 273 | tokens = np.concatenate([tokens, next_tokens[:, None]], axis=-1) 274 | 275 | completed = (tokens[:, -1] == self.eot).all() 276 | return tokens, completed 277 | 278 | def finalize(self, tokens: np.ndarray, sum_logprobs: np.ndarray): 279 | # make sure each sequence has at least one EOT token at the end 280 | tokens = np.pad(tokens, (0, 1), constant_values=self.eot) 281 | return tokens, sum_logprobs.tolist() 282 | 283 | 284 | class BeamSearchDecoder(TokenDecoder): 285 | def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None): 286 | self.beam_size = beam_size 287 | self.eot = eot 288 | self.inference = inference 289 | self.patience = patience or 1.0 290 | self.max_candidates: int = round(beam_size * self.patience) 291 | self.finished_sequences = None 292 | 293 | assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})" 294 | 295 | def reset(self): 296 | self.finished_sequences = None 297 | 298 | def update(self, tokens: np.ndarray, logits: np.ndarray, sum_logprobs: np.ndarray) -> Tuple[np.ndarray, bool]: 299 | if tokens.shape[0] % self.beam_size != 0: 300 | raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0") 301 | 302 | n_audio = tokens.shape[0] // self.beam_size 303 | if self.finished_sequences is None: # for the first update 304 | self.finished_sequences = [{} for _ in range(n_audio)] 305 | 306 | logprobs = log_softmax(logits, dim=-1) 307 | next_tokens, source_indices, finished_sequences = [], [], [] 308 | for i in range(n_audio): 309 | scores, sources, finished = {}, {}, {} 310 | 311 | # STEP 1: calculate the cumulative log probabilities for possible candidates 312 | for j in range(self.beam_size): 313 | idx = i * self.beam_size + j 314 | prefix = list(tokens[idx]) 315 | topk_values, topk_indices = \ 316 | -np.partition(-logprobs[idx], self.beam_size + 1)[:self.beam_size + 1], np.argpartition(-logprobs[idx], self.beam_size + 1)[:self.beam_size + 1] 317 | 318 | sort_indices = np.argsort(-topk_values) 319 | topk_values = topk_values[sort_indices] 320 | topk_indices = topk_indices[sort_indices] 321 | """ OK 322 | topk_values 323 | array([-0.99763197, -3.107007 , -3.310132 , -3.357007 ], dtype=float32) 324 | topk_indices 325 | array([50364, 50474, 50478, 50472]) 326 | """ 327 | for logprob, token in zip(topk_values, topk_indices): 328 | new_logprob = (sum_logprobs[idx] + logprob) 329 | sequence = tuple(prefix + [token]) 330 | scores[sequence] = new_logprob 331 | sources[sequence] = idx 332 | 333 | # STEP 2: rank the candidates and keep the top beam_size sequences for each audio 334 | saved = 0 335 | for sequence in sorted(scores, key=scores.get, reverse=True): 336 | if sequence[-1] == self.eot: 337 | finished[sequence] = scores[sequence] 338 | else: 339 | sum_logprobs[len(next_tokens)] = scores[sequence] 340 | next_tokens.append(sequence) 341 | source_indices.append(sources[sequence]) 342 | 343 | saved += 1 344 | if saved == self.beam_size: 345 | break 346 | 347 | finished_sequences.append(finished) 348 | 349 | tokens = np.array(next_tokens) 350 | self.inference.rearrange_kv_cache(source_indices) 351 | 352 | # add newly finished sequences to self.finished_sequences 353 | assert len(self.finished_sequences) == len(finished_sequences) 354 | for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences): 355 | for seq in sorted(newly_finished, key=newly_finished.get, reverse=True): 356 | if len(previously_finished) >= self.max_candidates: 357 | break # the candidate list is full 358 | previously_finished[seq] = newly_finished[seq] 359 | 360 | # mark as completed if all audio has enough number of samples 361 | completed = all( 362 | len(sequences) >= self.max_candidates for sequences in self.finished_sequences 363 | ) 364 | return tokens, completed 365 | 366 | def finalize(self, preceding_tokens: np.ndarray, sum_logprobs: np.ndarray): 367 | # collect all finished sequences, including patience, and add unfinished ones if not enough 368 | sum_logprobs = sum_logprobs 369 | for i, sequences in enumerate(self.finished_sequences): 370 | if len(sequences) < self.beam_size: # when not enough sequences are finished 371 | for j in list(np.argsort(sum_logprobs[i]))[::-1]: 372 | sequence = list(preceding_tokens[i, j]) + [self.eot] 373 | sequences[tuple(sequence)] = sum_logprobs[i][j] 374 | if len(sequences) >= self.beam_size: 375 | break 376 | 377 | tokens: List[List[np.ndarray]] = [ 378 | [np.array(seq) for seq in sequences.keys()] for sequences in self.finished_sequences 379 | ] 380 | sum_logprobs: List[List[float]] = [ 381 | list(sequences.values()) for sequences in self.finished_sequences 382 | ] 383 | return tokens, sum_logprobs 384 | 385 | 386 | class LogitFilter: 387 | def apply(self, logits: np.ndarray, tokens: np.ndarray) -> None: 388 | """Apply any filtering or masking to logits in-place 389 | 390 | Parameters 391 | ---------- 392 | logits : np.ndarray, shape = (n_batch, vocab_size) 393 | per-token logits of the probability distribution at the current step 394 | 395 | tokens : np.ndarray, shape = (n_batch, current_sequence_length) 396 | all tokens in the context so far, including the prefix and sot_sequence tokens 397 | 398 | """ 399 | raise NotImplementedError 400 | 401 | 402 | class SuppressBlank(LogitFilter): 403 | def __init__(self, tokenizer: Tokenizer, sample_begin: int): 404 | self.tokenizer = tokenizer 405 | self.sample_begin = sample_begin 406 | 407 | def apply(self, logits: np.ndarray, tokens: np.ndarray): 408 | if tokens.shape[1] == self.sample_begin: 409 | logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf 410 | 411 | 412 | class SuppressTokens(LogitFilter): 413 | def __init__(self, suppress_tokens: Sequence[int]): 414 | self.suppress_tokens = list(suppress_tokens) 415 | 416 | def apply(self, logits: np.ndarray, tokens: np.ndarray): 417 | logits[:, self.suppress_tokens] = -np.inf 418 | 419 | 420 | class ApplyTimestampRules(LogitFilter): 421 | def __init__( 422 | self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int] 423 | ): 424 | self.tokenizer = tokenizer 425 | self.sample_begin = sample_begin 426 | self.max_initial_timestamp_index = max_initial_timestamp_index 427 | 428 | def apply(self, logits: np.ndarray, tokens: np.ndarray): 429 | # suppress <|notimestamps|> which is handled by without_timestamps 430 | if self.tokenizer.no_timestamps is not None: 431 | logits[:, self.tokenizer.no_timestamps] = -np.inf 432 | 433 | # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly 434 | for k in range(tokens.shape[0]): 435 | seq = [t for t in tokens[k, self.sample_begin :].tolist()] 436 | last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin 437 | penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin 438 | 439 | if last_was_timestamp: 440 | if penultimate_was_timestamp: # has to be non-timestamp 441 | logits[k, self.tokenizer.timestamp_begin :] = -np.inf 442 | else: # cannot be normal text tokens 443 | logits[k, : self.tokenizer.eot] = -np.inf 444 | 445 | # apply the `max_initial_timestamp` option 446 | if tokens.shape[1] == self.sample_begin and self.max_initial_timestamp_index is not None: 447 | last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index 448 | logits[:, last_allowed + 1 :] = -np.inf 449 | 450 | # if sum of probability over timestamps is above any other token, sample timestamp 451 | logprobs = log_softmax(logits, dim=-1) 452 | for k in range(tokens.shape[0]): 453 | max_val = np.max(logprobs[k, self.tokenizer.timestamp_begin :], axis=-1, keepdims=True) 454 | timestamp_logprob = np.squeeze(max_val) + np.log(np.sum(np.exp(logprobs[k, self.tokenizer.timestamp_begin :] - max_val), axis=-1)) 455 | max_text_token_logprob = np.max(logprobs[k, : self.tokenizer.timestamp_begin]) 456 | if timestamp_logprob > max_text_token_logprob: 457 | logits[k, : self.tokenizer.timestamp_begin] = -np.inf 458 | 459 | 460 | class DecodingTask: 461 | inference: Inference 462 | sequence_ranker: SequenceRanker 463 | decoder: TokenDecoder 464 | logit_filters: List[LogitFilter] 465 | 466 | def __init__(self, model: "Whisper", options: DecodingOptions): 467 | self.model = model 468 | 469 | language = options.language or "en" 470 | tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task) 471 | self.tokenizer: Tokenizer = tokenizer 472 | self.options: DecodingOptions = self._verify_options(options) 473 | 474 | self.n_group: int = options.beam_size or options.best_of or 1 475 | self.n_ctx: int = model.dims.n_text_ctx 476 | self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2 477 | 478 | self.sot_sequence: Tuple[int] = tokenizer.sot_sequence 479 | if self.options.without_timestamps: 480 | self.sot_sequence = tokenizer.sot_sequence_including_notimestamps 481 | 482 | self.initial_tokens: Tuple[int] = self._get_initial_tokens() 483 | self.sample_begin: int = len(self.initial_tokens) 484 | self.sot_index: int = self.initial_tokens.index(tokenizer.sot) 485 | 486 | # inference: implements the forward pass through the decoder, including kv caching 487 | self.inference = OnnxInference(model, len(self.initial_tokens)) 488 | 489 | # sequence ranker: implements how to rank a group of sampled sequences 490 | self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty) 491 | 492 | # decoder: implements how to select the next tokens, given the autoregressive distribution 493 | if options.beam_size is not None: 494 | self.decoder = BeamSearchDecoder( 495 | options.beam_size, tokenizer.eot, self.inference, options.patience 496 | ) 497 | else: 498 | self.decoder = GreedyDecoder(options.temperature, tokenizer.eot) 499 | 500 | # logit filters: applies various rules to suppress or penalize certain tokens 501 | self.logit_filters = [] 502 | if self.options.suppress_blank: 503 | self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin)) 504 | if self.options.suppress_tokens: 505 | self.logit_filters.append(SuppressTokens(self._get_suppress_tokens())) 506 | if not options.without_timestamps: 507 | precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds 508 | max_initial_timestamp_index = None 509 | if options.max_initial_timestamp: 510 | max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision) 511 | self.logit_filters.append( 512 | ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index) 513 | ) 514 | 515 | def _verify_options(self, options: DecodingOptions) -> DecodingOptions: 516 | if options.beam_size is not None and options.best_of is not None: 517 | raise ValueError("beam_size and best_of can't be given together") 518 | if options.temperature == 0: 519 | if options.best_of is not None: 520 | raise ValueError("best_of with greedy sampling (T=0) is not compatible") 521 | if options.patience is not None and options.beam_size is None: 522 | raise ValueError("patience requires beam_size to be given") 523 | if options.length_penalty is not None and not (0 <= options.length_penalty <= 1): 524 | raise ValueError("length_penalty (alpha) should be a value between 0 and 1") 525 | 526 | return options 527 | 528 | def _get_initial_tokens(self) -> Tuple[int]: 529 | tokens = list(self.sot_sequence) 530 | prefix = self.options.prefix 531 | prompt = self.options.prompt 532 | 533 | if prefix: 534 | prefix_tokens = ( 535 | self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix 536 | ) 537 | if self.sample_len is not None: 538 | max_prefix_len = self.n_ctx // 2 - self.sample_len 539 | prefix_tokens = prefix_tokens[-max_prefix_len:] 540 | tokens = tokens + prefix_tokens 541 | 542 | if prompt: 543 | prompt_tokens = ( 544 | self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt 545 | ) 546 | tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens 547 | 548 | return tuple(tokens) 549 | 550 | def _get_suppress_tokens(self) -> Tuple[int]: 551 | suppress_tokens = self.options.suppress_tokens 552 | 553 | if isinstance(suppress_tokens, str): 554 | suppress_tokens = [int(t) for t in suppress_tokens.split(",")] 555 | 556 | if -1 in suppress_tokens: 557 | suppress_tokens = [t for t in suppress_tokens if t >= 0] 558 | suppress_tokens.extend(self.tokenizer.non_speech_tokens) 559 | elif suppress_tokens is None or len(suppress_tokens) == 0: 560 | suppress_tokens = [] # interpret empty string as an empty list 561 | else: 562 | assert isinstance(suppress_tokens, list), "suppress_tokens must be a list" 563 | 564 | suppress_tokens.extend( 565 | [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm] 566 | ) 567 | if self.tokenizer.no_speech is not None: 568 | # no-speech probability is collected separately 569 | suppress_tokens.append(self.tokenizer.no_speech) 570 | 571 | return tuple(sorted(set(suppress_tokens))) 572 | 573 | def _get_audio_features(self, mel: np.ndarray): 574 | if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state): 575 | # encoded audio features are given; skip audio encoding 576 | audio_features = mel 577 | else: 578 | audio_features = self.model.encoder(mel) 579 | 580 | return audio_features 581 | 582 | def _detect_language(self, audio_features: np.ndarray, tokens: np.ndarray): 583 | languages = [self.options.language] * audio_features.shape[0] 584 | lang_probs = None 585 | 586 | if self.options.language is None or self.options.task == "lang_id": 587 | lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer) 588 | languages = [max(probs, key=probs.get) for probs in lang_probs] 589 | if self.options.language is None: 590 | tokens[:, self.sot_index + 1] = lang_tokens # write language tokens 591 | 592 | return languages, lang_probs 593 | 594 | def _main_loop(self, audio_features: np.ndarray, tokens: np.ndarray): 595 | assert audio_features.shape[0] == tokens.shape[0] 596 | n_batch = tokens.shape[0] 597 | sum_logprobs: np.ndarray = np.zeros(n_batch) 598 | no_speech_probs = [np.nan] * n_batch 599 | 600 | try: 601 | for i in range(self.sample_len): 602 | logits = self.inference.logits(tokens, audio_features) 603 | 604 | if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs 605 | probs_at_sot = softmax(logits[:, self.sot_index], dim=-1) 606 | no_speech_probs = list(probs_at_sot[:, self.tokenizer.no_speech]) 607 | 608 | # now we need to consider the logits at the last token only 609 | logits = logits[:, -1] 610 | 611 | # apply the logit filters, e.g. for suppressing or applying penalty to 612 | for logit_filter in self.logit_filters: 613 | logit_filter.apply(logits, tokens) 614 | 615 | # expand the tokens tensor with the selected next tokens 616 | tokens, completed = self.decoder.update(tokens, logits, sum_logprobs) 617 | 618 | if completed or tokens.shape[-1] > self.n_ctx: 619 | break 620 | finally: 621 | self.inference.cleanup_caching() 622 | 623 | return tokens, sum_logprobs, no_speech_probs 624 | 625 | def run(self, mel: np.ndarray) -> List[DecodingResult]: 626 | self.decoder.reset() 627 | tokenizer: Tokenizer = self.tokenizer 628 | n_audio: int = mel.shape[0] 629 | 630 | audio_features: np.ndarray = self._get_audio_features(mel) # encoder forward pass 631 | token = np.array([self.initial_tokens]) 632 | tokens: np.ndarray = np.broadcast_to(token, (n_audio, token.shape[1])) 633 | 634 | # detect language if requested, overwriting the language token 635 | languages, language_probs = self._detect_language(audio_features, tokens) 636 | if self.options.task == "lang_id": 637 | return [ 638 | DecodingResult(audio_features=features, language=language, language_probs=probs) 639 | for features, language, probs in zip(audio_features, languages, language_probs) 640 | ] 641 | 642 | # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling 643 | audio_features = np.repeat(a=audio_features, repeats=self.n_group, axis=0) 644 | tokens = np.repeat(a=tokens, repeats=self.n_group, axis=0) 645 | 646 | # call the main sampling loop 647 | tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens) 648 | 649 | # reshape the tensors to have (n_audio, n_group) as the first two dimensions 650 | audio_features = audio_features[:: self.n_group] 651 | no_speech_probs = no_speech_probs[:: self.n_group] 652 | assert audio_features.shape[0] == len(no_speech_probs) == n_audio 653 | 654 | tokens = tokens.reshape(n_audio, self.n_group, -1) 655 | sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group) 656 | 657 | # get the final candidates for each group, and slice between the first sampled token and EOT 658 | tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs) 659 | tokens: List[List[np.ndarray]] = [ 660 | [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0][0]] for t in s] for s in tokens 661 | ] 662 | 663 | # select the top-ranked sample in each group 664 | selected = self.sequence_ranker.rank(tokens, sum_logprobs) 665 | tokens: List[List[int]] = [list(t[i]) for i, t in zip(selected, tokens)] 666 | texts: List[str] = [tokenizer.decode(t).strip() for t in tokens] 667 | 668 | sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)] 669 | avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)] 670 | 671 | fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs) 672 | if len(set(map(len, fields))) != 1: 673 | raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}") 674 | 675 | return [ 676 | DecodingResult( 677 | audio_features=features, 678 | language=language, 679 | tokens=tokens, 680 | text=text, 681 | avg_logprob=avg_logprob, 682 | no_speech_prob=no_speech_prob, 683 | temperature=self.options.temperature, 684 | compression_ratio=compression_ratio(text), 685 | ) 686 | for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields) 687 | ] 688 | 689 | 690 | def decode(model: "Whisper", mel: np.ndarray, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]: 691 | """ 692 | Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s). 693 | 694 | Parameters 695 | ---------- 696 | model: Whisper 697 | the Whisper model instance 698 | 699 | mel: np.ndarray, shape = (80, 3000) or (*, 80, 3000) 700 | A tensor containing the Mel spectrogram(s) 701 | 702 | options: DecodingOptions 703 | A dataclass that contains all necessary options for decoding 30-second segments 704 | 705 | Returns 706 | ------- 707 | result: Union[DecodingResult, List[DecodingResult]] 708 | The result(s) of decoding contained in `DecodingResult` dataclass instance(s) 709 | """ 710 | single = mel.ndim == 2 711 | if single: 712 | mel = mel[np.newaxis, ...] 713 | 714 | result = DecodingTask(model, options).run(mel) 715 | 716 | if single: 717 | result = result[0] 718 | 719 | return result 720 | -------------------------------------------------------------------------------- /whisper/model.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | from dataclasses import dataclass 4 | from typing import List, Dict, Tuple 5 | import numpy as np 6 | import requests 7 | import onnx 8 | import onnxruntime as ort 9 | from whisper.decoding import detect_language as detect_language_function, decode as decode_function 10 | from whisper.utils import onnx_dtype_to_np_dtype_convert 11 | 12 | 13 | _MODELS = { 14 | "tiny.en": "https://s3.ap-northeast-2.wasabisys.com/pinto-model-zoo/381_Whisper/pt/tiny.en.pt", 15 | "tiny": "https://s3.ap-northeast-2.wasabisys.com/pinto-model-zoo/381_Whisper/pt/tiny.pt", 16 | "base.en": "https://s3.ap-northeast-2.wasabisys.com/pinto-model-zoo/381_Whisper/pt/base.en.pt", 17 | "base": "https://s3.ap-northeast-2.wasabisys.com/pinto-model-zoo/381_Whisper/pt/base.pt", 18 | "small.en": "https://s3.ap-northeast-2.wasabisys.com/pinto-model-zoo/381_Whisper/pt/small.en.pt", 19 | "small": "https://s3.ap-northeast-2.wasabisys.com/pinto-model-zoo/381_Whisper/pt/small.pt", 20 | "medium.en": "https://s3.ap-northeast-2.wasabisys.com/pinto-model-zoo/381_Whisper/pt/medium.en.pt", 21 | "medium": "https://s3.ap-northeast-2.wasabisys.com/pinto-model-zoo/381_Whisper/pt/medium.pt", 22 | "large-v1": "https://s3.ap-northeast-2.wasabisys.com/pinto-model-zoo/381_Whisper/pt/large-v1.pt", 23 | "large-v2": "https://s3.ap-northeast-2.wasabisys.com/pinto-model-zoo/381_Whisper/pt/large-v2.pt", 24 | } 25 | 26 | def model_download(name: str, onnx_file_save_path: str='.') -> onnx.ModelProto: 27 | onnx_file = f'{name}_11_layer_fused_optimization_float16.onnx' 28 | onnx_file_path = f'{onnx_file_save_path}/{onnx_file}' 29 | onnx_serialized_graph = None 30 | if not os.path.exists(onnx_file_path): 31 | try: 32 | url = f'https://s3.ap-northeast-2.wasabisys.com/pinto-model-zoo/381_Whisper/onnx/whisper-onnx-xxx/float16/layer_fused_optimization_float16/{onnx_file}' 33 | onnx_serialized_graph = requests.get(url).content 34 | with io.BytesIO(onnx_serialized_graph) as f: 35 | onnx_graph: onnx.ModelProto = onnx.load(f) 36 | onnx.save(onnx_graph, f'{onnx_file_path}') 37 | except: 38 | onnx_file = f'{name}_11_float16.onnx' 39 | onnx_file_path = f'{onnx_file_save_path}/{onnx_file}' 40 | if not os.path.exists(onnx_file_path): 41 | url = f'https://s3.ap-northeast-2.wasabisys.com/pinto-model-zoo/381_Whisper/onnx/whisper-onnx-xxx/float16/no_optimization/{onnx_file}' 42 | onnx_serialized_graph = requests.get(url).content 43 | with io.BytesIO(onnx_serialized_graph) as f: 44 | onnx_graph: onnx.ModelProto = onnx.load(f) 45 | onnx.save(onnx_graph, f'{onnx_file_path}') 46 | else: 47 | onnx_graph: onnx.ModelProto = onnx.load(onnx_file_path) 48 | onnx_serialized_graph = onnx._serialize(onnx_graph) 49 | else: 50 | onnx_graph: onnx.ModelProto = onnx.load(onnx_file_path) 51 | onnx_serialized_graph = onnx._serialize(onnx_graph) 52 | return onnx_serialized_graph 53 | 54 | def load_model(name: str): 55 | """ 56 | Load a Whisper ASR model 57 | 58 | Parameters 59 | ---------- 60 | name : str 61 | one of the official model names listed by `whisper.available_models()`, or 62 | path to a model checkpoint containing the model dimensions and the model state_dict. 63 | 64 | Returns 65 | ------- 66 | model : Whisper 67 | The Whisper ASR model instance 68 | """ 69 | 70 | if name == "tiny": 71 | dims_config = {'n_mels': 80, 'n_vocab': 51865, 'n_audio_ctx': 1500, 'n_audio_state': 384, 'n_audio_head': 6, 'n_audio_layer': 4, 'n_text_ctx': 448, 'n_text_state': 384, 'n_text_head': 6, 'n_text_layer': 4} 72 | elif name == "tiny.en": 73 | dims_config = {'n_mels': 80, 'n_vocab': 51864, 'n_audio_ctx': 1500, 'n_audio_state': 384, 'n_audio_head': 6, 'n_audio_layer': 4, 'n_text_ctx': 448, 'n_text_state': 384, 'n_text_head': 6, 'n_text_layer': 4} 74 | elif name == "base": 75 | dims_config = {'n_mels': 80, 'n_vocab': 51865, 'n_audio_ctx': 1500, 'n_audio_state': 512, 'n_audio_head': 8, 'n_audio_layer': 6, 'n_text_ctx': 448, 'n_text_state': 512, 'n_text_head': 8, 'n_text_layer': 6} 76 | elif name == "base.en": 77 | dims_config = {'n_mels': 80, 'n_vocab': 51864, 'n_audio_ctx': 1500, 'n_audio_state': 512, 'n_audio_head': 8, 'n_audio_layer': 6, 'n_text_ctx': 448, 'n_text_state': 512, 'n_text_head': 8, 'n_text_layer': 6} 78 | elif name == "small": 79 | dims_config = {'n_mels': 80, 'n_vocab': 51865, 'n_audio_ctx': 1500, 'n_audio_state': 768, 'n_audio_head': 12, 'n_audio_layer': 12, 'n_text_ctx': 448, 'n_text_state': 768, 'n_text_head': 12, 'n_text_layer': 12} 80 | elif name == "small.en": 81 | dims_config = {'n_mels': 80, 'n_vocab': 51864, 'n_audio_ctx': 1500, 'n_audio_state': 768, 'n_audio_head': 12, 'n_audio_layer': 12, 'n_text_ctx': 448, 'n_text_state': 768, 'n_text_head': 12, 'n_text_layer': 12} 82 | elif name == "medium": 83 | dims_config = {'n_mels': 80, 'n_vocab': 51865, 'n_audio_ctx': 1500, 'n_audio_state': 1024, 'n_audio_head': 16, 'n_audio_layer': 24, 'n_text_ctx': 448, 'n_text_state': 1024, 'n_text_head': 16, 'n_text_layer': 24} 84 | elif name == "medium.en": 85 | dims_config = {'n_mels': 80, 'n_vocab': 51864, 'n_audio_ctx': 1500, 'n_audio_state': 1024, 'n_audio_head': 16, 'n_audio_layer': 24, 'n_text_ctx': 448, 'n_text_state': 1024, 'n_text_head': 16, 'n_text_layer': 24} 86 | elif name == "large-v1": 87 | dims_config = {'n_mels': 80, 'n_vocab': 51865, 'n_audio_ctx': 1500, 'n_audio_state': 1280, 'n_audio_head': 20, 'n_audio_layer': 32, 'n_text_ctx': 448, 'n_text_state': 1280, 'n_text_head': 20, 'n_text_layer': 32} 88 | elif name == "large-v2": 89 | dims_config = {'n_mels': 80, 'n_vocab': 51865, 'n_audio_ctx': 1500, 'n_audio_state': 1280, 'n_audio_head': 20, 'n_audio_layer': 32, 'n_text_ctx': 448, 'n_text_state': 1280, 'n_text_head': 20, 'n_text_layer': 32} 90 | else: 91 | raise ValueError(f"model type {name} not supported") 92 | 93 | dims = ModelDimensions(**dims_config) 94 | model = Whisper(dims=dims, model_name=name) 95 | return model 96 | 97 | def available_models() -> List[str]: 98 | """Returns the names of available models""" 99 | return list(_MODELS.keys()) 100 | 101 | @dataclass 102 | class ModelDimensions: 103 | n_mels: int 104 | n_audio_ctx: int 105 | n_audio_state: int 106 | n_audio_head: int 107 | n_audio_layer: int 108 | n_vocab: int 109 | n_text_ctx: int 110 | n_text_state: int 111 | n_text_head: int 112 | n_text_layer: int 113 | 114 | 115 | class OnnxAudioEncoder(): 116 | def __init__( 117 | self, 118 | model: str, 119 | ): 120 | super().__init__() 121 | 122 | self.sess = \ 123 | ort.InferenceSession( 124 | path_or_bytes=model_download(name=f'{model}_encoder'), 125 | providers=[ 126 | # ( 127 | # 'TensorrtExecutionProvider', { 128 | # 'trt_engine_cache_enable': True, 129 | # 'trt_engine_cache_path': '.', 130 | # 'trt_fp16_enable': True, 131 | # } 132 | # ), 133 | 'CUDAExecutionProvider' 134 | ], 135 | ) 136 | self.inputs = { 137 | input.name: onnx_dtype_to_np_dtype_convert(input.type) \ 138 | for input in self.sess.get_inputs() 139 | } 140 | 141 | def __call__( 142 | self, 143 | mel: np.ndarray 144 | ) -> np.ndarray: 145 | result: np.ndarray = \ 146 | self.sess.run( 147 | output_names=[ 148 | "output", 149 | ], 150 | input_feed={ 151 | "mel": mel.astype(self.inputs["mel"]), 152 | } 153 | )[0] 154 | return result 155 | 156 | 157 | class OnnxTextDecoder(): 158 | def __init__( 159 | self, 160 | model: str, 161 | ): 162 | super().__init__() 163 | 164 | self.sess = \ 165 | ort.InferenceSession( 166 | path_or_bytes=model_download(name=f'{model}_decoder'), 167 | providers=[ 168 | 'CUDAExecutionProvider' 169 | ], 170 | ) 171 | self.inputs = { 172 | input.name: onnx_dtype_to_np_dtype_convert(input.type) \ 173 | for input in self.sess.get_inputs() 174 | } 175 | 176 | def __call__( 177 | self, 178 | x: np.ndarray, 179 | xa: np.ndarray, 180 | kv_cache: np.ndarray, 181 | offset: int, 182 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 183 | outputs = \ 184 | self.sess.run( 185 | output_names=[ 186 | "logits", 187 | "output_kv_cache", 188 | "cross_attention_qks", 189 | ], 190 | input_feed={ 191 | "tokens": x.astype(self.inputs["tokens"]), 192 | "audio_features": xa.astype(self.inputs["audio_features"]), 193 | "kv_cache": kv_cache.astype(self.inputs["kv_cache"]), 194 | "offset": np.array([offset], dtype=self.inputs["offset"]), 195 | } 196 | ) 197 | logits: np.ndarray = outputs[0] 198 | output_kv_cache: np.ndarray = outputs[1] 199 | cross_attention_qks: np.ndarray = outputs[2] 200 | return logits.astype(np.float32), output_kv_cache.astype(np.float32) 201 | 202 | 203 | class Whisper(): 204 | def __init__( 205 | self, 206 | dims: ModelDimensions, 207 | model_name: str, 208 | ): 209 | super().__init__() 210 | self.model_name = model_name 211 | self.dims = dims 212 | self.encoder = OnnxAudioEncoder(model=model_name) 213 | self.decoder = OnnxTextDecoder(model=model_name) 214 | 215 | def embed_audio( 216 | self, 217 | mel: np.ndarray, 218 | ): 219 | return self.encoder(mel) 220 | 221 | def logits( 222 | self, 223 | tokens: np.ndarray, 224 | audio_features: np.ndarray, 225 | ): 226 | kv_cache = self.new_kv_cache(tokens.shape[0], tokens.shape[-1]) 227 | output, _ = self.decoder(tokens, audio_features, kv_cache=kv_cache, offset=0) 228 | return output 229 | 230 | def __call__( 231 | self, 232 | mel: np.ndarray, 233 | tokens: np.ndarray, 234 | ) -> Dict[str, np.ndarray]: 235 | kv_cache = self.new_kv_cache(tokens.shape[0], tokens.shape[-1]) 236 | output, _ = self.decoder(tokens, self.encoder(mel), kv_cache=kv_cache, offset=0) 237 | return output 238 | 239 | @property 240 | def is_multilingual(self): 241 | return self.dims.n_vocab == 51865 242 | 243 | def new_kv_cache( 244 | self, 245 | n_group: int, 246 | length: int, 247 | ): 248 | if self.model_name == "tiny.en" or self.model_name == "tiny": 249 | size = [8, n_group, length, 384] 250 | elif self.model_name == "base.en" or self.model_name == "base": 251 | size = [12, n_group, length, 512] 252 | elif self.model_name == "small.en" or self.model_name == "small": 253 | size = [24, n_group, length, 768] 254 | elif self.model_name == "medium.en" or self.model_name == "medium": 255 | size = [48, n_group, length, 1024] 256 | elif self.model_name == "large-v1" or self.model_name == "large-v2": 257 | size = [64, n_group, length, 1280] 258 | else: 259 | raise ValueError(f"Unsupported model type: {self.type}") 260 | return np.zeros(size, dtype=np.float16) 261 | 262 | detect_language = detect_language_function 263 | # transcribe = transcribe_function 264 | decode = decode_function 265 | -------------------------------------------------------------------------------- /whisper/normalizers/__init__.py: -------------------------------------------------------------------------------- 1 | from whisper.normalizers.basic import BasicTextNormalizer 2 | from whisper.normalizers.english import EnglishTextNormalizer 3 | -------------------------------------------------------------------------------- /whisper/normalizers/basic.py: -------------------------------------------------------------------------------- 1 | import re 2 | import unicodedata 3 | 4 | import regex 5 | 6 | # non-ASCII letters that are not separated by "NFKD" normalization 7 | ADDITIONAL_DIACRITICS = { 8 | "œ": "oe", 9 | "Œ": "OE", 10 | "ø": "o", 11 | "Ø": "O", 12 | "æ": "ae", 13 | "Æ": "AE", 14 | "ß": "ss", 15 | "ẞ": "SS", 16 | "đ": "d", 17 | "Đ": "D", 18 | "ð": "d", 19 | "Ð": "D", 20 | "þ": "th", 21 | "Þ": "th", 22 | "ł": "l", 23 | "Ł": "L", 24 | } 25 | 26 | 27 | def remove_symbols_and_diacritics(s: str, keep=""): 28 | """ 29 | Replace any other markers, symbols, and punctuations with a space, 30 | and drop any diacritics (category 'Mn' and some manual mappings) 31 | """ 32 | return "".join( 33 | c 34 | if c in keep 35 | else ADDITIONAL_DIACRITICS[c] 36 | if c in ADDITIONAL_DIACRITICS 37 | else "" 38 | if unicodedata.category(c) == "Mn" 39 | else " " 40 | if unicodedata.category(c)[0] in "MSP" 41 | else c 42 | for c in unicodedata.normalize("NFKD", s) 43 | ) 44 | 45 | 46 | def remove_symbols(s: str): 47 | """ 48 | Replace any other markers, symbols, punctuations with a space, keeping diacritics 49 | """ 50 | return "".join( 51 | " " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s) 52 | ) 53 | 54 | 55 | class BasicTextNormalizer: 56 | def __init__(self, remove_diacritics: bool = False, split_letters: bool = False): 57 | self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols 58 | self.split_letters = split_letters 59 | 60 | def __call__(self, s: str): 61 | s = s.lower() 62 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets 63 | s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis 64 | s = self.clean(s).lower() 65 | 66 | if self.split_letters: 67 | s = " ".join(regex.findall(r"\X", s, regex.U)) 68 | 69 | s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space 70 | 71 | return s 72 | -------------------------------------------------------------------------------- /whisper/normalizers/english.json: -------------------------------------------------------------------------------- 1 | { 2 | "accessorise": "accessorize", 3 | "accessorised": "accessorized", 4 | "accessorises": "accessorizes", 5 | "accessorising": "accessorizing", 6 | "acclimatisation": "acclimatization", 7 | "acclimatise": "acclimatize", 8 | "acclimatised": "acclimatized", 9 | "acclimatises": "acclimatizes", 10 | "acclimatising": "acclimatizing", 11 | "accoutrements": "accouterments", 12 | "aeon": "eon", 13 | "aeons": "eons", 14 | "aerogramme": "aerogram", 15 | "aerogrammes": "aerograms", 16 | "aeroplane": "airplane", 17 | "aeroplanes": "airplanes", 18 | "aesthete": "esthete", 19 | "aesthetes": "esthetes", 20 | "aesthetic": "esthetic", 21 | "aesthetically": "esthetically", 22 | "aesthetics": "esthetics", 23 | "aetiology": "etiology", 24 | "ageing": "aging", 25 | "aggrandisement": "aggrandizement", 26 | "agonise": "agonize", 27 | "agonised": "agonized", 28 | "agonises": "agonizes", 29 | "agonising": "agonizing", 30 | "agonisingly": "agonizingly", 31 | "almanack": "almanac", 32 | "almanacks": "almanacs", 33 | "aluminium": "aluminum", 34 | "amortisable": "amortizable", 35 | "amortisation": "amortization", 36 | "amortisations": "amortizations", 37 | "amortise": "amortize", 38 | "amortised": "amortized", 39 | "amortises": "amortizes", 40 | "amortising": "amortizing", 41 | "amphitheatre": "amphitheater", 42 | "amphitheatres": "amphitheaters", 43 | "anaemia": "anemia", 44 | "anaemic": "anemic", 45 | "anaesthesia": "anesthesia", 46 | "anaesthetic": "anesthetic", 47 | "anaesthetics": "anesthetics", 48 | "anaesthetise": "anesthetize", 49 | "anaesthetised": "anesthetized", 50 | "anaesthetises": "anesthetizes", 51 | "anaesthetising": "anesthetizing", 52 | "anaesthetist": "anesthetist", 53 | "anaesthetists": "anesthetists", 54 | "anaesthetize": "anesthetize", 55 | "anaesthetized": "anesthetized", 56 | "anaesthetizes": "anesthetizes", 57 | "anaesthetizing": "anesthetizing", 58 | "analogue": "analog", 59 | "analogues": "analogs", 60 | "analyse": "analyze", 61 | "analysed": "analyzed", 62 | "analyses": "analyzes", 63 | "analysing": "analyzing", 64 | "anglicise": "anglicize", 65 | "anglicised": "anglicized", 66 | "anglicises": "anglicizes", 67 | "anglicising": "anglicizing", 68 | "annualised": "annualized", 69 | "antagonise": "antagonize", 70 | "antagonised": "antagonized", 71 | "antagonises": "antagonizes", 72 | "antagonising": "antagonizing", 73 | "apologise": "apologize", 74 | "apologised": "apologized", 75 | "apologises": "apologizes", 76 | "apologising": "apologizing", 77 | "appal": "appall", 78 | "appals": "appalls", 79 | "appetiser": "appetizer", 80 | "appetisers": "appetizers", 81 | "appetising": "appetizing", 82 | "appetisingly": "appetizingly", 83 | "arbour": "arbor", 84 | "arbours": "arbors", 85 | "archeological": "archaeological", 86 | "archaeologically": "archeologically", 87 | "archaeologist": "archeologist", 88 | "archaeologists": "archeologists", 89 | "archaeology": "archeology", 90 | "ardour": "ardor", 91 | "armour": "armor", 92 | "armoured": "armored", 93 | "armourer": "armorer", 94 | "armourers": "armorers", 95 | "armouries": "armories", 96 | "armoury": "armory", 97 | "artefact": "artifact", 98 | "artefacts": "artifacts", 99 | "authorise": "authorize", 100 | "authorised": "authorized", 101 | "authorises": "authorizes", 102 | "authorising": "authorizing", 103 | "axe": "ax", 104 | "backpedalled": "backpedaled", 105 | "backpedalling": "backpedaling", 106 | "bannister": "banister", 107 | "bannisters": "banisters", 108 | "baptise": "baptize", 109 | "baptised": "baptized", 110 | "baptises": "baptizes", 111 | "baptising": "baptizing", 112 | "bastardise": "bastardize", 113 | "bastardised": "bastardized", 114 | "bastardises": "bastardizes", 115 | "bastardising": "bastardizing", 116 | "battleax": "battleaxe", 117 | "baulk": "balk", 118 | "baulked": "balked", 119 | "baulking": "balking", 120 | "baulks": "balks", 121 | "bedevilled": "bedeviled", 122 | "bedevilling": "bedeviling", 123 | "behaviour": "behavior", 124 | "behavioural": "behavioral", 125 | "behaviourism": "behaviorism", 126 | "behaviourist": "behaviorist", 127 | "behaviourists": "behaviorists", 128 | "behaviours": "behaviors", 129 | "behove": "behoove", 130 | "behoved": "behooved", 131 | "behoves": "behooves", 132 | "bejewelled": "bejeweled", 133 | "belabour": "belabor", 134 | "belaboured": "belabored", 135 | "belabouring": "belaboring", 136 | "belabours": "belabors", 137 | "bevelled": "beveled", 138 | "bevvies": "bevies", 139 | "bevvy": "bevy", 140 | "biassed": "biased", 141 | "biassing": "biasing", 142 | "bingeing": "binging", 143 | "bougainvillaea": "bougainvillea", 144 | "bougainvillaeas": "bougainvilleas", 145 | "bowdlerise": "bowdlerize", 146 | "bowdlerised": "bowdlerized", 147 | "bowdlerises": "bowdlerizes", 148 | "bowdlerising": "bowdlerizing", 149 | "breathalyse": "breathalyze", 150 | "breathalysed": "breathalyzed", 151 | "breathalyser": "breathalyzer", 152 | "breathalysers": "breathalyzers", 153 | "breathalyses": "breathalyzes", 154 | "breathalysing": "breathalyzing", 155 | "brutalise": "brutalize", 156 | "brutalised": "brutalized", 157 | "brutalises": "brutalizes", 158 | "brutalising": "brutalizing", 159 | "busses": "buses", 160 | "bussing": "busing", 161 | "caesarean": "cesarean", 162 | "caesareans": "cesareans", 163 | "calibre": "caliber", 164 | "calibres": "calibers", 165 | "calliper": "caliper", 166 | "callipers": "calipers", 167 | "callisthenics": "calisthenics", 168 | "canalise": "canalize", 169 | "canalised": "canalized", 170 | "canalises": "canalizes", 171 | "canalising": "canalizing", 172 | "cancelation": "cancellation", 173 | "cancelations": "cancellations", 174 | "cancelled": "canceled", 175 | "cancelling": "canceling", 176 | "candour": "candor", 177 | "cannibalise": "cannibalize", 178 | "cannibalised": "cannibalized", 179 | "cannibalises": "cannibalizes", 180 | "cannibalising": "cannibalizing", 181 | "canonise": "canonize", 182 | "canonised": "canonized", 183 | "canonises": "canonizes", 184 | "canonising": "canonizing", 185 | "capitalise": "capitalize", 186 | "capitalised": "capitalized", 187 | "capitalises": "capitalizes", 188 | "capitalising": "capitalizing", 189 | "caramelise": "caramelize", 190 | "caramelised": "caramelized", 191 | "caramelises": "caramelizes", 192 | "caramelising": "caramelizing", 193 | "carbonise": "carbonize", 194 | "carbonised": "carbonized", 195 | "carbonises": "carbonizes", 196 | "carbonising": "carbonizing", 197 | "carolled": "caroled", 198 | "carolling": "caroling", 199 | "catalogue": "catalog", 200 | "catalogued": "cataloged", 201 | "catalogues": "catalogs", 202 | "cataloguing": "cataloging", 203 | "catalyse": "catalyze", 204 | "catalysed": "catalyzed", 205 | "catalyses": "catalyzes", 206 | "catalysing": "catalyzing", 207 | "categorise": "categorize", 208 | "categorised": "categorized", 209 | "categorises": "categorizes", 210 | "categorising": "categorizing", 211 | "cauterise": "cauterize", 212 | "cauterised": "cauterized", 213 | "cauterises": "cauterizes", 214 | "cauterising": "cauterizing", 215 | "cavilled": "caviled", 216 | "cavilling": "caviling", 217 | "centigramme": "centigram", 218 | "centigrammes": "centigrams", 219 | "centilitre": "centiliter", 220 | "centilitres": "centiliters", 221 | "centimetre": "centimeter", 222 | "centimetres": "centimeters", 223 | "centralise": "centralize", 224 | "centralised": "centralized", 225 | "centralises": "centralizes", 226 | "centralising": "centralizing", 227 | "centre": "center", 228 | "centred": "centered", 229 | "centrefold": "centerfold", 230 | "centrefolds": "centerfolds", 231 | "centrepiece": "centerpiece", 232 | "centrepieces": "centerpieces", 233 | "centres": "centers", 234 | "channelled": "channeled", 235 | "channelling": "channeling", 236 | "characterise": "characterize", 237 | "characterised": "characterized", 238 | "characterises": "characterizes", 239 | "characterising": "characterizing", 240 | "cheque": "check", 241 | "chequebook": "checkbook", 242 | "chequebooks": "checkbooks", 243 | "chequered": "checkered", 244 | "cheques": "checks", 245 | "chilli": "chili", 246 | "chimaera": "chimera", 247 | "chimaeras": "chimeras", 248 | "chiselled": "chiseled", 249 | "chiselling": "chiseling", 250 | "circularise": "circularize", 251 | "circularised": "circularized", 252 | "circularises": "circularizes", 253 | "circularising": "circularizing", 254 | "civilise": "civilize", 255 | "civilised": "civilized", 256 | "civilises": "civilizes", 257 | "civilising": "civilizing", 258 | "clamour": "clamor", 259 | "clamoured": "clamored", 260 | "clamouring": "clamoring", 261 | "clamours": "clamors", 262 | "clangour": "clangor", 263 | "clarinettist": "clarinetist", 264 | "clarinettists": "clarinetists", 265 | "collectivise": "collectivize", 266 | "collectivised": "collectivized", 267 | "collectivises": "collectivizes", 268 | "collectivising": "collectivizing", 269 | "colonisation": "colonization", 270 | "colonise": "colonize", 271 | "colonised": "colonized", 272 | "coloniser": "colonizer", 273 | "colonisers": "colonizers", 274 | "colonises": "colonizes", 275 | "colonising": "colonizing", 276 | "colour": "color", 277 | "colourant": "colorant", 278 | "colourants": "colorants", 279 | "coloured": "colored", 280 | "coloureds": "coloreds", 281 | "colourful": "colorful", 282 | "colourfully": "colorfully", 283 | "colouring": "coloring", 284 | "colourize": "colorize", 285 | "colourized": "colorized", 286 | "colourizes": "colorizes", 287 | "colourizing": "colorizing", 288 | "colourless": "colorless", 289 | "colours": "colors", 290 | "commercialise": "commercialize", 291 | "commercialised": "commercialized", 292 | "commercialises": "commercializes", 293 | "commercialising": "commercializing", 294 | "compartmentalise": "compartmentalize", 295 | "compartmentalised": "compartmentalized", 296 | "compartmentalises": "compartmentalizes", 297 | "compartmentalising": "compartmentalizing", 298 | "computerise": "computerize", 299 | "computerised": "computerized", 300 | "computerises": "computerizes", 301 | "computerising": "computerizing", 302 | "conceptualise": "conceptualize", 303 | "conceptualised": "conceptualized", 304 | "conceptualises": "conceptualizes", 305 | "conceptualising": "conceptualizing", 306 | "connexion": "connection", 307 | "connexions": "connections", 308 | "contextualise": "contextualize", 309 | "contextualised": "contextualized", 310 | "contextualises": "contextualizes", 311 | "contextualising": "contextualizing", 312 | "cosier": "cozier", 313 | "cosies": "cozies", 314 | "cosiest": "coziest", 315 | "cosily": "cozily", 316 | "cosiness": "coziness", 317 | "cosy": "cozy", 318 | "councillor": "councilor", 319 | "councillors": "councilors", 320 | "counselled": "counseled", 321 | "counselling": "counseling", 322 | "counsellor": "counselor", 323 | "counsellors": "counselors", 324 | "crenelated": "crenellated", 325 | "criminalise": "criminalize", 326 | "criminalised": "criminalized", 327 | "criminalises": "criminalizes", 328 | "criminalising": "criminalizing", 329 | "criticise": "criticize", 330 | "criticised": "criticized", 331 | "criticises": "criticizes", 332 | "criticising": "criticizing", 333 | "crueller": "crueler", 334 | "cruellest": "cruelest", 335 | "crystallisation": "crystallization", 336 | "crystallise": "crystallize", 337 | "crystallised": "crystallized", 338 | "crystallises": "crystallizes", 339 | "crystallising": "crystallizing", 340 | "cudgelled": "cudgeled", 341 | "cudgelling": "cudgeling", 342 | "customise": "customize", 343 | "customised": "customized", 344 | "customises": "customizes", 345 | "customising": "customizing", 346 | "cypher": "cipher", 347 | "cyphers": "ciphers", 348 | "decentralisation": "decentralization", 349 | "decentralise": "decentralize", 350 | "decentralised": "decentralized", 351 | "decentralises": "decentralizes", 352 | "decentralising": "decentralizing", 353 | "decriminalisation": "decriminalization", 354 | "decriminalise": "decriminalize", 355 | "decriminalised": "decriminalized", 356 | "decriminalises": "decriminalizes", 357 | "decriminalising": "decriminalizing", 358 | "defence": "defense", 359 | "defenceless": "defenseless", 360 | "defences": "defenses", 361 | "dehumanisation": "dehumanization", 362 | "dehumanise": "dehumanize", 363 | "dehumanised": "dehumanized", 364 | "dehumanises": "dehumanizes", 365 | "dehumanising": "dehumanizing", 366 | "demeanour": "demeanor", 367 | "demilitarisation": "demilitarization", 368 | "demilitarise": "demilitarize", 369 | "demilitarised": "demilitarized", 370 | "demilitarises": "demilitarizes", 371 | "demilitarising": "demilitarizing", 372 | "demobilisation": "demobilization", 373 | "demobilise": "demobilize", 374 | "demobilised": "demobilized", 375 | "demobilises": "demobilizes", 376 | "demobilising": "demobilizing", 377 | "democratisation": "democratization", 378 | "democratise": "democratize", 379 | "democratised": "democratized", 380 | "democratises": "democratizes", 381 | "democratising": "democratizing", 382 | "demonise": "demonize", 383 | "demonised": "demonized", 384 | "demonises": "demonizes", 385 | "demonising": "demonizing", 386 | "demoralisation": "demoralization", 387 | "demoralise": "demoralize", 388 | "demoralised": "demoralized", 389 | "demoralises": "demoralizes", 390 | "demoralising": "demoralizing", 391 | "denationalisation": "denationalization", 392 | "denationalise": "denationalize", 393 | "denationalised": "denationalized", 394 | "denationalises": "denationalizes", 395 | "denationalising": "denationalizing", 396 | "deodorise": "deodorize", 397 | "deodorised": "deodorized", 398 | "deodorises": "deodorizes", 399 | "deodorising": "deodorizing", 400 | "depersonalise": "depersonalize", 401 | "depersonalised": "depersonalized", 402 | "depersonalises": "depersonalizes", 403 | "depersonalising": "depersonalizing", 404 | "deputise": "deputize", 405 | "deputised": "deputized", 406 | "deputises": "deputizes", 407 | "deputising": "deputizing", 408 | "desensitisation": "desensitization", 409 | "desensitise": "desensitize", 410 | "desensitised": "desensitized", 411 | "desensitises": "desensitizes", 412 | "desensitising": "desensitizing", 413 | "destabilisation": "destabilization", 414 | "destabilise": "destabilize", 415 | "destabilised": "destabilized", 416 | "destabilises": "destabilizes", 417 | "destabilising": "destabilizing", 418 | "dialled": "dialed", 419 | "dialling": "dialing", 420 | "dialogue": "dialog", 421 | "dialogues": "dialogs", 422 | "diarrhoea": "diarrhea", 423 | "digitise": "digitize", 424 | "digitised": "digitized", 425 | "digitises": "digitizes", 426 | "digitising": "digitizing", 427 | "disc": "disk", 428 | "discolour": "discolor", 429 | "discoloured": "discolored", 430 | "discolouring": "discoloring", 431 | "discolours": "discolors", 432 | "discs": "disks", 433 | "disembowelled": "disemboweled", 434 | "disembowelling": "disemboweling", 435 | "disfavour": "disfavor", 436 | "dishevelled": "disheveled", 437 | "dishonour": "dishonor", 438 | "dishonourable": "dishonorable", 439 | "dishonourably": "dishonorably", 440 | "dishonoured": "dishonored", 441 | "dishonouring": "dishonoring", 442 | "dishonours": "dishonors", 443 | "disorganisation": "disorganization", 444 | "disorganised": "disorganized", 445 | "distil": "distill", 446 | "distils": "distills", 447 | "dramatisation": "dramatization", 448 | "dramatisations": "dramatizations", 449 | "dramatise": "dramatize", 450 | "dramatised": "dramatized", 451 | "dramatises": "dramatizes", 452 | "dramatising": "dramatizing", 453 | "draught": "draft", 454 | "draughtboard": "draftboard", 455 | "draughtboards": "draftboards", 456 | "draughtier": "draftier", 457 | "draughtiest": "draftiest", 458 | "draughts": "drafts", 459 | "draughtsman": "draftsman", 460 | "draughtsmanship": "draftsmanship", 461 | "draughtsmen": "draftsmen", 462 | "draughtswoman": "draftswoman", 463 | "draughtswomen": "draftswomen", 464 | "draughty": "drafty", 465 | "drivelled": "driveled", 466 | "drivelling": "driveling", 467 | "duelled": "dueled", 468 | "duelling": "dueling", 469 | "economise": "economize", 470 | "economised": "economized", 471 | "economises": "economizes", 472 | "economising": "economizing", 473 | "edoema": "edema", 474 | "editorialise": "editorialize", 475 | "editorialised": "editorialized", 476 | "editorialises": "editorializes", 477 | "editorialising": "editorializing", 478 | "empathise": "empathize", 479 | "empathised": "empathized", 480 | "empathises": "empathizes", 481 | "empathising": "empathizing", 482 | "emphasise": "emphasize", 483 | "emphasised": "emphasized", 484 | "emphasises": "emphasizes", 485 | "emphasising": "emphasizing", 486 | "enamelled": "enameled", 487 | "enamelling": "enameling", 488 | "enamoured": "enamored", 489 | "encyclopaedia": "encyclopedia", 490 | "encyclopaedias": "encyclopedias", 491 | "encyclopaedic": "encyclopedic", 492 | "endeavour": "endeavor", 493 | "endeavoured": "endeavored", 494 | "endeavouring": "endeavoring", 495 | "endeavours": "endeavors", 496 | "energise": "energize", 497 | "energised": "energized", 498 | "energises": "energizes", 499 | "energising": "energizing", 500 | "enrol": "enroll", 501 | "enrols": "enrolls", 502 | "enthral": "enthrall", 503 | "enthrals": "enthralls", 504 | "epaulette": "epaulet", 505 | "epaulettes": "epaulets", 506 | "epicentre": "epicenter", 507 | "epicentres": "epicenters", 508 | "epilogue": "epilog", 509 | "epilogues": "epilogs", 510 | "epitomise": "epitomize", 511 | "epitomised": "epitomized", 512 | "epitomises": "epitomizes", 513 | "epitomising": "epitomizing", 514 | "equalisation": "equalization", 515 | "equalise": "equalize", 516 | "equalised": "equalized", 517 | "equaliser": "equalizer", 518 | "equalisers": "equalizers", 519 | "equalises": "equalizes", 520 | "equalising": "equalizing", 521 | "eulogise": "eulogize", 522 | "eulogised": "eulogized", 523 | "eulogises": "eulogizes", 524 | "eulogising": "eulogizing", 525 | "evangelise": "evangelize", 526 | "evangelised": "evangelized", 527 | "evangelises": "evangelizes", 528 | "evangelising": "evangelizing", 529 | "exorcise": "exorcize", 530 | "exorcised": "exorcized", 531 | "exorcises": "exorcizes", 532 | "exorcising": "exorcizing", 533 | "extemporisation": "extemporization", 534 | "extemporise": "extemporize", 535 | "extemporised": "extemporized", 536 | "extemporises": "extemporizes", 537 | "extemporising": "extemporizing", 538 | "externalisation": "externalization", 539 | "externalisations": "externalizations", 540 | "externalise": "externalize", 541 | "externalised": "externalized", 542 | "externalises": "externalizes", 543 | "externalising": "externalizing", 544 | "factorise": "factorize", 545 | "factorised": "factorized", 546 | "factorises": "factorizes", 547 | "factorising": "factorizing", 548 | "faecal": "fecal", 549 | "faeces": "feces", 550 | "familiarisation": "familiarization", 551 | "familiarise": "familiarize", 552 | "familiarised": "familiarized", 553 | "familiarises": "familiarizes", 554 | "familiarising": "familiarizing", 555 | "fantasise": "fantasize", 556 | "fantasised": "fantasized", 557 | "fantasises": "fantasizes", 558 | "fantasising": "fantasizing", 559 | "favour": "favor", 560 | "favourable": "favorable", 561 | "favourably": "favorably", 562 | "favoured": "favored", 563 | "favouring": "favoring", 564 | "favourite": "favorite", 565 | "favourites": "favorites", 566 | "favouritism": "favoritism", 567 | "favours": "favors", 568 | "feminise": "feminize", 569 | "feminised": "feminized", 570 | "feminises": "feminizes", 571 | "feminising": "feminizing", 572 | "fertilisation": "fertilization", 573 | "fertilise": "fertilize", 574 | "fertilised": "fertilized", 575 | "fertiliser": "fertilizer", 576 | "fertilisers": "fertilizers", 577 | "fertilises": "fertilizes", 578 | "fertilising": "fertilizing", 579 | "fervour": "fervor", 580 | "fibre": "fiber", 581 | "fibreglass": "fiberglass", 582 | "fibres": "fibers", 583 | "fictionalisation": "fictionalization", 584 | "fictionalisations": "fictionalizations", 585 | "fictionalise": "fictionalize", 586 | "fictionalised": "fictionalized", 587 | "fictionalises": "fictionalizes", 588 | "fictionalising": "fictionalizing", 589 | "fillet": "filet", 590 | "filleted": "fileted", 591 | "filleting": "fileting", 592 | "fillets": "filets", 593 | "finalisation": "finalization", 594 | "finalise": "finalize", 595 | "finalised": "finalized", 596 | "finalises": "finalizes", 597 | "finalising": "finalizing", 598 | "flautist": "flutist", 599 | "flautists": "flutists", 600 | "flavour": "flavor", 601 | "flavoured": "flavored", 602 | "flavouring": "flavoring", 603 | "flavourings": "flavorings", 604 | "flavourless": "flavorless", 605 | "flavours": "flavors", 606 | "flavoursome": "flavorsome", 607 | "flyer / flier": "flier / flyer", 608 | "foetal": "fetal", 609 | "foetid": "fetid", 610 | "foetus": "fetus", 611 | "foetuses": "fetuses", 612 | "formalisation": "formalization", 613 | "formalise": "formalize", 614 | "formalised": "formalized", 615 | "formalises": "formalizes", 616 | "formalising": "formalizing", 617 | "fossilisation": "fossilization", 618 | "fossilise": "fossilize", 619 | "fossilised": "fossilized", 620 | "fossilises": "fossilizes", 621 | "fossilising": "fossilizing", 622 | "fraternisation": "fraternization", 623 | "fraternise": "fraternize", 624 | "fraternised": "fraternized", 625 | "fraternises": "fraternizes", 626 | "fraternising": "fraternizing", 627 | "fulfil": "fulfill", 628 | "fulfilment": "fulfillment", 629 | "fulfils": "fulfills", 630 | "funnelled": "funneled", 631 | "funnelling": "funneling", 632 | "galvanise": "galvanize", 633 | "galvanised": "galvanized", 634 | "galvanises": "galvanizes", 635 | "galvanising": "galvanizing", 636 | "gambolled": "gamboled", 637 | "gambolling": "gamboling", 638 | "gaol": "jail", 639 | "gaolbird": "jailbird", 640 | "gaolbirds": "jailbirds", 641 | "gaolbreak": "jailbreak", 642 | "gaolbreaks": "jailbreaks", 643 | "gaoled": "jailed", 644 | "gaoler": "jailer", 645 | "gaolers": "jailers", 646 | "gaoling": "jailing", 647 | "gaols": "jails", 648 | "gasses": "gases", 649 | "gage": "gauge", 650 | "gaged": "gauged", 651 | "gages": "gauges", 652 | "gaging": "gauging", 653 | "generalisation": "generalization", 654 | "generalisations": "generalizations", 655 | "generalise": "generalize", 656 | "generalised": "generalized", 657 | "generalises": "generalizes", 658 | "generalising": "generalizing", 659 | "ghettoise": "ghettoize", 660 | "ghettoised": "ghettoized", 661 | "ghettoises": "ghettoizes", 662 | "ghettoising": "ghettoizing", 663 | "gipsies": "gypsies", 664 | "glamorise": "glamorize", 665 | "glamorised": "glamorized", 666 | "glamorises": "glamorizes", 667 | "glamorising": "glamorizing", 668 | "glamor": "glamour", 669 | "globalisation": "globalization", 670 | "globalise": "globalize", 671 | "globalised": "globalized", 672 | "globalises": "globalizes", 673 | "globalising": "globalizing", 674 | "glueing": "gluing", 675 | "goitre": "goiter", 676 | "goitres": "goiters", 677 | "gonorrhoea": "gonorrhea", 678 | "gramme": "gram", 679 | "grammes": "grams", 680 | "gravelled": "graveled", 681 | "grey": "gray", 682 | "greyed": "grayed", 683 | "greying": "graying", 684 | "greyish": "grayish", 685 | "greyness": "grayness", 686 | "greys": "grays", 687 | "grovelled": "groveled", 688 | "grovelling": "groveling", 689 | "groyne": "groin", 690 | "groynes": "groins", 691 | "gruelling": "grueling", 692 | "gruellingly": "gruelingly", 693 | "gryphon": "griffin", 694 | "gryphons": "griffins", 695 | "gynaecological": "gynecological", 696 | "gynaecologist": "gynecologist", 697 | "gynaecologists": "gynecologists", 698 | "gynaecology": "gynecology", 699 | "haematological": "hematological", 700 | "haematologist": "hematologist", 701 | "haematologists": "hematologists", 702 | "haematology": "hematology", 703 | "haemoglobin": "hemoglobin", 704 | "haemophilia": "hemophilia", 705 | "haemophiliac": "hemophiliac", 706 | "haemophiliacs": "hemophiliacs", 707 | "haemorrhage": "hemorrhage", 708 | "haemorrhaged": "hemorrhaged", 709 | "haemorrhages": "hemorrhages", 710 | "haemorrhaging": "hemorrhaging", 711 | "haemorrhoids": "hemorrhoids", 712 | "harbour": "harbor", 713 | "harboured": "harbored", 714 | "harbouring": "harboring", 715 | "harbours": "harbors", 716 | "harmonisation": "harmonization", 717 | "harmonise": "harmonize", 718 | "harmonised": "harmonized", 719 | "harmonises": "harmonizes", 720 | "harmonising": "harmonizing", 721 | "homoeopath": "homeopath", 722 | "homoeopathic": "homeopathic", 723 | "homoeopaths": "homeopaths", 724 | "homoeopathy": "homeopathy", 725 | "homogenise": "homogenize", 726 | "homogenised": "homogenized", 727 | "homogenises": "homogenizes", 728 | "homogenising": "homogenizing", 729 | "honour": "honor", 730 | "honourable": "honorable", 731 | "honourably": "honorably", 732 | "honoured": "honored", 733 | "honouring": "honoring", 734 | "honours": "honors", 735 | "hospitalisation": "hospitalization", 736 | "hospitalise": "hospitalize", 737 | "hospitalised": "hospitalized", 738 | "hospitalises": "hospitalizes", 739 | "hospitalising": "hospitalizing", 740 | "humanise": "humanize", 741 | "humanised": "humanized", 742 | "humanises": "humanizes", 743 | "humanising": "humanizing", 744 | "humour": "humor", 745 | "humoured": "humored", 746 | "humouring": "humoring", 747 | "humourless": "humorless", 748 | "humours": "humors", 749 | "hybridise": "hybridize", 750 | "hybridised": "hybridized", 751 | "hybridises": "hybridizes", 752 | "hybridising": "hybridizing", 753 | "hypnotise": "hypnotize", 754 | "hypnotised": "hypnotized", 755 | "hypnotises": "hypnotizes", 756 | "hypnotising": "hypnotizing", 757 | "hypothesise": "hypothesize", 758 | "hypothesised": "hypothesized", 759 | "hypothesises": "hypothesizes", 760 | "hypothesising": "hypothesizing", 761 | "idealisation": "idealization", 762 | "idealise": "idealize", 763 | "idealised": "idealized", 764 | "idealises": "idealizes", 765 | "idealising": "idealizing", 766 | "idolise": "idolize", 767 | "idolised": "idolized", 768 | "idolises": "idolizes", 769 | "idolising": "idolizing", 770 | "immobilisation": "immobilization", 771 | "immobilise": "immobilize", 772 | "immobilised": "immobilized", 773 | "immobiliser": "immobilizer", 774 | "immobilisers": "immobilizers", 775 | "immobilises": "immobilizes", 776 | "immobilising": "immobilizing", 777 | "immortalise": "immortalize", 778 | "immortalised": "immortalized", 779 | "immortalises": "immortalizes", 780 | "immortalising": "immortalizing", 781 | "immunisation": "immunization", 782 | "immunise": "immunize", 783 | "immunised": "immunized", 784 | "immunises": "immunizes", 785 | "immunising": "immunizing", 786 | "impanelled": "impaneled", 787 | "impanelling": "impaneling", 788 | "imperilled": "imperiled", 789 | "imperilling": "imperiling", 790 | "individualise": "individualize", 791 | "individualised": "individualized", 792 | "individualises": "individualizes", 793 | "individualising": "individualizing", 794 | "industrialise": "industrialize", 795 | "industrialised": "industrialized", 796 | "industrialises": "industrializes", 797 | "industrialising": "industrializing", 798 | "inflexion": "inflection", 799 | "inflexions": "inflections", 800 | "initialise": "initialize", 801 | "initialised": "initialized", 802 | "initialises": "initializes", 803 | "initialising": "initializing", 804 | "initialled": "initialed", 805 | "initialling": "initialing", 806 | "instal": "install", 807 | "instalment": "installment", 808 | "instalments": "installments", 809 | "instals": "installs", 810 | "instil": "instill", 811 | "instils": "instills", 812 | "institutionalisation": "institutionalization", 813 | "institutionalise": "institutionalize", 814 | "institutionalised": "institutionalized", 815 | "institutionalises": "institutionalizes", 816 | "institutionalising": "institutionalizing", 817 | "intellectualise": "intellectualize", 818 | "intellectualised": "intellectualized", 819 | "intellectualises": "intellectualizes", 820 | "intellectualising": "intellectualizing", 821 | "internalisation": "internalization", 822 | "internalise": "internalize", 823 | "internalised": "internalized", 824 | "internalises": "internalizes", 825 | "internalising": "internalizing", 826 | "internationalisation": "internationalization", 827 | "internationalise": "internationalize", 828 | "internationalised": "internationalized", 829 | "internationalises": "internationalizes", 830 | "internationalising": "internationalizing", 831 | "ionisation": "ionization", 832 | "ionise": "ionize", 833 | "ionised": "ionized", 834 | "ioniser": "ionizer", 835 | "ionisers": "ionizers", 836 | "ionises": "ionizes", 837 | "ionising": "ionizing", 838 | "italicise": "italicize", 839 | "italicised": "italicized", 840 | "italicises": "italicizes", 841 | "italicising": "italicizing", 842 | "itemise": "itemize", 843 | "itemised": "itemized", 844 | "itemises": "itemizes", 845 | "itemising": "itemizing", 846 | "jeopardise": "jeopardize", 847 | "jeopardised": "jeopardized", 848 | "jeopardises": "jeopardizes", 849 | "jeopardising": "jeopardizing", 850 | "jewelled": "jeweled", 851 | "jeweller": "jeweler", 852 | "jewellers": "jewelers", 853 | "jewellery": "jewelry", 854 | "judgement": "judgment", 855 | "kilogramme": "kilogram", 856 | "kilogrammes": "kilograms", 857 | "kilometre": "kilometer", 858 | "kilometres": "kilometers", 859 | "labelled": "labeled", 860 | "labelling": "labeling", 861 | "labour": "labor", 862 | "laboured": "labored", 863 | "labourer": "laborer", 864 | "labourers": "laborers", 865 | "labouring": "laboring", 866 | "labours": "labors", 867 | "lacklustre": "lackluster", 868 | "legalisation": "legalization", 869 | "legalise": "legalize", 870 | "legalised": "legalized", 871 | "legalises": "legalizes", 872 | "legalising": "legalizing", 873 | "legitimise": "legitimize", 874 | "legitimised": "legitimized", 875 | "legitimises": "legitimizes", 876 | "legitimising": "legitimizing", 877 | "leukaemia": "leukemia", 878 | "levelled": "leveled", 879 | "leveller": "leveler", 880 | "levellers": "levelers", 881 | "levelling": "leveling", 882 | "libelled": "libeled", 883 | "libelling": "libeling", 884 | "libellous": "libelous", 885 | "liberalisation": "liberalization", 886 | "liberalise": "liberalize", 887 | "liberalised": "liberalized", 888 | "liberalises": "liberalizes", 889 | "liberalising": "liberalizing", 890 | "licence": "license", 891 | "licenced": "licensed", 892 | "licences": "licenses", 893 | "licencing": "licensing", 894 | "likeable": "likable", 895 | "lionisation": "lionization", 896 | "lionise": "lionize", 897 | "lionised": "lionized", 898 | "lionises": "lionizes", 899 | "lionising": "lionizing", 900 | "liquidise": "liquidize", 901 | "liquidised": "liquidized", 902 | "liquidiser": "liquidizer", 903 | "liquidisers": "liquidizers", 904 | "liquidises": "liquidizes", 905 | "liquidising": "liquidizing", 906 | "litre": "liter", 907 | "litres": "liters", 908 | "localise": "localize", 909 | "localised": "localized", 910 | "localises": "localizes", 911 | "localising": "localizing", 912 | "louvre": "louver", 913 | "louvred": "louvered", 914 | "louvres": "louvers", 915 | "lustre": "luster", 916 | "magnetise": "magnetize", 917 | "magnetised": "magnetized", 918 | "magnetises": "magnetizes", 919 | "magnetising": "magnetizing", 920 | "manoeuvrability": "maneuverability", 921 | "manoeuvrable": "maneuverable", 922 | "manoeuvre": "maneuver", 923 | "manoeuvred": "maneuvered", 924 | "manoeuvres": "maneuvers", 925 | "manoeuvring": "maneuvering", 926 | "manoeuvrings": "maneuverings", 927 | "marginalisation": "marginalization", 928 | "marginalise": "marginalize", 929 | "marginalised": "marginalized", 930 | "marginalises": "marginalizes", 931 | "marginalising": "marginalizing", 932 | "marshalled": "marshaled", 933 | "marshalling": "marshaling", 934 | "marvelled": "marveled", 935 | "marvelling": "marveling", 936 | "marvellous": "marvelous", 937 | "marvellously": "marvelously", 938 | "materialisation": "materialization", 939 | "materialise": "materialize", 940 | "materialised": "materialized", 941 | "materialises": "materializes", 942 | "materialising": "materializing", 943 | "maximisation": "maximization", 944 | "maximise": "maximize", 945 | "maximised": "maximized", 946 | "maximises": "maximizes", 947 | "maximising": "maximizing", 948 | "meagre": "meager", 949 | "mechanisation": "mechanization", 950 | "mechanise": "mechanize", 951 | "mechanised": "mechanized", 952 | "mechanises": "mechanizes", 953 | "mechanising": "mechanizing", 954 | "mediaeval": "medieval", 955 | "memorialise": "memorialize", 956 | "memorialised": "memorialized", 957 | "memorialises": "memorializes", 958 | "memorialising": "memorializing", 959 | "memorise": "memorize", 960 | "memorised": "memorized", 961 | "memorises": "memorizes", 962 | "memorising": "memorizing", 963 | "mesmerise": "mesmerize", 964 | "mesmerised": "mesmerized", 965 | "mesmerises": "mesmerizes", 966 | "mesmerising": "mesmerizing", 967 | "metabolise": "metabolize", 968 | "metabolised": "metabolized", 969 | "metabolises": "metabolizes", 970 | "metabolising": "metabolizing", 971 | "metre": "meter", 972 | "metres": "meters", 973 | "micrometre": "micrometer", 974 | "micrometres": "micrometers", 975 | "militarise": "militarize", 976 | "militarised": "militarized", 977 | "militarises": "militarizes", 978 | "militarising": "militarizing", 979 | "milligramme": "milligram", 980 | "milligrammes": "milligrams", 981 | "millilitre": "milliliter", 982 | "millilitres": "milliliters", 983 | "millimetre": "millimeter", 984 | "millimetres": "millimeters", 985 | "miniaturisation": "miniaturization", 986 | "miniaturise": "miniaturize", 987 | "miniaturised": "miniaturized", 988 | "miniaturises": "miniaturizes", 989 | "miniaturising": "miniaturizing", 990 | "minibusses": "minibuses", 991 | "minimise": "minimize", 992 | "minimised": "minimized", 993 | "minimises": "minimizes", 994 | "minimising": "minimizing", 995 | "misbehaviour": "misbehavior", 996 | "misdemeanour": "misdemeanor", 997 | "misdemeanours": "misdemeanors", 998 | "misspelt": "misspelled", 999 | "mitre": "miter", 1000 | "mitres": "miters", 1001 | "mobilisation": "mobilization", 1002 | "mobilise": "mobilize", 1003 | "mobilised": "mobilized", 1004 | "mobilises": "mobilizes", 1005 | "mobilising": "mobilizing", 1006 | "modelled": "modeled", 1007 | "modeller": "modeler", 1008 | "modellers": "modelers", 1009 | "modelling": "modeling", 1010 | "modernise": "modernize", 1011 | "modernised": "modernized", 1012 | "modernises": "modernizes", 1013 | "modernising": "modernizing", 1014 | "moisturise": "moisturize", 1015 | "moisturised": "moisturized", 1016 | "moisturiser": "moisturizer", 1017 | "moisturisers": "moisturizers", 1018 | "moisturises": "moisturizes", 1019 | "moisturising": "moisturizing", 1020 | "monologue": "monolog", 1021 | "monologues": "monologs", 1022 | "monopolisation": "monopolization", 1023 | "monopolise": "monopolize", 1024 | "monopolised": "monopolized", 1025 | "monopolises": "monopolizes", 1026 | "monopolising": "monopolizing", 1027 | "moralise": "moralize", 1028 | "moralised": "moralized", 1029 | "moralises": "moralizes", 1030 | "moralising": "moralizing", 1031 | "motorised": "motorized", 1032 | "mould": "mold", 1033 | "moulded": "molded", 1034 | "moulder": "molder", 1035 | "mouldered": "moldered", 1036 | "mouldering": "moldering", 1037 | "moulders": "molders", 1038 | "mouldier": "moldier", 1039 | "mouldiest": "moldiest", 1040 | "moulding": "molding", 1041 | "mouldings": "moldings", 1042 | "moulds": "molds", 1043 | "mouldy": "moldy", 1044 | "moult": "molt", 1045 | "moulted": "molted", 1046 | "moulting": "molting", 1047 | "moults": "molts", 1048 | "moustache": "mustache", 1049 | "moustached": "mustached", 1050 | "moustaches": "mustaches", 1051 | "moustachioed": "mustachioed", 1052 | "multicoloured": "multicolored", 1053 | "nationalisation": "nationalization", 1054 | "nationalisations": "nationalizations", 1055 | "nationalise": "nationalize", 1056 | "nationalised": "nationalized", 1057 | "nationalises": "nationalizes", 1058 | "nationalising": "nationalizing", 1059 | "naturalisation": "naturalization", 1060 | "naturalise": "naturalize", 1061 | "naturalised": "naturalized", 1062 | "naturalises": "naturalizes", 1063 | "naturalising": "naturalizing", 1064 | "neighbour": "neighbor", 1065 | "neighbourhood": "neighborhood", 1066 | "neighbourhoods": "neighborhoods", 1067 | "neighbouring": "neighboring", 1068 | "neighbourliness": "neighborliness", 1069 | "neighbourly": "neighborly", 1070 | "neighbours": "neighbors", 1071 | "neutralisation": "neutralization", 1072 | "neutralise": "neutralize", 1073 | "neutralised": "neutralized", 1074 | "neutralises": "neutralizes", 1075 | "neutralising": "neutralizing", 1076 | "normalisation": "normalization", 1077 | "normalise": "normalize", 1078 | "normalised": "normalized", 1079 | "normalises": "normalizes", 1080 | "normalising": "normalizing", 1081 | "odour": "odor", 1082 | "odourless": "odorless", 1083 | "odours": "odors", 1084 | "oesophagus": "esophagus", 1085 | "oesophaguses": "esophaguses", 1086 | "oestrogen": "estrogen", 1087 | "offence": "offense", 1088 | "offences": "offenses", 1089 | "omelette": "omelet", 1090 | "omelettes": "omelets", 1091 | "optimise": "optimize", 1092 | "optimised": "optimized", 1093 | "optimises": "optimizes", 1094 | "optimising": "optimizing", 1095 | "organisation": "organization", 1096 | "organisational": "organizational", 1097 | "organisations": "organizations", 1098 | "organise": "organize", 1099 | "organised": "organized", 1100 | "organiser": "organizer", 1101 | "organisers": "organizers", 1102 | "organises": "organizes", 1103 | "organising": "organizing", 1104 | "orthopaedic": "orthopedic", 1105 | "orthopaedics": "orthopedics", 1106 | "ostracise": "ostracize", 1107 | "ostracised": "ostracized", 1108 | "ostracises": "ostracizes", 1109 | "ostracising": "ostracizing", 1110 | "outmanoeuvre": "outmaneuver", 1111 | "outmanoeuvred": "outmaneuvered", 1112 | "outmanoeuvres": "outmaneuvers", 1113 | "outmanoeuvring": "outmaneuvering", 1114 | "overemphasise": "overemphasize", 1115 | "overemphasised": "overemphasized", 1116 | "overemphasises": "overemphasizes", 1117 | "overemphasising": "overemphasizing", 1118 | "oxidisation": "oxidization", 1119 | "oxidise": "oxidize", 1120 | "oxidised": "oxidized", 1121 | "oxidises": "oxidizes", 1122 | "oxidising": "oxidizing", 1123 | "paederast": "pederast", 1124 | "paederasts": "pederasts", 1125 | "paediatric": "pediatric", 1126 | "paediatrician": "pediatrician", 1127 | "paediatricians": "pediatricians", 1128 | "paediatrics": "pediatrics", 1129 | "paedophile": "pedophile", 1130 | "paedophiles": "pedophiles", 1131 | "paedophilia": "pedophilia", 1132 | "palaeolithic": "paleolithic", 1133 | "palaeontologist": "paleontologist", 1134 | "palaeontologists": "paleontologists", 1135 | "palaeontology": "paleontology", 1136 | "panelled": "paneled", 1137 | "panelling": "paneling", 1138 | "panellist": "panelist", 1139 | "panellists": "panelists", 1140 | "paralyse": "paralyze", 1141 | "paralysed": "paralyzed", 1142 | "paralyses": "paralyzes", 1143 | "paralysing": "paralyzing", 1144 | "parcelled": "parceled", 1145 | "parcelling": "parceling", 1146 | "parlour": "parlor", 1147 | "parlours": "parlors", 1148 | "particularise": "particularize", 1149 | "particularised": "particularized", 1150 | "particularises": "particularizes", 1151 | "particularising": "particularizing", 1152 | "passivisation": "passivization", 1153 | "passivise": "passivize", 1154 | "passivised": "passivized", 1155 | "passivises": "passivizes", 1156 | "passivising": "passivizing", 1157 | "pasteurisation": "pasteurization", 1158 | "pasteurise": "pasteurize", 1159 | "pasteurised": "pasteurized", 1160 | "pasteurises": "pasteurizes", 1161 | "pasteurising": "pasteurizing", 1162 | "patronise": "patronize", 1163 | "patronised": "patronized", 1164 | "patronises": "patronizes", 1165 | "patronising": "patronizing", 1166 | "patronisingly": "patronizingly", 1167 | "pedalled": "pedaled", 1168 | "pedalling": "pedaling", 1169 | "pedestrianisation": "pedestrianization", 1170 | "pedestrianise": "pedestrianize", 1171 | "pedestrianised": "pedestrianized", 1172 | "pedestrianises": "pedestrianizes", 1173 | "pedestrianising": "pedestrianizing", 1174 | "penalise": "penalize", 1175 | "penalised": "penalized", 1176 | "penalises": "penalizes", 1177 | "penalising": "penalizing", 1178 | "pencilled": "penciled", 1179 | "pencilling": "penciling", 1180 | "personalise": "personalize", 1181 | "personalised": "personalized", 1182 | "personalises": "personalizes", 1183 | "personalising": "personalizing", 1184 | "pharmacopoeia": "pharmacopeia", 1185 | "pharmacopoeias": "pharmacopeias", 1186 | "philosophise": "philosophize", 1187 | "philosophised": "philosophized", 1188 | "philosophises": "philosophizes", 1189 | "philosophising": "philosophizing", 1190 | "philtre": "filter", 1191 | "philtres": "filters", 1192 | "phoney": "phony", 1193 | "plagiarise": "plagiarize", 1194 | "plagiarised": "plagiarized", 1195 | "plagiarises": "plagiarizes", 1196 | "plagiarising": "plagiarizing", 1197 | "plough": "plow", 1198 | "ploughed": "plowed", 1199 | "ploughing": "plowing", 1200 | "ploughman": "plowman", 1201 | "ploughmen": "plowmen", 1202 | "ploughs": "plows", 1203 | "ploughshare": "plowshare", 1204 | "ploughshares": "plowshares", 1205 | "polarisation": "polarization", 1206 | "polarise": "polarize", 1207 | "polarised": "polarized", 1208 | "polarises": "polarizes", 1209 | "polarising": "polarizing", 1210 | "politicisation": "politicization", 1211 | "politicise": "politicize", 1212 | "politicised": "politicized", 1213 | "politicises": "politicizes", 1214 | "politicising": "politicizing", 1215 | "popularisation": "popularization", 1216 | "popularise": "popularize", 1217 | "popularised": "popularized", 1218 | "popularises": "popularizes", 1219 | "popularising": "popularizing", 1220 | "pouffe": "pouf", 1221 | "pouffes": "poufs", 1222 | "practise": "practice", 1223 | "practised": "practiced", 1224 | "practises": "practices", 1225 | "practising": "practicing", 1226 | "praesidium": "presidium", 1227 | "praesidiums": "presidiums", 1228 | "pressurisation": "pressurization", 1229 | "pressurise": "pressurize", 1230 | "pressurised": "pressurized", 1231 | "pressurises": "pressurizes", 1232 | "pressurising": "pressurizing", 1233 | "pretence": "pretense", 1234 | "pretences": "pretenses", 1235 | "primaeval": "primeval", 1236 | "prioritisation": "prioritization", 1237 | "prioritise": "prioritize", 1238 | "prioritised": "prioritized", 1239 | "prioritises": "prioritizes", 1240 | "prioritising": "prioritizing", 1241 | "privatisation": "privatization", 1242 | "privatisations": "privatizations", 1243 | "privatise": "privatize", 1244 | "privatised": "privatized", 1245 | "privatises": "privatizes", 1246 | "privatising": "privatizing", 1247 | "professionalisation": "professionalization", 1248 | "professionalise": "professionalize", 1249 | "professionalised": "professionalized", 1250 | "professionalises": "professionalizes", 1251 | "professionalising": "professionalizing", 1252 | "programme": "program", 1253 | "programmes": "programs", 1254 | "prologue": "prolog", 1255 | "prologues": "prologs", 1256 | "propagandise": "propagandize", 1257 | "propagandised": "propagandized", 1258 | "propagandises": "propagandizes", 1259 | "propagandising": "propagandizing", 1260 | "proselytise": "proselytize", 1261 | "proselytised": "proselytized", 1262 | "proselytiser": "proselytizer", 1263 | "proselytisers": "proselytizers", 1264 | "proselytises": "proselytizes", 1265 | "proselytising": "proselytizing", 1266 | "psychoanalyse": "psychoanalyze", 1267 | "psychoanalysed": "psychoanalyzed", 1268 | "psychoanalyses": "psychoanalyzes", 1269 | "psychoanalysing": "psychoanalyzing", 1270 | "publicise": "publicize", 1271 | "publicised": "publicized", 1272 | "publicises": "publicizes", 1273 | "publicising": "publicizing", 1274 | "pulverisation": "pulverization", 1275 | "pulverise": "pulverize", 1276 | "pulverised": "pulverized", 1277 | "pulverises": "pulverizes", 1278 | "pulverising": "pulverizing", 1279 | "pummelled": "pummel", 1280 | "pummelling": "pummeled", 1281 | "pyjama": "pajama", 1282 | "pyjamas": "pajamas", 1283 | "pzazz": "pizzazz", 1284 | "quarrelled": "quarreled", 1285 | "quarrelling": "quarreling", 1286 | "radicalise": "radicalize", 1287 | "radicalised": "radicalized", 1288 | "radicalises": "radicalizes", 1289 | "radicalising": "radicalizing", 1290 | "rancour": "rancor", 1291 | "randomise": "randomize", 1292 | "randomised": "randomized", 1293 | "randomises": "randomizes", 1294 | "randomising": "randomizing", 1295 | "rationalisation": "rationalization", 1296 | "rationalisations": "rationalizations", 1297 | "rationalise": "rationalize", 1298 | "rationalised": "rationalized", 1299 | "rationalises": "rationalizes", 1300 | "rationalising": "rationalizing", 1301 | "ravelled": "raveled", 1302 | "ravelling": "raveling", 1303 | "realisable": "realizable", 1304 | "realisation": "realization", 1305 | "realisations": "realizations", 1306 | "realise": "realize", 1307 | "realised": "realized", 1308 | "realises": "realizes", 1309 | "realising": "realizing", 1310 | "recognisable": "recognizable", 1311 | "recognisably": "recognizably", 1312 | "recognisance": "recognizance", 1313 | "recognise": "recognize", 1314 | "recognised": "recognized", 1315 | "recognises": "recognizes", 1316 | "recognising": "recognizing", 1317 | "reconnoitre": "reconnoiter", 1318 | "reconnoitred": "reconnoitered", 1319 | "reconnoitres": "reconnoiters", 1320 | "reconnoitring": "reconnoitering", 1321 | "refuelled": "refueled", 1322 | "refuelling": "refueling", 1323 | "regularisation": "regularization", 1324 | "regularise": "regularize", 1325 | "regularised": "regularized", 1326 | "regularises": "regularizes", 1327 | "regularising": "regularizing", 1328 | "remodelled": "remodeled", 1329 | "remodelling": "remodeling", 1330 | "remould": "remold", 1331 | "remoulded": "remolded", 1332 | "remoulding": "remolding", 1333 | "remoulds": "remolds", 1334 | "reorganisation": "reorganization", 1335 | "reorganisations": "reorganizations", 1336 | "reorganise": "reorganize", 1337 | "reorganised": "reorganized", 1338 | "reorganises": "reorganizes", 1339 | "reorganising": "reorganizing", 1340 | "revelled": "reveled", 1341 | "reveller": "reveler", 1342 | "revellers": "revelers", 1343 | "revelling": "reveling", 1344 | "revitalise": "revitalize", 1345 | "revitalised": "revitalized", 1346 | "revitalises": "revitalizes", 1347 | "revitalising": "revitalizing", 1348 | "revolutionise": "revolutionize", 1349 | "revolutionised": "revolutionized", 1350 | "revolutionises": "revolutionizes", 1351 | "revolutionising": "revolutionizing", 1352 | "rhapsodise": "rhapsodize", 1353 | "rhapsodised": "rhapsodized", 1354 | "rhapsodises": "rhapsodizes", 1355 | "rhapsodising": "rhapsodizing", 1356 | "rigour": "rigor", 1357 | "rigours": "rigors", 1358 | "ritualised": "ritualized", 1359 | "rivalled": "rivaled", 1360 | "rivalling": "rivaling", 1361 | "romanticise": "romanticize", 1362 | "romanticised": "romanticized", 1363 | "romanticises": "romanticizes", 1364 | "romanticising": "romanticizing", 1365 | "rumour": "rumor", 1366 | "rumoured": "rumored", 1367 | "rumours": "rumors", 1368 | "sabre": "saber", 1369 | "sabres": "sabers", 1370 | "saltpetre": "saltpeter", 1371 | "sanitise": "sanitize", 1372 | "sanitised": "sanitized", 1373 | "sanitises": "sanitizes", 1374 | "sanitising": "sanitizing", 1375 | "satirise": "satirize", 1376 | "satirised": "satirized", 1377 | "satirises": "satirizes", 1378 | "satirising": "satirizing", 1379 | "saviour": "savior", 1380 | "saviours": "saviors", 1381 | "savour": "savor", 1382 | "savoured": "savored", 1383 | "savouries": "savories", 1384 | "savouring": "savoring", 1385 | "savours": "savors", 1386 | "savoury": "savory", 1387 | "scandalise": "scandalize", 1388 | "scandalised": "scandalized", 1389 | "scandalises": "scandalizes", 1390 | "scandalising": "scandalizing", 1391 | "sceptic": "skeptic", 1392 | "sceptical": "skeptical", 1393 | "sceptically": "skeptically", 1394 | "scepticism": "skepticism", 1395 | "sceptics": "skeptics", 1396 | "sceptre": "scepter", 1397 | "sceptres": "scepters", 1398 | "scrutinise": "scrutinize", 1399 | "scrutinised": "scrutinized", 1400 | "scrutinises": "scrutinizes", 1401 | "scrutinising": "scrutinizing", 1402 | "secularisation": "secularization", 1403 | "secularise": "secularize", 1404 | "secularised": "secularized", 1405 | "secularises": "secularizes", 1406 | "secularising": "secularizing", 1407 | "sensationalise": "sensationalize", 1408 | "sensationalised": "sensationalized", 1409 | "sensationalises": "sensationalizes", 1410 | "sensationalising": "sensationalizing", 1411 | "sensitise": "sensitize", 1412 | "sensitised": "sensitized", 1413 | "sensitises": "sensitizes", 1414 | "sensitising": "sensitizing", 1415 | "sentimentalise": "sentimentalize", 1416 | "sentimentalised": "sentimentalized", 1417 | "sentimentalises": "sentimentalizes", 1418 | "sentimentalising": "sentimentalizing", 1419 | "sepulchre": "sepulcher", 1420 | "sepulchres": "sepulchers", 1421 | "serialisation": "serialization", 1422 | "serialisations": "serializations", 1423 | "serialise": "serialize", 1424 | "serialised": "serialized", 1425 | "serialises": "serializes", 1426 | "serialising": "serializing", 1427 | "sermonise": "sermonize", 1428 | "sermonised": "sermonized", 1429 | "sermonises": "sermonizes", 1430 | "sermonising": "sermonizing", 1431 | "sheikh": "sheik", 1432 | "shovelled": "shoveled", 1433 | "shovelling": "shoveling", 1434 | "shrivelled": "shriveled", 1435 | "shrivelling": "shriveling", 1436 | "signalise": "signalize", 1437 | "signalised": "signalized", 1438 | "signalises": "signalizes", 1439 | "signalising": "signalizing", 1440 | "signalled": "signaled", 1441 | "signalling": "signaling", 1442 | "smoulder": "smolder", 1443 | "smouldered": "smoldered", 1444 | "smouldering": "smoldering", 1445 | "smoulders": "smolders", 1446 | "snivelled": "sniveled", 1447 | "snivelling": "sniveling", 1448 | "snorkelled": "snorkeled", 1449 | "snorkelling": "snorkeling", 1450 | "snowplough": "snowplow", 1451 | "snowploughs": "snowplow", 1452 | "socialisation": "socialization", 1453 | "socialise": "socialize", 1454 | "socialised": "socialized", 1455 | "socialises": "socializes", 1456 | "socialising": "socializing", 1457 | "sodomise": "sodomize", 1458 | "sodomised": "sodomized", 1459 | "sodomises": "sodomizes", 1460 | "sodomising": "sodomizing", 1461 | "solemnise": "solemnize", 1462 | "solemnised": "solemnized", 1463 | "solemnises": "solemnizes", 1464 | "solemnising": "solemnizing", 1465 | "sombre": "somber", 1466 | "specialisation": "specialization", 1467 | "specialisations": "specializations", 1468 | "specialise": "specialize", 1469 | "specialised": "specialized", 1470 | "specialises": "specializes", 1471 | "specialising": "specializing", 1472 | "spectre": "specter", 1473 | "spectres": "specters", 1474 | "spiralled": "spiraled", 1475 | "spiralling": "spiraling", 1476 | "splendour": "splendor", 1477 | "splendours": "splendors", 1478 | "squirrelled": "squirreled", 1479 | "squirrelling": "squirreling", 1480 | "stabilisation": "stabilization", 1481 | "stabilise": "stabilize", 1482 | "stabilised": "stabilized", 1483 | "stabiliser": "stabilizer", 1484 | "stabilisers": "stabilizers", 1485 | "stabilises": "stabilizes", 1486 | "stabilising": "stabilizing", 1487 | "standardisation": "standardization", 1488 | "standardise": "standardize", 1489 | "standardised": "standardized", 1490 | "standardises": "standardizes", 1491 | "standardising": "standardizing", 1492 | "stencilled": "stenciled", 1493 | "stencilling": "stenciling", 1494 | "sterilisation": "sterilization", 1495 | "sterilisations": "sterilizations", 1496 | "sterilise": "sterilize", 1497 | "sterilised": "sterilized", 1498 | "steriliser": "sterilizer", 1499 | "sterilisers": "sterilizers", 1500 | "sterilises": "sterilizes", 1501 | "sterilising": "sterilizing", 1502 | "stigmatisation": "stigmatization", 1503 | "stigmatise": "stigmatize", 1504 | "stigmatised": "stigmatized", 1505 | "stigmatises": "stigmatizes", 1506 | "stigmatising": "stigmatizing", 1507 | "storey": "story", 1508 | "storeys": "stories", 1509 | "subsidisation": "subsidization", 1510 | "subsidise": "subsidize", 1511 | "subsidised": "subsidized", 1512 | "subsidiser": "subsidizer", 1513 | "subsidisers": "subsidizers", 1514 | "subsidises": "subsidizes", 1515 | "subsidising": "subsidizing", 1516 | "succour": "succor", 1517 | "succoured": "succored", 1518 | "succouring": "succoring", 1519 | "succours": "succors", 1520 | "sulphate": "sulfate", 1521 | "sulphates": "sulfates", 1522 | "sulphide": "sulfide", 1523 | "sulphides": "sulfides", 1524 | "sulphur": "sulfur", 1525 | "sulphurous": "sulfurous", 1526 | "summarise": "summarize", 1527 | "summarised": "summarized", 1528 | "summarises": "summarizes", 1529 | "summarising": "summarizing", 1530 | "swivelled": "swiveled", 1531 | "swivelling": "swiveling", 1532 | "symbolise": "symbolize", 1533 | "symbolised": "symbolized", 1534 | "symbolises": "symbolizes", 1535 | "symbolising": "symbolizing", 1536 | "sympathise": "sympathize", 1537 | "sympathised": "sympathized", 1538 | "sympathiser": "sympathizer", 1539 | "sympathisers": "sympathizers", 1540 | "sympathises": "sympathizes", 1541 | "sympathising": "sympathizing", 1542 | "synchronisation": "synchronization", 1543 | "synchronise": "synchronize", 1544 | "synchronised": "synchronized", 1545 | "synchronises": "synchronizes", 1546 | "synchronising": "synchronizing", 1547 | "synthesise": "synthesize", 1548 | "synthesised": "synthesized", 1549 | "synthesiser": "synthesizer", 1550 | "synthesisers": "synthesizers", 1551 | "synthesises": "synthesizes", 1552 | "synthesising": "synthesizing", 1553 | "syphon": "siphon", 1554 | "syphoned": "siphoned", 1555 | "syphoning": "siphoning", 1556 | "syphons": "siphons", 1557 | "systematisation": "systematization", 1558 | "systematise": "systematize", 1559 | "systematised": "systematized", 1560 | "systematises": "systematizes", 1561 | "systematising": "systematizing", 1562 | "tantalise": "tantalize", 1563 | "tantalised": "tantalized", 1564 | "tantalises": "tantalizes", 1565 | "tantalising": "tantalizing", 1566 | "tantalisingly": "tantalizingly", 1567 | "tasselled": "tasseled", 1568 | "technicolour": "technicolor", 1569 | "temporise": "temporize", 1570 | "temporised": "temporized", 1571 | "temporises": "temporizes", 1572 | "temporising": "temporizing", 1573 | "tenderise": "tenderize", 1574 | "tenderised": "tenderized", 1575 | "tenderises": "tenderizes", 1576 | "tenderising": "tenderizing", 1577 | "terrorise": "terrorize", 1578 | "terrorised": "terrorized", 1579 | "terrorises": "terrorizes", 1580 | "terrorising": "terrorizing", 1581 | "theatre": "theater", 1582 | "theatregoer": "theatergoer", 1583 | "theatregoers": "theatergoers", 1584 | "theatres": "theaters", 1585 | "theorise": "theorize", 1586 | "theorised": "theorized", 1587 | "theorises": "theorizes", 1588 | "theorising": "theorizing", 1589 | "tonne": "ton", 1590 | "tonnes": "tons", 1591 | "towelled": "toweled", 1592 | "towelling": "toweling", 1593 | "toxaemia": "toxemia", 1594 | "tranquillise": "tranquilize", 1595 | "tranquillised": "tranquilized", 1596 | "tranquilliser": "tranquilizer", 1597 | "tranquillisers": "tranquilizers", 1598 | "tranquillises": "tranquilizes", 1599 | "tranquillising": "tranquilizing", 1600 | "tranquillity": "tranquility", 1601 | "tranquillize": "tranquilize", 1602 | "tranquillized": "tranquilized", 1603 | "tranquillizer": "tranquilizer", 1604 | "tranquillizers": "tranquilizers", 1605 | "tranquillizes": "tranquilizes", 1606 | "tranquillizing": "tranquilizing", 1607 | "tranquilly": "tranquility", 1608 | "transistorised": "transistorized", 1609 | "traumatise": "traumatize", 1610 | "traumatised": "traumatized", 1611 | "traumatises": "traumatizes", 1612 | "traumatising": "traumatizing", 1613 | "travelled": "traveled", 1614 | "traveller": "traveler", 1615 | "travellers": "travelers", 1616 | "travelling": "traveling", 1617 | "travelog": "travelogue", 1618 | "travelogs": "travelogues", 1619 | "trialled": "trialed", 1620 | "trialling": "trialing", 1621 | "tricolour": "tricolor", 1622 | "tricolours": "tricolors", 1623 | "trivialise": "trivialize", 1624 | "trivialised": "trivialized", 1625 | "trivialises": "trivializes", 1626 | "trivialising": "trivializing", 1627 | "tumour": "tumor", 1628 | "tumours": "tumors", 1629 | "tunnelled": "tunneled", 1630 | "tunnelling": "tunneling", 1631 | "tyrannise": "tyrannize", 1632 | "tyrannised": "tyrannized", 1633 | "tyrannises": "tyrannizes", 1634 | "tyrannising": "tyrannizing", 1635 | "tyre": "tire", 1636 | "tyres": "tires", 1637 | "unauthorised": "unauthorized", 1638 | "uncivilised": "uncivilized", 1639 | "underutilised": "underutilized", 1640 | "unequalled": "unequaled", 1641 | "unfavourable": "unfavorable", 1642 | "unfavourably": "unfavorably", 1643 | "unionisation": "unionization", 1644 | "unionise": "unionize", 1645 | "unionised": "unionized", 1646 | "unionises": "unionizes", 1647 | "unionising": "unionizing", 1648 | "unorganised": "unorganized", 1649 | "unravelled": "unraveled", 1650 | "unravelling": "unraveling", 1651 | "unrecognisable": "unrecognizable", 1652 | "unrecognised": "unrecognized", 1653 | "unrivalled": "unrivaled", 1654 | "unsavoury": "unsavory", 1655 | "untrammelled": "untrammeled", 1656 | "urbanisation": "urbanization", 1657 | "urbanise": "urbanize", 1658 | "urbanised": "urbanized", 1659 | "urbanises": "urbanizes", 1660 | "urbanising": "urbanizing", 1661 | "utilisable": "utilizable", 1662 | "utilisation": "utilization", 1663 | "utilise": "utilize", 1664 | "utilised": "utilized", 1665 | "utilises": "utilizes", 1666 | "utilising": "utilizing", 1667 | "valour": "valor", 1668 | "vandalise": "vandalize", 1669 | "vandalised": "vandalized", 1670 | "vandalises": "vandalizes", 1671 | "vandalising": "vandalizing", 1672 | "vaporisation": "vaporization", 1673 | "vaporise": "vaporize", 1674 | "vaporised": "vaporized", 1675 | "vaporises": "vaporizes", 1676 | "vaporising": "vaporizing", 1677 | "vapour": "vapor", 1678 | "vapours": "vapors", 1679 | "verbalise": "verbalize", 1680 | "verbalised": "verbalized", 1681 | "verbalises": "verbalizes", 1682 | "verbalising": "verbalizing", 1683 | "victimisation": "victimization", 1684 | "victimise": "victimize", 1685 | "victimised": "victimized", 1686 | "victimises": "victimizes", 1687 | "victimising": "victimizing", 1688 | "videodisc": "videodisk", 1689 | "videodiscs": "videodisks", 1690 | "vigour": "vigor", 1691 | "visualisation": "visualization", 1692 | "visualisations": "visualizations", 1693 | "visualise": "visualize", 1694 | "visualised": "visualized", 1695 | "visualises": "visualizes", 1696 | "visualising": "visualizing", 1697 | "vocalisation": "vocalization", 1698 | "vocalisations": "vocalizations", 1699 | "vocalise": "vocalize", 1700 | "vocalised": "vocalized", 1701 | "vocalises": "vocalizes", 1702 | "vocalising": "vocalizing", 1703 | "vulcanised": "vulcanized", 1704 | "vulgarisation": "vulgarization", 1705 | "vulgarise": "vulgarize", 1706 | "vulgarised": "vulgarized", 1707 | "vulgarises": "vulgarizes", 1708 | "vulgarising": "vulgarizing", 1709 | "waggon": "wagon", 1710 | "waggons": "wagons", 1711 | "watercolour": "watercolor", 1712 | "watercolours": "watercolors", 1713 | "weaselled": "weaseled", 1714 | "weaselling": "weaseling", 1715 | "westernisation": "westernization", 1716 | "westernise": "westernize", 1717 | "westernised": "westernized", 1718 | "westernises": "westernizes", 1719 | "westernising": "westernizing", 1720 | "womanise": "womanize", 1721 | "womanised": "womanized", 1722 | "womaniser": "womanizer", 1723 | "womanisers": "womanizers", 1724 | "womanises": "womanizes", 1725 | "womanising": "womanizing", 1726 | "woollen": "woolen", 1727 | "woollens": "woolens", 1728 | "woollies": "woolies", 1729 | "woolly": "wooly", 1730 | "worshipped": "worshiped", 1731 | "worshipping": "worshiping", 1732 | "worshipper": "worshiper", 1733 | "yodelled": "yodeled", 1734 | "yodelling": "yodeling", 1735 | "yoghourt": "yogurt", 1736 | "yoghourts": "yogurts", 1737 | "yoghurt": "yogurt", 1738 | "yoghurts": "yogurts", 1739 | "mhm": "hmm", 1740 | "mm": "hmm", 1741 | "mmm": "hmm" 1742 | } -------------------------------------------------------------------------------- /whisper/normalizers/english.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | from fractions import Fraction 5 | from typing import Iterator, List, Match, Optional, Union 6 | 7 | from more_itertools import windowed 8 | 9 | from whisper.normalizers.basic import remove_symbols_and_diacritics 10 | 11 | 12 | class EnglishNumberNormalizer: 13 | """ 14 | Convert any spelled-out numbers into arabic numbers, while handling: 15 | 16 | - remove any commas 17 | - keep the suffixes such as: `1960s`, `274th`, `32nd`, etc. 18 | - spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars` 19 | - spell out `one` and `ones` 20 | - interpret successive single-digit numbers as nominal: `one oh one` -> `101` 21 | """ 22 | 23 | def __init__(self): 24 | super().__init__() 25 | 26 | self.zeros = {"o", "oh", "zero"} 27 | self.ones = { 28 | name: i 29 | for i, name in enumerate( 30 | [ 31 | "one", 32 | "two", 33 | "three", 34 | "four", 35 | "five", 36 | "six", 37 | "seven", 38 | "eight", 39 | "nine", 40 | "ten", 41 | "eleven", 42 | "twelve", 43 | "thirteen", 44 | "fourteen", 45 | "fifteen", 46 | "sixteen", 47 | "seventeen", 48 | "eighteen", 49 | "nineteen", 50 | ], 51 | start=1, 52 | ) 53 | } 54 | self.ones_plural = { 55 | "sixes" if name == "six" else name + "s": (value, "s") 56 | for name, value in self.ones.items() 57 | } 58 | self.ones_ordinal = { 59 | "zeroth": (0, "th"), 60 | "first": (1, "st"), 61 | "second": (2, "nd"), 62 | "third": (3, "rd"), 63 | "fifth": (5, "th"), 64 | "twelfth": (12, "th"), 65 | **{ 66 | name + ("h" if name.endswith("t") else "th"): (value, "th") 67 | for name, value in self.ones.items() 68 | if value > 3 and value != 5 and value != 12 69 | }, 70 | } 71 | self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal} 72 | 73 | self.tens = { 74 | "twenty": 20, 75 | "thirty": 30, 76 | "forty": 40, 77 | "fifty": 50, 78 | "sixty": 60, 79 | "seventy": 70, 80 | "eighty": 80, 81 | "ninety": 90, 82 | } 83 | self.tens_plural = { 84 | name.replace("y", "ies"): (value, "s") for name, value in self.tens.items() 85 | } 86 | self.tens_ordinal = { 87 | name.replace("y", "ieth"): (value, "th") for name, value in self.tens.items() 88 | } 89 | self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal} 90 | 91 | self.multipliers = { 92 | "hundred": 100, 93 | "thousand": 1_000, 94 | "million": 1_000_000, 95 | "billion": 1_000_000_000, 96 | "trillion": 1_000_000_000_000, 97 | "quadrillion": 1_000_000_000_000_000, 98 | "quintillion": 1_000_000_000_000_000_000, 99 | "sextillion": 1_000_000_000_000_000_000_000, 100 | "septillion": 1_000_000_000_000_000_000_000_000, 101 | "octillion": 1_000_000_000_000_000_000_000_000_000, 102 | "nonillion": 1_000_000_000_000_000_000_000_000_000_000, 103 | "decillion": 1_000_000_000_000_000_000_000_000_000_000_000, 104 | } 105 | self.multipliers_plural = { 106 | name + "s": (value, "s") for name, value in self.multipliers.items() 107 | } 108 | self.multipliers_ordinal = { 109 | name + "th": (value, "th") for name, value in self.multipliers.items() 110 | } 111 | self.multipliers_suffixed = {**self.multipliers_plural, **self.multipliers_ordinal} 112 | self.decimals = {*self.ones, *self.tens, *self.zeros} 113 | 114 | self.preceding_prefixers = { 115 | "minus": "-", 116 | "negative": "-", 117 | "plus": "+", 118 | "positive": "+", 119 | } 120 | self.following_prefixers = { 121 | "pound": "£", 122 | "pounds": "£", 123 | "euro": "€", 124 | "euros": "€", 125 | "dollar": "$", 126 | "dollars": "$", 127 | "cent": "¢", 128 | "cents": "¢", 129 | } 130 | self.prefixes = set( 131 | list(self.preceding_prefixers.values()) + list(self.following_prefixers.values()) 132 | ) 133 | self.suffixers = { 134 | "per": {"cent": "%"}, 135 | "percent": "%", 136 | } 137 | self.specials = {"and", "double", "triple", "point"} 138 | 139 | self.words = set( 140 | [ 141 | key 142 | for mapping in [ 143 | self.zeros, 144 | self.ones, 145 | self.ones_suffixed, 146 | self.tens, 147 | self.tens_suffixed, 148 | self.multipliers, 149 | self.multipliers_suffixed, 150 | self.preceding_prefixers, 151 | self.following_prefixers, 152 | self.suffixers, 153 | self.specials, 154 | ] 155 | for key in mapping 156 | ] 157 | ) 158 | self.literal_words = {"one", "ones"} 159 | 160 | def process_words(self, words: List[str]) -> Iterator[str]: 161 | prefix: Optional[str] = None 162 | value: Optional[Union[str, int]] = None 163 | skip = False 164 | 165 | def to_fraction(s: str): 166 | try: 167 | return Fraction(s) 168 | except ValueError: 169 | return None 170 | 171 | def output(result: Union[str, int]): 172 | nonlocal prefix, value 173 | result = str(result) 174 | if prefix is not None: 175 | result = prefix + result 176 | value = None 177 | prefix = None 178 | return result 179 | 180 | if len(words) == 0: 181 | return 182 | 183 | for prev, current, next in windowed([None] + words + [None], 3): 184 | if skip: 185 | skip = False 186 | continue 187 | 188 | next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next) 189 | has_prefix = current[0] in self.prefixes 190 | current_without_prefix = current[1:] if has_prefix else current 191 | if re.match(r"^\d+(\.\d+)?$", current_without_prefix): 192 | # arabic numbers (potentially with signs and fractions) 193 | f = to_fraction(current_without_prefix) 194 | assert f is not None 195 | if value is not None: 196 | if isinstance(value, str) and value.endswith("."): 197 | # concatenate decimals / ip address components 198 | value = str(value) + str(current) 199 | continue 200 | else: 201 | yield output(value) 202 | 203 | prefix = current[0] if has_prefix else prefix 204 | if f.denominator == 1: 205 | value = f.numerator # store integers as int 206 | else: 207 | value = current_without_prefix 208 | elif current not in self.words: 209 | # non-numeric words 210 | if value is not None: 211 | yield output(value) 212 | yield output(current) 213 | elif current in self.zeros: 214 | value = str(value or "") + "0" 215 | elif current in self.ones: 216 | ones = self.ones[current] 217 | 218 | if value is None: 219 | value = ones 220 | elif isinstance(value, str) or prev in self.ones: 221 | if prev in self.tens and ones < 10: # replace the last zero with the digit 222 | assert value[-1] == "0" 223 | value = value[:-1] + str(ones) 224 | else: 225 | value = str(value) + str(ones) 226 | elif ones < 10: 227 | if value % 10 == 0: 228 | value += ones 229 | else: 230 | value = str(value) + str(ones) 231 | else: # eleven to nineteen 232 | if value % 100 == 0: 233 | value += ones 234 | else: 235 | value = str(value) + str(ones) 236 | elif current in self.ones_suffixed: 237 | # ordinal or cardinal; yield the number right away 238 | ones, suffix = self.ones_suffixed[current] 239 | if value is None: 240 | yield output(str(ones) + suffix) 241 | elif isinstance(value, str) or prev in self.ones: 242 | if prev in self.tens and ones < 10: 243 | assert value[-1] == "0" 244 | yield output(value[:-1] + str(ones) + suffix) 245 | else: 246 | yield output(str(value) + str(ones) + suffix) 247 | elif ones < 10: 248 | if value % 10 == 0: 249 | yield output(str(value + ones) + suffix) 250 | else: 251 | yield output(str(value) + str(ones) + suffix) 252 | else: # eleven to nineteen 253 | if value % 100 == 0: 254 | yield output(str(value + ones) + suffix) 255 | else: 256 | yield output(str(value) + str(ones) + suffix) 257 | value = None 258 | elif current in self.tens: 259 | tens = self.tens[current] 260 | if value is None: 261 | value = tens 262 | elif isinstance(value, str): 263 | value = str(value) + str(tens) 264 | else: 265 | if value % 100 == 0: 266 | value += tens 267 | else: 268 | value = str(value) + str(tens) 269 | elif current in self.tens_suffixed: 270 | # ordinal or cardinal; yield the number right away 271 | tens, suffix = self.tens_suffixed[current] 272 | if value is None: 273 | yield output(str(tens) + suffix) 274 | elif isinstance(value, str): 275 | yield output(str(value) + str(tens) + suffix) 276 | else: 277 | if value % 100 == 0: 278 | yield output(str(value + tens) + suffix) 279 | else: 280 | yield output(str(value) + str(tens) + suffix) 281 | elif current in self.multipliers: 282 | multiplier = self.multipliers[current] 283 | if value is None: 284 | value = multiplier 285 | elif isinstance(value, str) or value == 0: 286 | f = to_fraction(value) 287 | p = f * multiplier if f is not None else None 288 | if f is not None and p.denominator == 1: 289 | value = p.numerator 290 | else: 291 | yield output(value) 292 | value = multiplier 293 | else: 294 | before = value // 1000 * 1000 295 | residual = value % 1000 296 | value = before + residual * multiplier 297 | elif current in self.multipliers_suffixed: 298 | multiplier, suffix = self.multipliers_suffixed[current] 299 | if value is None: 300 | yield output(str(multiplier) + suffix) 301 | elif isinstance(value, str): 302 | f = to_fraction(value) 303 | p = f * multiplier if f is not None else None 304 | if f is not None and p.denominator == 1: 305 | yield output(str(p.numerator) + suffix) 306 | else: 307 | yield output(value) 308 | yield output(str(multiplier) + suffix) 309 | else: # int 310 | before = value // 1000 * 1000 311 | residual = value % 1000 312 | value = before + residual * multiplier 313 | yield output(str(value) + suffix) 314 | value = None 315 | elif current in self.preceding_prefixers: 316 | # apply prefix (positive, minus, etc.) if it precedes a number 317 | if value is not None: 318 | yield output(value) 319 | 320 | if next in self.words or next_is_numeric: 321 | prefix = self.preceding_prefixers[current] 322 | else: 323 | yield output(current) 324 | elif current in self.following_prefixers: 325 | # apply prefix (dollars, cents, etc.) only after a number 326 | if value is not None: 327 | prefix = self.following_prefixers[current] 328 | yield output(value) 329 | else: 330 | yield output(current) 331 | elif current in self.suffixers: 332 | # apply suffix symbols (percent -> '%') 333 | if value is not None: 334 | suffix = self.suffixers[current] 335 | if isinstance(suffix, dict): 336 | if next in suffix: 337 | yield output(str(value) + suffix[next]) 338 | skip = True 339 | else: 340 | yield output(value) 341 | yield output(current) 342 | else: 343 | yield output(str(value) + suffix) 344 | else: 345 | yield output(current) 346 | elif current in self.specials: 347 | if next not in self.words and not next_is_numeric: 348 | # apply special handling only if the next word can be numeric 349 | if value is not None: 350 | yield output(value) 351 | yield output(current) 352 | elif current == "and": 353 | # ignore "and" after hundreds, thousands, etc. 354 | if prev not in self.multipliers: 355 | if value is not None: 356 | yield output(value) 357 | yield output(current) 358 | elif current == "double" or current == "triple": 359 | if next in self.ones or next in self.zeros: 360 | repeats = 2 if current == "double" else 3 361 | ones = self.ones.get(next, 0) 362 | value = str(value or "") + str(ones) * repeats 363 | skip = True 364 | else: 365 | if value is not None: 366 | yield output(value) 367 | yield output(current) 368 | elif current == "point": 369 | if next in self.decimals or next_is_numeric: 370 | value = str(value or "") + "." 371 | else: 372 | # should all have been covered at this point 373 | raise ValueError(f"Unexpected token: {current}") 374 | else: 375 | # all should have been covered at this point 376 | raise ValueError(f"Unexpected token: {current}") 377 | 378 | if value is not None: 379 | yield output(value) 380 | 381 | def preprocess(self, s: str): 382 | # replace " and a half" with " point five" 383 | results = [] 384 | 385 | segments = re.split(r"\band\s+a\s+half\b", s) 386 | for i, segment in enumerate(segments): 387 | if len(segment.strip()) == 0: 388 | continue 389 | if i == len(segments) - 1: 390 | results.append(segment) 391 | else: 392 | results.append(segment) 393 | last_word = segment.rsplit(maxsplit=2)[-1] 394 | if last_word in self.decimals or last_word in self.multipliers: 395 | results.append("point five") 396 | else: 397 | results.append("and a half") 398 | 399 | s = " ".join(results) 400 | 401 | # put a space at number/letter boundary 402 | s = re.sub(r"([a-z])([0-9])", r"\1 \2", s) 403 | s = re.sub(r"([0-9])([a-z])", r"\1 \2", s) 404 | 405 | # but remove spaces which could be a suffix 406 | s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s) 407 | 408 | return s 409 | 410 | def postprocess(self, s: str): 411 | def combine_cents(m: Match): 412 | try: 413 | currency = m.group(1) 414 | integer = m.group(2) 415 | cents = int(m.group(3)) 416 | return f"{currency}{integer}.{cents:02d}" 417 | except ValueError: 418 | return m.string 419 | 420 | def extract_cents(m: Match): 421 | try: 422 | return f"¢{int(m.group(1))}" 423 | except ValueError: 424 | return m.string 425 | 426 | # apply currency postprocessing; "$2 and ¢7" -> "$2.07" 427 | s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s) 428 | s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s) 429 | 430 | # write "one(s)" instead of "1(s)", just for the readability 431 | s = re.sub(r"\b1(s?)\b", r"one\1", s) 432 | 433 | return s 434 | 435 | def __call__(self, s: str): 436 | s = self.preprocess(s) 437 | s = " ".join(word for word in self.process_words(s.split()) if word is not None) 438 | s = self.postprocess(s) 439 | 440 | return s 441 | 442 | 443 | class EnglishSpellingNormalizer: 444 | """ 445 | Applies British-American spelling mappings as listed in [1]. 446 | 447 | [1] https://www.tysto.com/uk-us-spelling-list.html 448 | """ 449 | 450 | def __init__(self): 451 | mapping_path = os.path.join(os.path.dirname(__file__), "english.json") 452 | self.mapping = json.load(open(mapping_path)) 453 | 454 | def __call__(self, s: str): 455 | return " ".join(self.mapping.get(word, word) for word in s.split()) 456 | 457 | 458 | class EnglishTextNormalizer: 459 | def __init__(self): 460 | self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b" 461 | self.replacers = { 462 | # common contractions 463 | r"\bwon't\b": "will not", 464 | r"\bcan't\b": "can not", 465 | r"\blet's\b": "let us", 466 | r"\bain't\b": "aint", 467 | r"\by'all\b": "you all", 468 | r"\bwanna\b": "want to", 469 | r"\bgotta\b": "got to", 470 | r"\bgonna\b": "going to", 471 | r"\bi'ma\b": "i am going to", 472 | r"\bimma\b": "i am going to", 473 | r"\bwoulda\b": "would have", 474 | r"\bcoulda\b": "could have", 475 | r"\bshoulda\b": "should have", 476 | r"\bma'am\b": "madam", 477 | # contractions in titles/prefixes 478 | r"\bmr\b": "mister ", 479 | r"\bmrs\b": "missus ", 480 | r"\bst\b": "saint ", 481 | r"\bdr\b": "doctor ", 482 | r"\bprof\b": "professor ", 483 | r"\bcapt\b": "captain ", 484 | r"\bgov\b": "governor ", 485 | r"\bald\b": "alderman ", 486 | r"\bgen\b": "general ", 487 | r"\bsen\b": "senator ", 488 | r"\brep\b": "representative ", 489 | r"\bpres\b": "president ", 490 | r"\brev\b": "reverend ", 491 | r"\bhon\b": "honorable ", 492 | r"\basst\b": "assistant ", 493 | r"\bassoc\b": "associate ", 494 | r"\blt\b": "lieutenant ", 495 | r"\bcol\b": "colonel ", 496 | r"\bjr\b": "junior ", 497 | r"\bsr\b": "senior ", 498 | r"\besq\b": "esquire ", 499 | # prefect tenses, ideally it should be any past participles, but it's harder.. 500 | r"'d been\b": " had been", 501 | r"'s been\b": " has been", 502 | r"'d gone\b": " had gone", 503 | r"'s gone\b": " has gone", 504 | r"'d done\b": " had done", # "'s done" is ambiguous 505 | r"'s got\b": " has got", 506 | # general contractions 507 | r"n't\b": " not", 508 | r"'re\b": " are", 509 | r"'s\b": " is", 510 | r"'d\b": " would", 511 | r"'ll\b": " will", 512 | r"'t\b": " not", 513 | r"'ve\b": " have", 514 | r"'m\b": " am", 515 | } 516 | self.standardize_numbers = EnglishNumberNormalizer() 517 | self.standardize_spellings = EnglishSpellingNormalizer() 518 | 519 | def __call__(self, s: str): 520 | s = s.lower() 521 | 522 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets 523 | s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis 524 | s = re.sub(self.ignore_patterns, "", s) 525 | s = re.sub(r"\s+'", "'", s) # standardize when there's a space before an apostrophe 526 | 527 | for pattern, replacement in self.replacers.items(): 528 | s = re.sub(pattern, replacement, s) 529 | 530 | s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits 531 | s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers 532 | s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep some symbols for numerics 533 | 534 | s = self.standardize_numbers(s) 535 | s = self.standardize_spellings(s) 536 | 537 | # now remove prefix/suffix symbols that are not preceded/followed by numbers 538 | s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s) 539 | s = re.sub(r"([^0-9])%", r"\1 ", s) 540 | 541 | s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space 542 | 543 | return s 544 | -------------------------------------------------------------------------------- /whisper/tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from functools import lru_cache 4 | from typing import List, Optional, Tuple, Union 5 | 6 | import numpy as np 7 | from transformers import GPT2TokenizerFast 8 | 9 | LANGUAGES = { 10 | "en": "english", 11 | "zh": "chinese", 12 | "de": "german", 13 | "es": "spanish", 14 | "ru": "russian", 15 | "ko": "korean", 16 | "fr": "french", 17 | "ja": "japanese", 18 | "pt": "portuguese", 19 | "tr": "turkish", 20 | "pl": "polish", 21 | "ca": "catalan", 22 | "nl": "dutch", 23 | "ar": "arabic", 24 | "sv": "swedish", 25 | "it": "italian", 26 | "id": "indonesian", 27 | "hi": "hindi", 28 | "fi": "finnish", 29 | "vi": "vietnamese", 30 | "iw": "hebrew", 31 | "uk": "ukrainian", 32 | "el": "greek", 33 | "ms": "malay", 34 | "cs": "czech", 35 | "ro": "romanian", 36 | "da": "danish", 37 | "hu": "hungarian", 38 | "ta": "tamil", 39 | "no": "norwegian", 40 | "th": "thai", 41 | "ur": "urdu", 42 | "hr": "croatian", 43 | "bg": "bulgarian", 44 | "lt": "lithuanian", 45 | "la": "latin", 46 | "mi": "maori", 47 | "ml": "malayalam", 48 | "cy": "welsh", 49 | "sk": "slovak", 50 | "te": "telugu", 51 | "fa": "persian", 52 | "lv": "latvian", 53 | "bn": "bengali", 54 | "sr": "serbian", 55 | "az": "azerbaijani", 56 | "sl": "slovenian", 57 | "kn": "kannada", 58 | "et": "estonian", 59 | "mk": "macedonian", 60 | "br": "breton", 61 | "eu": "basque", 62 | "is": "icelandic", 63 | "hy": "armenian", 64 | "ne": "nepali", 65 | "mn": "mongolian", 66 | "bs": "bosnian", 67 | "kk": "kazakh", 68 | "sq": "albanian", 69 | "sw": "swahili", 70 | "gl": "galician", 71 | "mr": "marathi", 72 | "pa": "punjabi", 73 | "si": "sinhala", 74 | "km": "khmer", 75 | "sn": "shona", 76 | "yo": "yoruba", 77 | "so": "somali", 78 | "af": "afrikaans", 79 | "oc": "occitan", 80 | "ka": "georgian", 81 | "be": "belarusian", 82 | "tg": "tajik", 83 | "sd": "sindhi", 84 | "gu": "gujarati", 85 | "am": "amharic", 86 | "yi": "yiddish", 87 | "lo": "lao", 88 | "uz": "uzbek", 89 | "fo": "faroese", 90 | "ht": "haitian creole", 91 | "ps": "pashto", 92 | "tk": "turkmen", 93 | "nn": "nynorsk", 94 | "mt": "maltese", 95 | "sa": "sanskrit", 96 | "lb": "luxembourgish", 97 | "my": "myanmar", 98 | "bo": "tibetan", 99 | "tl": "tagalog", 100 | "mg": "malagasy", 101 | "as": "assamese", 102 | "tt": "tatar", 103 | "haw": "hawaiian", 104 | "ln": "lingala", 105 | "ha": "hausa", 106 | "ba": "bashkir", 107 | "jw": "javanese", 108 | "su": "sundanese", 109 | } 110 | 111 | # language code lookup by name, with a few language aliases 112 | TO_LANGUAGE_CODE = { 113 | **{language: code for code, language in LANGUAGES.items()}, 114 | "burmese": "my", 115 | "valencian": "ca", 116 | "flemish": "nl", 117 | "haitian": "ht", 118 | "letzeburgesch": "lb", 119 | "pushto": "ps", 120 | "panjabi": "pa", 121 | "moldavian": "ro", 122 | "moldovan": "ro", 123 | "sinhalese": "si", 124 | "castilian": "es", 125 | } 126 | 127 | 128 | @dataclass(frozen=True) 129 | class Tokenizer: 130 | """A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens""" 131 | 132 | tokenizer: "GPT2TokenizerFast" 133 | language: Optional[str] 134 | sot_sequence: Tuple[int] 135 | 136 | def encode(self, text, **kwargs): 137 | return self.tokenizer.encode(text, **kwargs) 138 | 139 | def decode(self, token_ids: Union[int, List[int], np.ndarray], **kwargs): 140 | return self.tokenizer.decode(token_ids, **kwargs) 141 | 142 | def decode_with_timestamps(self, tokens) -> str: 143 | """ 144 | Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. 145 | This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". 146 | """ 147 | outputs = [[]] 148 | for token in tokens: 149 | if token >= self.timestamp_begin: 150 | timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>" 151 | outputs.append(timestamp) 152 | outputs.append([]) 153 | else: 154 | outputs[-1].append(token) 155 | outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs] 156 | return "".join(outputs) 157 | 158 | @property 159 | @lru_cache() 160 | def eot(self) -> int: 161 | return self.tokenizer.eos_token_id 162 | 163 | @property 164 | @lru_cache() 165 | def sot(self) -> int: 166 | return self._get_single_token_id("<|startoftranscript|>") 167 | 168 | @property 169 | @lru_cache() 170 | def sot_lm(self) -> int: 171 | return self._get_single_token_id("<|startoflm|>") 172 | 173 | @property 174 | @lru_cache() 175 | def sot_prev(self) -> int: 176 | return self._get_single_token_id("<|startofprev|>") 177 | 178 | @property 179 | @lru_cache() 180 | def no_speech(self) -> int: 181 | return self._get_single_token_id("<|nospeech|>") 182 | 183 | @property 184 | @lru_cache() 185 | def no_timestamps(self) -> int: 186 | return self._get_single_token_id("<|notimestamps|>") 187 | 188 | @property 189 | @lru_cache() 190 | def timestamp_begin(self) -> int: 191 | return self.tokenizer.all_special_ids[-1] + 1 192 | 193 | @property 194 | @lru_cache() 195 | def language_token(self) -> int: 196 | """Returns the token id corresponding to the value of the `language` field""" 197 | if self.language is None: 198 | raise ValueError(f"This tokenizer does not have language token configured") 199 | 200 | additional_tokens = dict( 201 | zip( 202 | self.tokenizer.additional_special_tokens, 203 | self.tokenizer.additional_special_tokens_ids, 204 | ) 205 | ) 206 | candidate = f"<|{self.language}|>" 207 | if candidate in additional_tokens: 208 | return additional_tokens[candidate] 209 | 210 | raise KeyError(f"Language {self.language} not found in tokenizer.") 211 | 212 | @property 213 | @lru_cache() 214 | def all_language_tokens(self) -> Tuple[int]: 215 | result = [] 216 | for token, token_id in zip( 217 | self.tokenizer.additional_special_tokens, 218 | self.tokenizer.additional_special_tokens_ids, 219 | ): 220 | if token.strip("<|>") in LANGUAGES: 221 | result.append(token_id) 222 | return tuple(result) 223 | 224 | @property 225 | @lru_cache() 226 | def all_language_codes(self) -> Tuple[str]: 227 | return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens) 228 | 229 | @property 230 | @lru_cache() 231 | def sot_sequence_including_notimestamps(self) -> Tuple[int]: 232 | return tuple(list(self.sot_sequence) + [self.no_timestamps]) 233 | 234 | @property 235 | @lru_cache() 236 | def non_speech_tokens(self) -> Tuple[int]: 237 | """ 238 | Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech 239 | annotations, to prevent sampling texts that are not actually spoken in the audio, e.g. 240 | 241 | - ♪♪♪ 242 | - ( SPEAKING FOREIGN LANGUAGE ) 243 | - [DAVID] Hey there, 244 | 245 | keeping basic punctuations like commas, periods, question marks, exclamation points, etc. 246 | """ 247 | symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』") 248 | symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split() 249 | 250 | # symbols that may be a single token or multiple tokens depending on the tokenizer. 251 | # In case they're multiple tokens, suppress the first token, which is safe because: 252 | # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress 253 | # in generations, and in the 3-byte UTF-8 representation they share the first two bytes. 254 | miscellaneous = set("♩♪♫♬♭♮♯") 255 | assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous) 256 | 257 | # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word 258 | result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]} 259 | for symbol in symbols + list(miscellaneous): 260 | for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]: 261 | if len(tokens) == 1 or symbol in miscellaneous: 262 | result.add(tokens[0]) 263 | 264 | return tuple(sorted(result)) 265 | 266 | def _get_single_token_id(self, text) -> int: 267 | tokens = self.tokenizer.encode(text) 268 | assert len(tokens) == 1, f"{text} is not encoded as a single token" 269 | return tokens[0] 270 | 271 | 272 | @lru_cache(maxsize=None) 273 | def build_tokenizer(name: str = "gpt2"): 274 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 275 | path = os.path.join(os.path.dirname(__file__), "assets", name) 276 | tokenizer = GPT2TokenizerFast.from_pretrained(path) 277 | 278 | specials = [ 279 | "<|startoftranscript|>", 280 | *[f"<|{lang}|>" for lang in LANGUAGES.keys()], 281 | "<|translate|>", 282 | "<|transcribe|>", 283 | "<|startoflm|>", 284 | "<|startofprev|>", 285 | "<|nospeech|>", 286 | "<|notimestamps|>", 287 | ] 288 | 289 | tokenizer.add_special_tokens(dict(additional_special_tokens=specials)) 290 | return tokenizer 291 | 292 | 293 | @lru_cache(maxsize=None) 294 | def get_tokenizer( 295 | multilingual: bool, 296 | *, 297 | task: Optional[str] = None, # Literal["transcribe", "translate", None] 298 | language: Optional[str] = None, 299 | ) -> Tokenizer: 300 | if language is not None: 301 | language = language.lower() 302 | if language not in LANGUAGES: 303 | if language in TO_LANGUAGE_CODE: 304 | language = TO_LANGUAGE_CODE[language] 305 | else: 306 | raise ValueError(f"Unsupported language: {language}") 307 | 308 | if multilingual: 309 | tokenizer_name = "multilingual" 310 | task = task or "transcribe" 311 | language = language or "en" 312 | else: 313 | tokenizer_name = "gpt2" 314 | task = None 315 | language = None 316 | 317 | tokenizer = build_tokenizer(name=tokenizer_name) 318 | all_special_ids: List[int] = tokenizer.all_special_ids 319 | sot: int = all_special_ids[1] 320 | translate: int = all_special_ids[-6] 321 | transcribe: int = all_special_ids[-5] 322 | 323 | langs = tuple(LANGUAGES.keys()) 324 | sot_sequence = [sot] 325 | if language is not None: 326 | sot_sequence.append(sot + 1 + langs.index(language)) 327 | if task is not None: 328 | sot_sequence.append(transcribe if task == "transcribe" else translate) 329 | 330 | return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)) 331 | -------------------------------------------------------------------------------- /whisper/transcribe.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.getcwd()) 4 | import argparse 5 | import warnings 6 | from typing import List, Optional, Tuple, Union, TYPE_CHECKING 7 | 8 | import warnings 9 | warnings.simplefilter(action='ignore', category=FutureWarning) 10 | warnings.simplefilter(action='ignore', category=Warning) 11 | warnings.simplefilter(action='ignore', category=DeprecationWarning) 12 | warnings.simplefilter(action='ignore', category=RuntimeWarning) 13 | 14 | import numpy as np 15 | import tqdm 16 | 17 | from whisper.model import load_model, available_models 18 | from whisper.audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram 19 | from whisper.decoding import DecodingOptions, DecodingResult 20 | from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer 21 | from whisper.utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt 22 | 23 | if TYPE_CHECKING: 24 | from whisper.model import Whisper 25 | 26 | 27 | def transcribe( 28 | model: "Whisper", 29 | audio: Union[str, np.ndarray], 30 | *, 31 | verbose: Optional[bool] = None, 32 | temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), 33 | compression_ratio_threshold: Optional[float] = 2.4, 34 | logprob_threshold: Optional[float] = -1.0, 35 | no_speech_threshold: Optional[float] = 0.6, 36 | condition_on_previous_text: bool = True, 37 | **decode_options, 38 | ): 39 | """ 40 | Transcribe an audio file using Whisper 41 | 42 | Parameters 43 | ---------- 44 | model: Whisper 45 | The Whisper model instance 46 | 47 | audio: Union[str, np.ndarray] 48 | The path to the audio file to open, or the audio waveform 49 | 50 | verbose: bool 51 | Whether to display the text being decoded to the console. If True, displays all the details, 52 | If False, displays minimal details. If None, does not display anything 53 | 54 | temperature: Union[float, Tuple[float, ...]] 55 | Temperature for sampling. It can be a tuple of temperatures, which will be successfully used 56 | upon failures according to either `compression_ratio_threshold` or `logprob_threshold`. 57 | 58 | compression_ratio_threshold: float 59 | If the gzip compression ratio is above this value, treat as failed 60 | 61 | logprob_threshold: float 62 | If the average log probability over sampled tokens is below this value, treat as failed 63 | 64 | no_speech_threshold: float 65 | If the no_speech probability is higher than this value AND the average log probability 66 | over sampled tokens is below `logprob_threshold`, consider the segment as silent 67 | 68 | condition_on_previous_text: bool 69 | if True, the previous output of the model is provided as a prompt for the next window; 70 | disabling may make the text inconsistent across windows, but the model becomes less prone to 71 | getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. 72 | 73 | decode_options: dict 74 | Keyword arguments to construct `DecodingOptions` instances 75 | 76 | Returns 77 | ------- 78 | A dictionary containing the resulting text ("text") and segment-level details ("segments"), and 79 | the spoken language ("language"), which is detected when `decode_options["language"]` is None. 80 | """ 81 | mel: np.ndarray = log_mel_spectrogram(audio, decode_options.pop("disable_cupy")) 82 | 83 | if decode_options.get("language", None) is None: 84 | if verbose: 85 | print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language") 86 | segment = pad_or_trim(mel, N_FRAMES) 87 | _, probs = model.detect_language(segment) 88 | decode_options["language"] = max(probs, key=probs.get) 89 | if verbose is not None: 90 | print(f"Detected language: {LANGUAGES[decode_options['language']].title()}") 91 | 92 | mel = mel[np.newaxis, ...] 93 | language = decode_options["language"] 94 | task = decode_options.get("task", "transcribe") 95 | tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task) 96 | 97 | def decode_with_fallback(segment: np.ndarray) -> List[DecodingResult]: 98 | temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature 99 | kwargs = {**decode_options} 100 | t = temperatures[0] 101 | if t == 0: 102 | best_of = kwargs.pop("best_of", None) 103 | else: 104 | best_of = kwargs.get("best_of", None) 105 | 106 | options = DecodingOptions(**kwargs, temperature=t) 107 | results = model.decode(segment, options) 108 | 109 | kwargs.pop("beam_size", None) # no beam search for t > 0 110 | kwargs.pop("patience", None) # no patience for t > 0 111 | kwargs["best_of"] = best_of # enable best_of for t > 0 112 | for t in temperatures[1:]: 113 | needs_fallback = [ 114 | compression_ratio_threshold is not None 115 | and result.compression_ratio > compression_ratio_threshold 116 | or logprob_threshold is not None 117 | and result.avg_logprob < logprob_threshold 118 | for result in results 119 | ] 120 | if any(needs_fallback): 121 | options = DecodingOptions(**kwargs, temperature=t) 122 | retries = model.decode(segment[needs_fallback], options) 123 | for retry_index, original_index in enumerate(np.nonzero(needs_fallback)[0]): 124 | results[original_index] = retries[retry_index] 125 | 126 | return results 127 | 128 | seek = 0 129 | input_stride = exact_div( 130 | N_FRAMES, model.dims.n_audio_ctx 131 | ) # mel frames per output token: 2 132 | time_precision = ( 133 | input_stride * HOP_LENGTH / SAMPLE_RATE 134 | ) # time per output token: 0.02 (seconds) 135 | all_tokens = [] 136 | all_segments = [] 137 | prompt_reset_since = 0 138 | 139 | initial_prompt = decode_options.pop("initial_prompt", None) or [] 140 | if initial_prompt: 141 | initial_prompt = tokenizer.encode(" " + initial_prompt.strip()) 142 | all_tokens.extend(initial_prompt) 143 | 144 | def add_segment( 145 | *, start: float, end: float, text_tokens: np.ndarray, result: DecodingResult 146 | ): 147 | text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot]) 148 | if len(text.strip()) == 0: # skip empty text output 149 | return 150 | 151 | all_segments.append( 152 | { 153 | "id": len(all_segments), 154 | "seek": seek, 155 | "start": start, 156 | "end": end, 157 | "text": text, 158 | "tokens": result.tokens, 159 | "temperature": result.temperature, 160 | "avg_logprob": result.avg_logprob, 161 | "compression_ratio": result.compression_ratio, 162 | "no_speech_prob": result.no_speech_prob, 163 | } 164 | ) 165 | if verbose: 166 | print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}", flush=True) 167 | 168 | # show the progress bar when verbose is False (otherwise the transcribed text will be printed) 169 | num_frames = mel.shape[-1] 170 | previous_seek_value = seek 171 | 172 | with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar: 173 | while seek < num_frames: 174 | timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) 175 | segment = pad_or_trim(mel[:, :, seek:], N_FRAMES) 176 | segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE 177 | 178 | decode_options["prompt"] = all_tokens[prompt_reset_since:] 179 | result = decode_with_fallback(segment)[0] 180 | tokens = result.tokens 181 | 182 | if no_speech_threshold is not None: 183 | # no voice activity check 184 | should_skip = result.no_speech_prob > no_speech_threshold 185 | if logprob_threshold is not None and result.avg_logprob > logprob_threshold: 186 | # don't skip if the logprob is high enough, despite the no_speech_prob 187 | should_skip = False 188 | 189 | if should_skip: 190 | seek += segment.shape[-1] # fast-forward to the next segment boundary 191 | continue 192 | 193 | timestamp_tokens: np.ndarray = np.greater_equal(tokens, tokenizer.timestamp_begin) 194 | consecutive = np.add(np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0], 1) 195 | if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens 196 | last_slice = 0 197 | for current_slice in consecutive: 198 | sliced_tokens = tokens[last_slice:current_slice] 199 | start_timestamp_position = ( 200 | sliced_tokens[0] - tokenizer.timestamp_begin 201 | ) 202 | end_timestamp_position = ( 203 | sliced_tokens[-1] - tokenizer.timestamp_begin 204 | ) 205 | add_segment( 206 | start=timestamp_offset + start_timestamp_position * time_precision, 207 | end=timestamp_offset + end_timestamp_position * time_precision, 208 | text_tokens=sliced_tokens[1:-1], 209 | result=result, 210 | ) 211 | last_slice = current_slice 212 | last_timestamp_position = ( 213 | tokens[last_slice - 1] - tokenizer.timestamp_begin 214 | ) 215 | seek += last_timestamp_position * input_stride 216 | all_tokens.extend(list(tokens[: last_slice + 1])) 217 | else: 218 | duration = segment_duration 219 | tokens = np.asarray(tokens) if isinstance(tokens, list) else tokens 220 | timestamps = tokens[ 221 | np.ravel_multi_index(np.nonzero(timestamp_tokens), timestamp_tokens.shape) 222 | ] 223 | if len(timestamps) > 0: 224 | # no consecutive timestamps but it has a timestamp; use the last one. 225 | # single timestamp at the end means no speech after the last timestamp. 226 | last_timestamp_position = timestamps[-1] - tokenizer.timestamp_begin 227 | duration = last_timestamp_position * time_precision 228 | 229 | add_segment( 230 | start=timestamp_offset, 231 | end=timestamp_offset + duration, 232 | text_tokens=tokens, 233 | result=result, 234 | ) 235 | 236 | seek += segment.shape[-1] 237 | all_tokens.extend(list(tokens)) 238 | 239 | if not condition_on_previous_text or result.temperature > 0.5: 240 | # do not feed the prompt tokens if a high temperature was used 241 | prompt_reset_since = len(all_tokens) 242 | 243 | # update progress bar 244 | pbar.update(min(num_frames, seek) - previous_seek_value) 245 | previous_seek_value = seek 246 | 247 | return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language) 248 | 249 | 250 | def cli(): 251 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 252 | parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") 253 | parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use") 254 | parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") 255 | parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") 256 | parser.add_argument("--disable_cupy", action='store_true', help='When Out of Memory occurs due to insufficient GPU RAM, this option suppresses GPU RAM consumption.') 257 | 258 | parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") 259 | parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection") 260 | 261 | parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling") 262 | parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature") 263 | parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero") 264 | parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search") 265 | parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple lengt normalization by default") 266 | 267 | parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations") 268 | parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") 269 | parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") 270 | 271 | parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below") 272 | parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed") 273 | parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed") 274 | parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") 275 | 276 | args = parser.parse_args().__dict__ 277 | model_name: str = args.pop("model") 278 | output_dir: str = args.pop("output_dir") 279 | os.makedirs(output_dir, exist_ok=True) 280 | 281 | if model_name.endswith(".en") and args["language"] not in {"en", "English"}: 282 | warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.") 283 | args["language"] = "en" 284 | 285 | temperature = args.pop("temperature") 286 | temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback") 287 | if temperature_increment_on_fallback is not None: 288 | temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback)) 289 | else: 290 | temperature = [temperature] 291 | 292 | model = load_model(model_name) 293 | 294 | for audio_path in args.pop("audio"): 295 | result = transcribe(model, audio_path, temperature=temperature, **args) 296 | 297 | audio_basename = os.path.basename(audio_path) 298 | 299 | # save TXT 300 | with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt: 301 | write_txt(result["segments"], file=txt) 302 | 303 | # save VTT 304 | with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt: 305 | write_vtt(result["segments"], file=vtt) 306 | 307 | # save SRT 308 | with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt: 309 | write_srt(result["segments"], file=srt) 310 | 311 | 312 | if __name__ == '__main__': 313 | cli() 314 | -------------------------------------------------------------------------------- /whisper/utils.py: -------------------------------------------------------------------------------- 1 | import zlib 2 | import numpy as np 3 | from typing import Iterator, TextIO 4 | 5 | 6 | def exact_div(x, y): 7 | assert x % y == 0 8 | return x // y 9 | 10 | 11 | def str2bool(string): 12 | str2val = {"True": True, "False": False} 13 | if string in str2val: 14 | return str2val[string] 15 | else: 16 | raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") 17 | 18 | 19 | def optional_int(string): 20 | return None if string == "None" else int(string) 21 | 22 | 23 | def optional_float(string): 24 | return None if string == "None" else float(string) 25 | 26 | 27 | def compression_ratio(text) -> float: 28 | return len(text) / len(zlib.compress(text.encode("utf-8"))) 29 | 30 | 31 | def format_timestamp(seconds: float, always_include_hours: bool = False): 32 | assert seconds >= 0, "non-negative timestamp expected" 33 | milliseconds = round(seconds * 1000.0) 34 | 35 | hours = milliseconds // 3_600_000 36 | milliseconds -= hours * 3_600_000 37 | 38 | minutes = milliseconds // 60_000 39 | milliseconds -= minutes * 60_000 40 | 41 | seconds = milliseconds // 1_000 42 | milliseconds -= seconds * 1_000 43 | 44 | hours_marker = f"{hours}:" if always_include_hours or hours > 0 else "" 45 | return f"{hours_marker}{minutes:02d}:{seconds:02d}.{milliseconds:03d}" 46 | 47 | 48 | def write_txt(transcript: Iterator[dict], file: TextIO): 49 | for segment in transcript: 50 | print(segment['text'].strip(), file=file, flush=True) 51 | 52 | 53 | def write_vtt(transcript: Iterator[dict], file: TextIO): 54 | print("WEBVTT\n", file=file) 55 | for segment in transcript: 56 | print( 57 | f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" 58 | f"{segment['text'].replace('-->', '->')}\n", 59 | file=file, 60 | flush=True, 61 | ) 62 | 63 | 64 | def write_srt(transcript: Iterator[dict], file: TextIO): 65 | """ 66 | Write a transcript to a file in SRT format. 67 | 68 | Example usage: 69 | from pathlib import Path 70 | from whisper.utils import write_srt 71 | 72 | result = transcribe(model, audio_path, temperature=temperature, **args) 73 | 74 | # save SRT 75 | audio_basename = Path(audio_path).stem 76 | with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt: 77 | write_srt(result["segments"], file=srt) 78 | """ 79 | for i, segment in enumerate(transcript, start=1): 80 | # write srt lines 81 | print( 82 | f"{i}\n" 83 | f"{format_timestamp(segment['start'], always_include_hours=True)} --> " 84 | f"{format_timestamp(segment['end'], always_include_hours=True)}\n" 85 | f"{segment['text'].strip().replace('-->', '->')}\n", 86 | file=file, 87 | flush=True, 88 | ) 89 | 90 | ONNX_DTYPE_NP_DTYPE = { 91 | "tensor(int64)": np.int64, 92 | "tensor(float)": np.float32, 93 | "tensor(float16)": np.float16, 94 | } 95 | 96 | def onnx_dtype_to_np_dtype_convert(onnx_dtype: str): 97 | return ONNX_DTYPE_NP_DTYPE[onnx_dtype] --------------------------------------------------------------------------------