├── training_code ├── __init__.py ├── ds_dict.py ├── collators.py ├── utils.py ├── train.py ├── datasets_classes.py └── whisper_module.py ├── .gitignore ├── careless_whisper_stream ├── version.py ├── __main__.py ├── assets │ └── mel_filters.npz ├── normalizers │ ├── __init__.py │ ├── basic.py │ └── english.py ├── triton_ops.py ├── streaming_transcribe.py ├── utils.py ├── model.py ├── __init__.py ├── tokenizer.py ├── audio.py ├── timing.py └── streaming_model.py ├── tests ├── jfk.wav └── streaming_transcription.py ├── transcribe.py ├── requirements.txt ├── environment.yml ├── README.md └── LICENSE /training_code/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.ipynb 2 | __pycache__/ 3 | notebooks/ 4 | static/ -------------------------------------------------------------------------------- /careless_whisper_stream/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "20231117" 2 | -------------------------------------------------------------------------------- /careless_whisper_stream/__main__.py: -------------------------------------------------------------------------------- 1 | from .streaming_transcribe import cli 2 | 3 | cli() 4 | -------------------------------------------------------------------------------- /tests/jfk.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomer9080/CarelessWhisper-Streaming/HEAD/tests/jfk.wav -------------------------------------------------------------------------------- /transcribe.py: -------------------------------------------------------------------------------- 1 | from careless_whisper_stream.streaming_transcribe import cli 2 | 3 | if __name__ == "__main__": 4 | cli() -------------------------------------------------------------------------------- /careless_whisper_stream/assets/mel_filters.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomer9080/CarelessWhisper-Streaming/HEAD/careless_whisper_stream/assets/mel_filters.npz -------------------------------------------------------------------------------- /careless_whisper_stream/normalizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic import BasicTextNormalizer as BasicTextNormalizer 2 | from .english import EnglishTextNormalizer as EnglishTextNormalizer 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp 2 | evaluate 3 | huggingface_hub 4 | jiwer 5 | more_itertools 6 | numba 7 | numpy 8 | openai_whisper==20240930 9 | pandas 10 | praatio 11 | pyaudio 12 | pytorch_lightning==2.5.0.post0 13 | regex 14 | soundfile 15 | tiktoken 16 | tqdm 17 | triton 18 | websockets 19 | wandb 20 | -------------------------------------------------------------------------------- /tests/streaming_transcription.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append("./") 4 | 5 | import torch 6 | import careless_whisper_stream 7 | 8 | device = "cuda" if torch.cuda.is_available() else "cpu" 9 | model = careless_whisper_stream.load_streaming_model("small", 300, False, device) 10 | texts = model.transcribe(simulate_stream=True, wav_file=os.path.join(os.path.dirname(os.path.abspath(__file__)), "jfk.wav"), beam_size=5, ca_kv_cache=True) 11 | print(texts) -------------------------------------------------------------------------------- /training_code/ds_dict.py: -------------------------------------------------------------------------------- 1 | import os 2 | """ 3 | Taken from README.md: 4 | ### Dataset Structure 5 | 6 | Before starting model training using the command-line interface provided below, you must first configure your dataset dictionary file located at `training_code/ds_dict.py`. 7 | 8 | This file defines a Python dictionary named `ds_paths`, where you should specify paths to the `train`, `val`, and `test` partitions of your dataset. Each partition should be a CSV file with the following three columns: 9 | 10 | 1. `wav_path` — Path to the WAV audio file. 11 | 2. `tg_path` — Path to the corresponding `.TextGrid` file containing forced alignment. 12 | 3. `raw_text` — Ground truth transcription. 13 | 14 | > **Note:** The dictionary key (i.e., the name of the dataset) will be used by the training script to identify and load the dataset correctly. 15 | """ 16 | 17 | ds_paths = { 18 | 'LIBRI-960-ALIGNED': { 19 | 'train': f"{os.environ.get('HOME')}/path/to/datasets/libri_train_960_train.csv", # used on training steps. 20 | 'val': f"{os.environ.get('HOME')}/path/to/datasets/libri_train_960_val.csv", # used on validation steps. 21 | 'test': f"{os.environ.get('HOME')}/path/to/datasets/libri_train_960_test.csv" # used on evaluation. 22 | }, 23 | # Add you entries below. 24 | } 25 | -------------------------------------------------------------------------------- /careless_whisper_stream/normalizers/basic.py: -------------------------------------------------------------------------------- 1 | import re 2 | import unicodedata 3 | 4 | import regex 5 | 6 | # non-ASCII letters that are not separated by "NFKD" normalization 7 | ADDITIONAL_DIACRITICS = { 8 | "œ": "oe", 9 | "Œ": "OE", 10 | "ø": "o", 11 | "Ø": "O", 12 | "æ": "ae", 13 | "Æ": "AE", 14 | "ß": "ss", 15 | "ẞ": "SS", 16 | "đ": "d", 17 | "Đ": "D", 18 | "ð": "d", 19 | "Ð": "D", 20 | "þ": "th", 21 | "Þ": "th", 22 | "ł": "l", 23 | "Ł": "L", 24 | } 25 | 26 | 27 | def remove_symbols_and_diacritics(s: str, keep=""): 28 | """ 29 | Replace any other markers, symbols, and punctuations with a space, 30 | and drop any diacritics (category 'Mn' and some manual mappings) 31 | """ 32 | return "".join( 33 | c 34 | if c in keep 35 | else ADDITIONAL_DIACRITICS[c] 36 | if c in ADDITIONAL_DIACRITICS 37 | else "" 38 | if unicodedata.category(c) == "Mn" 39 | else " " 40 | if unicodedata.category(c)[0] in "MSP" 41 | else c 42 | for c in unicodedata.normalize("NFKD", s) 43 | ) 44 | 45 | 46 | def remove_symbols(s: str): 47 | """ 48 | Replace any other markers, symbols, punctuations with a space, keeping diacritics 49 | """ 50 | return "".join( 51 | " " if unicodedata.category(c)[0] in "MSP" else c 52 | for c in unicodedata.normalize("NFKC", s) 53 | ) 54 | 55 | 56 | class BasicTextNormalizer: 57 | def __init__(self, remove_diacritics: bool = False, split_letters: bool = False): 58 | self.clean = ( 59 | remove_symbols_and_diacritics if remove_diacritics else remove_symbols 60 | ) 61 | self.split_letters = split_letters 62 | 63 | def __call__(self, s: str): 64 | s = s.lower() 65 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets 66 | s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis 67 | s = self.clean(s).lower() 68 | 69 | if self.split_letters: 70 | s = " ".join(regex.findall(r"\X", s, regex.U)) 71 | 72 | s = re.sub( 73 | r"\s+", " ", s 74 | ) # replace any successive whitespace characters with a space 75 | 76 | return s 77 | -------------------------------------------------------------------------------- /training_code/collators.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | class WhisperDataCollatorWithPadding: 6 | def __call__(self, features): 7 | 8 | input_ids, labels, dec_input_ids, labels_classes, unique_ids = [], [], [], [], [] 9 | for f in features: 10 | input_ids.append(f["input_ids"]) 11 | labels.append(f["labels"]) 12 | dec_input_ids.append(f["dec_input_ids"]) 13 | labels_classes.append([int(item==50257) for item in f["labels"]]) 14 | unique_ids.append(f.get("u_id", 0)) 15 | 16 | input_ids = torch.concat([input_id[None, :] for input_id in input_ids]) 17 | 18 | label_lengths = [len(lab) for lab in labels] 19 | dec_input_ids_length = [len(e) for e in dec_input_ids] 20 | max_label_len = max(label_lengths+dec_input_ids_length) 21 | 22 | labels = [np.pad(lab, (0, max_label_len - lab_len), 'constant', constant_values=-100) for lab, lab_len in zip(labels, label_lengths)] 23 | # labels = [np.pad(lab, (0, max_label_len - lab_len), 'constant', constant_values=50257) for lab, lab_len in zip(labels, label_lengths)] 24 | labels_classes = [np.pad(lab, (0, max_label_len - lab_len), 'constant', constant_values=-100) for lab, lab_len in zip(labels_classes, label_lengths)] 25 | dec_input_ids = [np.pad(e, (0, max_label_len - e_len), 'constant', constant_values=50257) for e, e_len in zip(dec_input_ids, dec_input_ids_length)] # 50257 is eot token id 26 | 27 | batch = { 28 | "labels": labels, 29 | "dec_input_ids": dec_input_ids, 30 | "labels_classes": labels_classes, 31 | "unique_id": unique_ids 32 | } 33 | 34 | batch = {k: torch.tensor(np.array(v), requires_grad=False) for k, v in batch.items()} 35 | 36 | batch["input_ids"] = input_ids 37 | 38 | return batch 39 | 40 | 41 | def pad_2d_sequences(arrays: list, dim: int = 0, padding_value: int = 0) -> torch.Tensor: 42 | lens = [array.shape[dim] for array in arrays] 43 | max_len = max(lens) 44 | 45 | padded_arrays = [F.pad(array, (0, int(dim == 1) * (max_len - array.shape[1]), 0, int(dim == 0) * (max_len - array.shape[0])), mode="constant", value=padding_value) for array in arrays] 46 | return torch.cat([padded_array[None] for padded_array in padded_arrays]) # adding batch dim and concatanating 47 | 48 | 49 | class LoRAWhisperDataCollatorWithPadding: 50 | def __call__(self, features): 51 | 52 | input_ids, labels, dec_input_ids, endpoints = [], [], [], [] 53 | 54 | for f in features: 55 | input_ids.append(f["input_ids"]) 56 | labels.append(f["labels"]) 57 | dec_input_ids.append(f["dec_input_ids"]) 58 | endpoints.append(f["endpoints"]) 59 | 60 | # make a batch 61 | input_ids = torch.concat([input_id[None, :] for input_id in input_ids]) 62 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100) 63 | dec_input_ids = torch.nn.utils.rnn.pad_sequence(dec_input_ids, batch_first=True, padding_value=50257) 64 | endpoints = torch.nn.utils.rnn.pad_sequence(endpoints, batch_first=True, padding_value=-100) 65 | 66 | batch = { 67 | "labels": labels, 68 | "dec_input_ids": dec_input_ids, 69 | "endpoints": endpoints 70 | } 71 | 72 | batch = {k: v.detach() for k, v in batch.items()} 73 | 74 | batch["input_ids"] = input_ids.squeeze(1) 75 | 76 | return batch 77 | 78 | -------------------------------------------------------------------------------- /careless_whisper_stream/triton_ops.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | 3 | import numpy as np 4 | import torch 5 | 6 | try: 7 | import triton 8 | import triton.language as tl 9 | except ImportError: 10 | raise RuntimeError("triton import failed; try `pip install --pre triton`") 11 | 12 | 13 | @triton.jit 14 | def dtw_kernel( 15 | cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr 16 | ): 17 | offsets = tl.arange(0, BLOCK_SIZE) 18 | mask = offsets < M 19 | 20 | for k in range(1, N + M + 1): # k = i + j 21 | tl.debug_barrier() 22 | 23 | p0 = cost + (k - 1) * cost_stride 24 | p1 = cost + k * cost_stride 25 | p2 = cost + k * cost_stride + 1 26 | 27 | c0 = tl.load(p0 + offsets, mask=mask) 28 | c1 = tl.load(p1 + offsets, mask=mask) 29 | c2 = tl.load(p2 + offsets, mask=mask) 30 | 31 | x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0) 32 | cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2) 33 | 34 | cost_ptr = cost + (k + 1) * cost_stride + 1 35 | tl.store(cost_ptr + offsets, cost_row, mask=mask) 36 | 37 | trace_ptr = trace + (k + 1) * trace_stride + 1 38 | tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1)) 39 | tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2)) 40 | tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2)) 41 | 42 | 43 | @lru_cache(maxsize=None) 44 | def median_kernel(filter_width: int): 45 | @triton.jit 46 | def kernel( 47 | y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr 48 | ): # x.shape[-1] == filter_width 49 | row_idx = tl.program_id(0) 50 | offsets = tl.arange(0, BLOCK_SIZE) 51 | mask = offsets < y_stride 52 | 53 | x_ptr = x + row_idx * x_stride # noqa: F841 54 | y_ptr = y + row_idx * y_stride 55 | 56 | LOAD_ALL_ROWS_HERE # noqa: F821 57 | 58 | BUBBLESORT_HERE # noqa: F821 59 | 60 | tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821 61 | 62 | kernel = triton.JITFunction(kernel.fn) 63 | kernel.src = kernel.src.replace( 64 | " LOAD_ALL_ROWS_HERE", 65 | "\n".join( 66 | [ 67 | f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)" 68 | for i in range(filter_width) 69 | ] 70 | ), 71 | ) 72 | kernel.src = kernel.src.replace( 73 | " BUBBLESORT_HERE", 74 | "\n\n".join( 75 | [ 76 | "\n\n".join( 77 | [ 78 | "\n".join( 79 | [ 80 | f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})", 81 | f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})", 82 | f" row{j} = smaller", 83 | f" row{j + 1} = larger", 84 | ] 85 | ) 86 | for j in range(filter_width - i - 1) 87 | ] 88 | ) 89 | for i in range(filter_width // 2 + 1) 90 | ] 91 | ), 92 | ) 93 | kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}") 94 | 95 | return kernel 96 | 97 | 98 | def median_filter_cuda(x: torch.Tensor, filter_width: int): 99 | """Apply a median filter of given width along the last dimension of x""" 100 | slices = x.contiguous().unfold(-1, filter_width, 1) 101 | grid = np.prod(slices.shape[:-2]) 102 | 103 | kernel = median_kernel(filter_width) 104 | y = torch.empty_like(slices[..., 0]) 105 | 106 | BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length() 107 | kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE) 108 | 109 | return y 110 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: careless_whisper 2 | channels: 3 | - defaults 4 | - conda-forge 5 | - https://repo.anaconda.com/pkgs/main 6 | - https://repo.anaconda.com/pkgs/r 7 | dependencies: 8 | - _libgcc_mutex=0.1 9 | - _openmp_mutex=5.1 10 | - annotated-types=0.6.0 11 | - appdirs=1.4.4 12 | - asttokens=3.0.0 13 | - brotlicffi=1.0.9.2 14 | - ca-certificates=2025.7.14 15 | - certifi=2025.7.14 16 | - cffi=1.17.1 17 | - click=8.1.8 18 | - comm=0.2.2 19 | - debugpy=1.8.14 20 | - decorator=5.2.1 21 | - docker-pycreds=0.4.0 22 | - eval_type_backport=0.2.2 23 | - exceptiongroup=1.3.0 24 | - executing=2.2.0 25 | - gitdb=4.0.7 26 | - gitpython=3.1.43 27 | - importlib-metadata=8.7.0 28 | - ipykernel=6.29.5 29 | - ipython=8.18.1 30 | - jedi=0.19.2 31 | - jupyter_client=8.6.3 32 | - jupyter_core=5.8.1 33 | - krb5=1.21.3 34 | - ld_impl_linux-64=2.40 35 | - libabseil=20250127.0 36 | - libedit=3.1.20230828 37 | - libffi=3.4.4 38 | - libgcc=15.1.0 39 | - libgcc-ng=15.1.0 40 | - libgomp=15.1.0 41 | - libsodium=1.0.20 42 | - libstdcxx=15.1.0 43 | - libstdcxx-ng=11.2.0 44 | - libxcb=1.17.0 45 | - matplotlib-inline=0.1.7 46 | - ncurses=6.4 47 | - nest-asyncio=1.6.0 48 | - openssl=3.5.1 49 | - packaging=25.0 50 | - parso=0.8.4 51 | - pexpect=4.9.0 52 | - pickleshare=0.7.5 53 | - pip=25.1 54 | - platformdirs=4.3.8 55 | - prompt-toolkit=3.0.51 56 | - protobuf=5.29.3 57 | - psutil=7.0.0 58 | - pthread-stubs=0.3 59 | - ptyprocess=0.7.0 60 | - pure_eval=0.2.3 61 | - pydantic=2.11.7 62 | - pydantic-core=2.33.2 63 | - pygments=2.19.2 64 | - pysocks=1.7.1 65 | - python=3.9.16 66 | - python-dateutil=2.9.0.post0 67 | - python_abi=3.9 68 | - pyyaml=6.0.2 69 | - pyzmq=27.0.0 70 | - readline=8.2 71 | - requests=2.32.4 72 | - sentry-sdk=2.18.0 73 | - setproctitle=1.2.2 74 | - setuptools=78.1.1 75 | - six=1.17.0 76 | - smmap=5.0.2 77 | - sqlite=3.45.3 78 | - stack_data=0.6.3 79 | - tk=8.6.14 80 | - tornado=6.5.1 81 | - traitlets=5.14.3 82 | - typing=3.10.0.0 83 | - typing-inspection=0.4.0 84 | - typing_extensions=4.14.1 85 | - urllib3=2.5.0 86 | - wandb=0.21.0 87 | - wcwidth=0.2.13 88 | - wheel=0.45.1 89 | - xorg-libx11=1.8.12 90 | - xorg-libxau=1.0.12 91 | - xorg-libxdmcp=1.1.5 92 | - xorg-xorgproto=2024.1 93 | - xz=5.6.4 94 | - yaml=0.2.5 95 | - zeromq=4.3.5 96 | - zipp=3.23.0 97 | - zlib=1.2.13 98 | - pip: 99 | - aiohappyeyeballs==2.6.1 100 | - aiohttp==3.12.14 101 | - aiosignal==1.4.0 102 | - async-timeout==5.0.1 103 | - attrs==25.3.0 104 | - audioread==3.0.1 105 | - charset-normalizer==3.4.2 106 | - contourpy==1.3.0 107 | - cycler==0.12.1 108 | - datasets==4.0.0 109 | - dill==0.3.8 110 | - evaluate==0.4.3 111 | - filelock==3.13.1 112 | - fonttools==4.59.0 113 | - frozenlist==1.7.0 114 | - fsspec==2024.6.1 115 | - hf-xet==1.1.5 116 | - huggingface-hub==0.33.4 117 | - idna==3.10 118 | - importlib-resources==6.5.2 119 | - jinja2==3.1.4 120 | - jiwer==4.0.0 121 | - joblib==1.5.1 122 | - kiwisolver==1.4.7 123 | - lazy-loader==0.4 124 | - librosa==0.10.2.post1 125 | - lightning-utilities==0.14.3 126 | - llvmlite==0.43.0 127 | - markupsafe==2.1.5 128 | - matplotlib==3.9.4 129 | - more-itertools==10.7.0 130 | - mpmath==1.3.0 131 | - msgpack==1.1.1 132 | - multidict==6.6.3 133 | - multiprocess==0.70.16 134 | - networkx==3.2.1 135 | - numba==0.60.0 136 | - numpy==1.26.0 137 | - openai-whisper==20240930 138 | - pandas==2.3.1 139 | - pillow==11.0.0 140 | - pooch==1.8.2 141 | - praatio==6.2.0 142 | - propcache==0.3.2 143 | - pyarrow==20.0.0 144 | - pyaudio==0.2.11 145 | - pycparser==2.22 146 | - pyparsing==3.2.3 147 | - pytorch-lightning==2.5.0.post0 148 | - pytz==2025.2 149 | - rapidfuzz==3.13.0 150 | - regex==2024.11.6 151 | - scikit-learn==1.6.1 152 | - scipy==1.13.1 153 | - soundfile==0.12.1 154 | - soxr==0.5.0.post1 155 | - sympy==1.13.1 156 | - threadpoolctl==3.6.0 157 | - tiktoken==0.8.0 158 | - tqdm==4.66.5 159 | - triton==3.2.0 160 | - typing-extensions==4.12.2 161 | - tzdata==2025.2 162 | - websockets==15.0.1 163 | - xxhash==3.5.0 164 | - yarl==1.20.1 165 | -------------------------------------------------------------------------------- /training_code/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from dataclasses import dataclass 3 | 4 | 5 | @dataclass(frozen=False) 6 | class Config: 7 | # DL Args 8 | learning_rate: float = 0.0005 9 | weight_decay: float = 0.01 10 | adam_epsilon: float = 1e-6 11 | warmup_steps: int = 100 12 | batch_size: int = 16 13 | num_worker: int = 16 14 | num_train_epochs: int = 10 15 | gradient_accumulation_steps: int = 1 16 | sample_rate: int = 16000 17 | ckpt: str = None 18 | no_logger: bool = False 19 | dataset: str = None 20 | name: str = None 21 | top_k: int = -1 22 | early_stop: bool = False 23 | fast_dev_run: int = None 24 | strategy: str = "ddp" 25 | seed: int = 3407 26 | custom_len: int = 0 27 | 28 | # Whisper args 29 | lang: str = "en" 30 | size: str = "tiny" 31 | 32 | # LoRA + streaming args 33 | lora: bool = False 34 | lora_ckpt: str = None 35 | rank: int = 16 36 | gran: int = 15 37 | extra_gran_blocks: int = 0 38 | sim_stream: bool = False 39 | uniform_sampling: bool = False 40 | streaming_train: bool = False 41 | streaming_fraction: float = 1 42 | streaming_random: bool = False 43 | multilingual: bool = False 44 | 45 | def parse_cmdl(): 46 | # parser 47 | parser = argparse.ArgumentParser(description="Training whisper models, using different configurations") 48 | 49 | # switches for train.py 50 | parser.add_argument('--lora', action="store_true", help="run LoRA training") 51 | parser.add_argument('--no_logger', action="store_true", help="set logger to False") 52 | 53 | # variables for trainer 54 | parser.add_argument('--name', type=str, help="Trained model name", default="model") 55 | parser.add_argument('--size', type=str, help="Whisper size - can use only [tiny, base, small, medium, large, large-v2]", default="tiny") 56 | parser.add_argument('--ckpt', type=str, help="ckpt loading to resume training from", default=None) 57 | parser.add_argument('--fast_dev_run', type=int, help="run few dev runs for sanity checks on lightning trainer", default=None) 58 | parser.add_argument('--top_k', type=int, help="Top K checkpoints to save, use 1 for the best, -1 for last", default=1) 59 | parser.add_argument('--early_stop', action="store_true", help="Use early stopping callback") 60 | parser.add_argument('--custom_len', type=int, help="Number of samples to train on", default=0) 61 | 62 | # DL Hyper Parameters 63 | parser.add_argument('--epochs', type=int, help="Number of training epochs", default=10) 64 | parser.add_argument('--batch_size', type=int, help="Batch size for training and evaluation. Better be 2^n, where n is a positive integer.", default=16) 65 | parser.add_argument('--dataset', type=str, help="Name of dataset to load", default='TIMIT-WORD') 66 | parser.add_argument('--learning_rate', type=float, help="Custom learning rate", default=0.0001) 67 | parser.add_argument('--gacc', type=int, help="Number of gradient accumulation steps", default=1) 68 | parser.add_argument('--weight_decay', type=float, help="Weight decay factor", default=0.01) 69 | parser.add_argument('--adam_epsilon', type=float, help="Adam epsilon", default=1e-6) 70 | parser.add_argument('--warmup_steps', type=int, help="Scheduler warmup steps", default=100) 71 | parser.add_argument('--num_worker', type=int, help="Data loader workers", default=16) 72 | parser.add_argument('--strategy', type=str, help="Trainer strategy [ddp, fsdp, ddp_find_unused_parameters_true]", default="ddp") 73 | 74 | # LoRA 75 | parser.add_argument('--lora_ckpt', type=str, help="ckpt loading (for LoRA training mode only)", default=None) 76 | parser.add_argument('--rank', type=int, help="LoRA rank", default=16) 77 | parser.add_argument('--gran', type=int, help="Granularity in encoder frames to calc attention on", default=15) 78 | parser.add_argument('--extra_gran_blocks', type=int, help="How many extra granularity blocks we add on encoder causal block matrix", default=0) 79 | parser.add_argument('--streaming_fraction', type=float, help="Fraction of the available streaming sample points to train on.", default=1) 80 | parser.add_argument('--simulate_stream', action="store_true", help="Spectrogram input is simulated to a stream scenario") 81 | parser.add_argument('--streaming_train', action="store_true", help="Train sequentially on a stream of data.") 82 | parser.add_argument('--streaming_random', action="store_true", help="Train using random sample points, not sequentially!") 83 | parser.add_argument('--multilingual', action="store_true", help="Train using multilingual dataset, assuming lang field is available.") 84 | 85 | return parser.parse_args() 86 | -------------------------------------------------------------------------------- /training_code/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import torch 4 | from pathlib import Path 5 | from ds_dict import ds_paths 6 | from whisper_module import LoRAStreamedWhisper 7 | from training_code.utils import Config, parse_cmdl 8 | from pytorch_lightning.loggers import WandbLogger 9 | from pytorch_lightning import Trainer, seed_everything 10 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping 11 | 12 | 13 | SEED = 3407 14 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 15 | # DEVICE = "cpu" 16 | seed_everything(SEED, workers=True) 17 | torch.set_float32_matmul_precision("high") 18 | 19 | logs_root = "/mlspeech/data/tomer/streaming_whisper/models/logs" 20 | ckpt_root = "/mlspeech/data/tomer/streaming_whisper/models/ckpts" 21 | 22 | whisper_lrs: dict[str, float] = {'tiny': 1.5e-3, 'base': 1e-3, 'small': 5e-4, 'medium': 2.5e-4, 'large': 1.75e-4, 'large-v2': 2e-4} 23 | 24 | project_names = { 25 | "lora": "LoRA_whisper_stream", 26 | } 27 | 28 | def train_model(log_output_dir, check_output_dir, model_name, train_set, val_set, train_name, project_name, cfg: Config) -> None: 29 | 30 | Path(log_output_dir).mkdir(exist_ok=True) 31 | Path(check_output_dir).mkdir(exist_ok=True) 32 | 33 | wandblogger = WandbLogger( 34 | save_dir=log_output_dir, 35 | name=train_name, 36 | project=project_names[project_name] 37 | ) 38 | 39 | checkpoint_callback = ModelCheckpoint( 40 | dirpath=f"{check_output_dir}/checkpoint", 41 | filename="checkpoint-{epoch:04d}", 42 | save_top_k=cfg.top_k, # Best model save, 43 | monitor="val/wer" 44 | ) 45 | 46 | callback_list = [checkpoint_callback, LearningRateMonitor(logging_interval="epoch")] 47 | 48 | if cfg.early_stop: 49 | early_stop_callback = EarlyStopping( 50 | monitor="val/wer", 51 | min_delta=0.00, 52 | patience=2, 53 | mode="min", 54 | ) 55 | callback_list.append(early_stop_callback) 56 | 57 | # Model mux 58 | if cfg.lora and cfg.streaming_train: 59 | model = LoRAStreamedWhisper(cfg, model_name, cfg.lang, train_set, val_set, rank=cfg.rank, enc_emb_gran=cfg.gran, enc_context=cfg.extra_gran_blocks, sim_stream=cfg.sim_stream) 60 | 61 | trainer = Trainer( 62 | accelerator=DEVICE, 63 | max_epochs=cfg.num_train_epochs, 64 | callbacks=callback_list, 65 | logger=wandblogger if not cfg.no_logger else False, 66 | deterministic=True, 67 | num_sanity_val_steps=1, 68 | strategy=cfg.strategy, 69 | fast_dev_run=cfg.fast_dev_run, 70 | # precision="16" 71 | # accumulate_grad_batches=cfg.gradient_accumulation_steps, 72 | ) 73 | 74 | if cfg.ckpt is None: trainer.fit(model) 75 | else: trainer.fit(model, ckpt_path=cfg.ckpt) 76 | 77 | 78 | if __name__ == "__main__": 79 | 80 | project_name = None 81 | 82 | args = parse_cmdl() 83 | 84 | # Training config 85 | cfg = Config( 86 | learning_rate=args.learning_rate, 87 | weight_decay=args.weight_decay, 88 | adam_epsilon=args.adam_epsilon, 89 | warmup_steps=args.warmup_steps, 90 | batch_size=args.batch_size, 91 | num_worker=args.num_worker, 92 | num_train_epochs=args.epochs, 93 | gradient_accumulation_steps=args.gacc, 94 | no_logger=args.no_logger, 95 | dataset=args.dataset, 96 | name=args.name, 97 | top_k=args.top_k, 98 | sample_rate=16_000, 99 | ckpt=args.ckpt, 100 | size=args.size, 101 | lora=args.lora, 102 | lora_ckpt=args.lora_ckpt, 103 | rank=args.rank, 104 | gran=args.gran, 105 | extra_gran_blocks=args.extra_gran_blocks, 106 | sim_stream=args.simulate_stream, 107 | fast_dev_run=args.fast_dev_run, 108 | early_stop=args.early_stop, 109 | strategy=args.strategy, 110 | streaming_train=args.streaming_train, 111 | streaming_random=args.streaming_random, 112 | streaming_fraction=args.streaming_fraction, 113 | seed=SEED, 114 | multilingual=args.multilingual, 115 | custom_len=args.custom_len 116 | ) 117 | 118 | if cfg.streaming_train: 119 | assert cfg.sim_stream == cfg.streaming_train, "When running in full stream mode you must simulate streaming!" 120 | cfg.sim_stream = True 121 | 122 | lr_addition = f"_LR-{cfg.learning_rate}" 123 | effective_bsize = cfg.batch_size * cfg.gradient_accumulation_steps 124 | 125 | if cfg.lora and cfg.streaming_train: 126 | dir_name = f"LoRA_streamed_whisper_{cfg.size}_{cfg.dataset}_{effective_bsize}_{cfg.name}{lr_addition}_r{cfg.rank}_g{cfg.gran}_eg{cfg.extra_gran_blocks}_top{cfg.top_k}_full-stream{cfg.streaming_train}_random-order{cfg.streaming_random}_fraction{cfg.streaming_fraction}" 127 | project_name = "lora" 128 | 129 | # Run trainer 130 | train_model( 131 | log_output_dir=os.path.join(logs_root, dir_name), 132 | check_output_dir=os.path.join(ckpt_root, dir_name), 133 | model_name=args.size, 134 | train_set=ds_paths[args.dataset]['train'], 135 | val_set=ds_paths[args.dataset]['val'], 136 | train_name=dir_name, 137 | project_name=project_name, 138 | cfg=cfg 139 | ) 140 | 141 | -------------------------------------------------------------------------------- /training_code/datasets_classes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import pandas as pd 4 | import careless_whisper_stream 5 | import careless_whisper_stream.tokenizer 6 | from praatio import textgrid 7 | from dataclasses import dataclass 8 | from careless_whisper_stream.tokenizer import Tokenizer 9 | from careless_whisper_stream.audio import SpectrogramStream 10 | 11 | class WAVsDataset(torch.utils.data.Dataset): 12 | def __init__(self, ds_path: str, 13 | sep="\t", 14 | tokenizer: careless_whisper_stream.tokenizer = None, 15 | no_labels: bool = False, 16 | custom_len: int = 0, 17 | get_streamed_mel: bool = False) -> None: 18 | super().__init__() 19 | 20 | if not no_labels: 21 | self.tokenizer = tokenizer if tokenizer else careless_whisper_stream.tokenizer.get_tokenizer(True, language="en", task="transcribe") 22 | self.ds_df = pd.read_csv(ds_path, sep=sep) 23 | self.sr = 16_000 24 | self.no_labels = no_labels 25 | self.custom_len = custom_len 26 | self.get_streamed_mel = get_streamed_mel 27 | 28 | def __len__(self): 29 | return int(self.custom_len) if 0 < self.custom_len < len(self.ds_df) else int(len(self.ds_df)) 30 | 31 | def _calc_mel(self, audio): 32 | if self.get_streamed_mel: 33 | spec_streamer = SpectrogramStream() 34 | return spec_streamer._simulate_streaming_log_spec(torch.tensor(audio)).squeeze(0) 35 | 36 | return careless_whisper_stream.log_mel_spectrogram(audio) 37 | 38 | def __getitem__(self, idx): 39 | item = self.ds_df.iloc[idx] 40 | 41 | audio = careless_whisper_stream.load_audio(item["wav_path"], sr=self.sr) 42 | audio = careless_whisper_stream.pad_or_trim(audio.flatten()) 43 | mel = self._calc_mel(audio) 44 | 45 | if self.no_labels: return dict(input_ids=mel) 46 | 47 | text = item["raw_text"] 48 | text = [*self.tokenizer.sot_sequence_including_notimestamps] + self.tokenizer.encode(text) 49 | labels = text[1:] + [self.tokenizer.eot] 50 | 51 | return dict(input_ids=mel, labels=labels, dec_input_ids=text) 52 | 53 | 54 | @dataclass 55 | class Interval: 56 | label: str = None 57 | start: float = 0.0 58 | end: float = 0.0 59 | 60 | class AlignedTextGridDataset(torch.utils.data.Dataset): 61 | def __init__(self, 62 | ds_path: str, 63 | tokenizer: Tokenizer = None, 64 | sample_rate: int = 16_000, 65 | custom_len: int = 0, 66 | get_streamed_mel: bool = False, 67 | gran: int = 15, 68 | extra_gran_blocks: int = 0, 69 | n_mels: int = 80, 70 | multilingual: bool = False): # most of the times we train just on english librispeech 71 | super().__init__() 72 | 73 | self.tokenizer = tokenizer if tokenizer else careless_whisper_stream.tokenizer.get_tokenizer(True, language="en", task="transcribe") 74 | self.ds_df = pd.read_csv(ds_path) 75 | self.sr = sample_rate 76 | self.custom_len = custom_len 77 | self.get_streamed_mel = get_streamed_mel 78 | self.gran = gran 79 | self.extra_gran_blocks = extra_gran_blocks 80 | self.n_mels = n_mels 81 | self.multilingual = multilingual 82 | 83 | def __len__(self): 84 | return int(self.custom_len) if 0 < self.custom_len < len(self.ds_df) else len(self.ds_df) 85 | 86 | def _calc_mel(self, audio): 87 | if self.get_streamed_mel: 88 | spec_streamer = SpectrogramStream(n_mels=self.n_mels) 89 | return spec_streamer._simulate_streaming_log_spec(torch.tensor(audio)) 90 | 91 | return careless_whisper_stream.log_mel_spectrogram(audio) 92 | 93 | def _get_intervals_from_wrd_file(self, path: str): 94 | with open(path, "r") as file: 95 | lines = file.readlines() 96 | 97 | intervals = [] 98 | for line in lines: 99 | start, end, label = line.strip().split() 100 | intervals.append(Interval(label, int(start) / self.sr, int(end) / self.sr)) 101 | 102 | return intervals 103 | 104 | def __getitem__(self, index): 105 | item = self.ds_df.iloc[index] 106 | 107 | audio = careless_whisper_stream.pad_or_trim(careless_whisper_stream.load_audio(item["wav_path"], sr=self.sr)) 108 | mel = self._calc_mel(audio) 109 | 110 | if ".wrd" in item["tg_path"]: 111 | text_intervals = self._get_intervals_from_wrd_file(item["tg_path"]) 112 | else: 113 | tg = textgrid.openTextgrid(item["tg_path"], includeEmptyIntervals=False) 114 | text_intervals = tg.getTier("words") 115 | 116 | tokenizer = self.tokenizer if not self.multilingual else careless_whisper_stream.tokenizer.get_tokenizer(True, language=item["lang"], task="transcribe") 117 | 118 | endpoints = [0, 0, 0] 119 | tokens = [] 120 | for i, interval in enumerate(text_intervals): 121 | curr_tokens = self.tokenizer.encode(interval.label if i == 0 else " " + interval.label) 122 | n_diff = (interval.end - interval.start) / len(curr_tokens) 123 | endpoints.extend([interval.start + (i + 1) * n_diff for i in range(len(curr_tokens))]) 124 | tokens.extend(curr_tokens) 125 | 126 | text = [*tokenizer.sot_sequence_including_notimestamps] + tokens 127 | labels = text[1:] + [self.tokenizer.eot] 128 | endpoints.append(endpoints[-1] + 0.5) 129 | 130 | assert len(endpoints) == len(labels) == len(text) 131 | 132 | return dict(input_ids=mel, 133 | dec_input_ids=torch.tensor(text), 134 | labels=torch.tensor(labels), 135 | endpoints=torch.tensor(endpoints)) 136 | 137 | 138 | class TIMIT(torch.utils.data.Dataset): 139 | def __init__(self, ds_path: str, tokenizer: Tokenizer = None, n_state: int = 384) -> None: 140 | 141 | self.tokenizer = tokenizer if tokenizer else careless_whisper_stream.tokenizer.get_tokenizer(True, language="en", task="transcribe") 142 | 143 | with open(ds_path, 'rb') as file: 144 | self.dataset = pickle.load(file) 145 | 146 | self.n_state = n_state 147 | 148 | def __len__(self): 149 | return len(self.dataset) 150 | 151 | def __getitem__(self, index): 152 | audio, sr, text, _, _ = self.dataset[index] 153 | audio_len = audio.shape[-1] 154 | assert sr == 16000 155 | audio = careless_whisper_stream.pad_or_trim(torch.Tensor(audio).flatten()) 156 | mel = careless_whisper_stream.log_mel_spectrogram(audio) 157 | 158 | text = [*self.tokenizer.sot_sequence_including_notimestamps] + self.tokenizer.encode(text) 159 | labels = text[1:] + [self.tokenizer.eot] 160 | 161 | num_frames = ((audio_len // 16000) * 50) + 2 162 | mask = torch.ones(1, 1500, self.n_state) 163 | mask[0, num_frames:, :] = 0 164 | 165 | return dict( 166 | input_ids=mel, 167 | labels=labels, 168 | dec_input_ids=text, 169 | mask=mask, 170 | ) 171 | -------------------------------------------------------------------------------- /careless_whisper_stream/streaming_transcribe.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') 3 | 4 | import argparse 5 | from typing import TYPE_CHECKING, List 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from .audio import ( 11 | SAMPLE_RATE, 12 | SpectrogramStream, 13 | MyStream 14 | ) 15 | from .streaming_decoding import DecodingOptions 16 | 17 | if TYPE_CHECKING: 18 | from .streaming_model import StreamingWhisper 19 | 20 | 21 | def transcribe( 22 | model: "StreamingWhisper" = None, 23 | output_filename: str = None, 24 | channels: int = 2, 25 | language: str = "en", 26 | simulate_stream: bool = False, 27 | wav_file: str = None, 28 | single_frame_mel: bool = True, 29 | temperature: float = 0, 30 | beam_size: int = 5, 31 | stream_decode: bool = True, 32 | ca_kv_cache: bool = False, 33 | sa_kv_cache: bool = False, 34 | use_latency: bool = False, 35 | pad_trim: bool = False, 36 | max_sec_context: int = 30, 37 | streaming_timestamps: bool = False, 38 | force_first_tokens_timestamps: bool = False, 39 | **kwargs 40 | ) -> List[str]: 41 | """ 42 | Open a stream and transcribe it using streaming whisper model 43 | 44 | A very thin implementation of the transcribe function, compared to Whisper implementation. 45 | 46 | Parameters 47 | ---------- 48 | model: Whisper 49 | The Whisper model instance 50 | 51 | Returns - 52 | ------- 53 | A dict with a text, tokens field with all of the text that was transcribed till the stream stopped. 54 | """ 55 | model.reset(use_stream=True) # we first reset the model before starting a stream, cleaning any cache. 56 | model.eval() 57 | 58 | # Instantiate streaming instance and open a stream 59 | ms_gran = model.encoder.gran * 20 60 | stream_instance = MyStream(ms_gran, 61 | channels=channels, 62 | filename=output_filename, 63 | simulate_stream=simulate_stream, 64 | wav_file=wav_file, 65 | use_latency=use_latency, 66 | pad_trim=pad_trim,) 67 | 68 | stream_instance.open_stream() 69 | 70 | # frames - used only when filename is given, in order to save a long wav at the end of the conversation. 71 | frames = [] 72 | 73 | # first we'll use 74 | decoding_options = DecodingOptions( 75 | language=language, 76 | gran=model.encoder.gran, 77 | single_frame_mel=single_frame_mel, 78 | without_timestamps=True, 79 | beam_size=beam_size if temperature == 0 else None, 80 | temperature=temperature, 81 | length_penalty=None, 82 | look_ahead_blocks=model.encoder.extra_gran_blocks, 83 | patience=None, 84 | stream_decode=stream_decode, 85 | use_kv_cache=sa_kv_cache, 86 | use_ca_kv_cache=ca_kv_cache, 87 | streaming_timestamps=streaming_timestamps, 88 | force_first_tokens_timestamps=force_first_tokens_timestamps, 89 | **kwargs 90 | ) 91 | 92 | streamed_spectrogram = SpectrogramStream(n_mels=model.dims.n_mels) # default values are whisper default values 93 | 94 | texts = [] 95 | reset_len = (max_sec_context * SAMPLE_RATE) + 360 # 360 is for the mel padding 96 | try: 97 | for frame in stream_instance.read(): 98 | # save frames for optional save 99 | frames.extend(frame) 100 | 101 | if len(frames) > reset_len: # When we surpass the max_sec_context - reset model (positional embeddings constrain us) 102 | frame = np.concatenate((frames[-360:], frame)) 103 | frames = [] 104 | frames.extend(frame.tolist()) 105 | model.reset(use_stream=True) 106 | streamed_spectrogram.reset() 107 | 108 | frame_tensor = torch.from_numpy(frame).pin_memory() 109 | mel_frame = streamed_spectrogram.calc_mel_with_new_frame(frame_tensor.to(model.device, non_blocking=True)) 110 | 111 | # decode given the new mel frame and print results 112 | result = model.decode(mel_frame.squeeze(0), decoding_options) 113 | 114 | print(result.text) 115 | 116 | texts.append(result) 117 | 118 | except KeyboardInterrupt: 119 | stream_instance.close_stream(frames) 120 | 121 | print("Finished capturing audio.") 122 | 123 | return texts 124 | 125 | 126 | def cli(): 127 | parser = argparse.ArgumentParser(description="Transcribe streaming audio with customizable options") 128 | 129 | # Model choices 130 | parser.add_argument("--model", type=str, default="small", help="Model size to transcribe with") 131 | parser.add_argument("--device", type=str, default="cpu", help="Device to run model inference on.") 132 | parser.add_argument("--chunk_size", type=int, default=300, help="Chunk size for streaming") 133 | parser.add_argument("--multilingual", action="store_true", help="Use a multilingual checkpoint if exists.", default=False) 134 | 135 | # Local streaming args 136 | parser.add_argument("--output_filename", type=str, help="Path to the output audio file when using local streaming") 137 | parser.add_argument("--channels", type=int, default=2, help="Number of audio channels - relevant for local streaming") 138 | 139 | # Streaming simulation wav file 140 | parser.add_argument("--wav_file", type=str, help="Optional WAV file path to stream, using a stream simulation") 141 | 142 | # Streaming behavior 143 | parser.add_argument("--simulate_stream", action="store_true", help="Simulate a stream from a file") 144 | parser.add_argument("--single_frame_mel", action="store_true", default=True, help="Use single frame MELs") 145 | parser.add_argument("--stream_decode", action="store_true", default=True, help="Use streaming decode") 146 | parser.add_argument("--ca_kv_cache", action="store_true", help="Use cross-attention key-value cache") 147 | parser.add_argument("--sa_kv_cache", action="store_true", help="Use self-attention key-value cache") 148 | parser.add_argument("--wait_for_all", action="store_true", help="Wait for all results before outputting") 149 | parser.add_argument("--use_latency", action="store_true", help="Add latency for streaming simulation") 150 | parser.add_argument("--pad_trim", action="store_true", default=False, help="Enable padding and trimming") 151 | parser.add_argument("--streaming_timestamps", action="store_true", help="Use timestamps in streaming") 152 | parser.add_argument("--force_first_tokens_timestamps", action="store_true", help="Force timestamps on first tokens") 153 | 154 | # Model behavior 155 | parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature") 156 | parser.add_argument("--beam_size", type=int, default=5, help="Beam size for beam search decoding") 157 | parser.add_argument("--language", type=str, default="en", help="Language of transcription") 158 | parser.add_argument("--max_sec_context", type=int, default=30, help="Max context window size in seconds") 159 | 160 | args = parser.parse_args().__dict__ 161 | 162 | from . import load_streaming_model 163 | 164 | model_size: str = args.pop("model") 165 | chunk_size: int = args.pop("chunk_size") 166 | multilingual: bool = args.pop("multilingual") 167 | device: str = args.pop("device") 168 | 169 | model = load_streaming_model(model_size, chunk_size, multilingual, device) 170 | 171 | texts = transcribe(model, **args) 172 | return texts 173 | 174 | if __name__ == "__main__": 175 | cli() 176 | 177 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CarelessWhisper - Causal Whisper Streaming Model 2 | Causal Whisper Streaming is a fine tuned version of OpenAI Whisper, which can handle causal data and perform real-time transcription. 3 | 4 | [![arXiv](https://img.shields.io/badge/arXiv-2508.12301-b31b1b.svg)](https://arxiv.org/abs/2508.12301) [![Demo on Hugging Face](https://img.shields.io/badge/🤗%20Demo-Hugging%20Face-blueviolet?logo=huggingface&logoColor=white)](https://huggingface.co/spaces/MLSpeech/CarelessWhisper-causal-streaming) 5 | 6 | ## 📄 Paper 7 | 8 | For more details, see our [paper](https://arxiv.org/abs/2508.12301). 9 | 10 | ## 🔧 Setup 11 | We used Python 3.9.16, PyTorch 2.6.0, and PyTorch-Lightning 2.5.0 to train and test our models. 12 | Portions of this code are adapted from [OpenAI's Whisper](https://github.com/openai/whisper). 13 | 14 | To set up the project environment using `conda`, follow these steps: 15 | 16 | 1. **Clone the repository** 17 | ```bash 18 | git clone https://github.com/tomer9080/CarelessWhisper-streaming 19 | cd CarelessWhisper-streaming 20 | ``` 21 | 22 | > 💡 Make sure you have [Miniconda](https://docs.conda.io/en/latest/miniconda.html) or [Anaconda](https://www.anaconda.com/products/distribution) installed before proceeding. 23 | 24 | 2. **Create the conda environment** 25 | ```bash 26 | conda env create -f environment.yml 27 | ``` 28 | 29 | 3. **Activate The environment** 30 | ```bash 31 | conda activate careless_whisper 32 | ``` 33 | 34 | 4. **Install the appropriate PyTorch version** 35 | Depending on your hardware and CUDA version, install PyTorch by following the instructions at [https://pytorch.org/get-started/locally](https://pytorch.org/get-started/locally). 36 | This project was tested with CUDA 12.4, but it should also work with compatible earlier or later versions. 37 | 38 | After installing all of the dependencies, you can try to run inference. 39 | 40 | ## 🤖 Available Models 41 | We fine-tuned three different sizes of Whisper, all support english only transcription. 42 | A `large-v2` that was fine tuned on multilingual data is available, and supports English, French, Spanish, German and Portuguese with chunk size of 300 miliseconds. 43 | 44 | | Size | Chunk Size [msec] | Multilingual | 45 | |:----:|:-----------------:|:------------:| 46 | | base | 40, 100, 200, 300 | N/A | 47 | | small| 40, 100, 200, 300, 1000| N/A | 48 | |large-v2| 40, 100, 200, 300, 1000| 300 | 49 | 50 | 51 | ## 🎤 Running Inference 52 | To run inference, download the repo content, and run from the repository root accroding to following sections. 53 | 54 | > **Note:** The models are hosted on the [Hugging Face Hub](https://huggingface.co/), which requires an access token. 55 | > Make sure you are logged in with your token to access the models. 56 | 57 | ### How to Apply Your Hugging Face 🤗 Access Token 58 | 59 | 1. **Create a Hugging Face account** (if you don’t have one) at [https://huggingface.co/join](https://huggingface.co/join). 60 | 61 | 2. **Generate an access token:** 62 | - Go to your Hugging Face account settings: [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens) 63 | - Click on **"New token"**, give it a name, select the appropriate scopes (usually `read` is enough), and create it. 64 | 65 | 3. **Login using the Hugging Face CLI:** 66 | Install the CLI if you don’t have it: 67 | ```bash 68 | pip install huggingface_hub 69 | ``` 70 | Then login: 71 | ```bash 72 | huggingface-cli login 73 | ``` 74 | Paste your token when prompted. 75 | 76 | 77 | ### 🖥️ CLI Usage 78 | The transcription model is easily activated using the next command: 79 | ```bash 80 | # Using a local microphone for streaming transcription, dumping the recording to out.wav 81 | python transcribe.py \ 82 | --output_filename out.wav \ 83 | --channels 2 \ 84 | --model small \ 85 | --chunk_size 300 \ 86 | --device cuda \ 87 | --beam_size 5 \ 88 | --ca_kv_cache \ 89 | ``` 90 | 91 | A simulation of a stream on a wav file is also available: 92 | ```bash 93 | # Simulating a stream on a wav file 94 | python transcribe.py \ 95 | --model small \ 96 | --chunk_size 300 \ 97 | --device cuda \ 98 | --beam_size 5 \ 99 | --ca_kv_cache \ 100 | --wav_file /path/to/audio.wav \ 101 | --simulate_stream \ 102 | --use_latency 103 | ``` 104 | 105 | ### 🐍 Python Usage 106 | If you prefer using python, a code sinppet utilizing a microphone or a wav file is provided below: 107 | 108 | ```python 109 | import torch 110 | import careless_whisper_stream 111 | 112 | model_size = "small" # model size 113 | chunk_size = 300 # chunk size in milliseconds 114 | multilingual = False # currently on large-v2_300msec supports other languages than english. 115 | device = "cuda" if torch.cuda.is_available() else "cpu" 116 | 117 | model = careless_whisper_stream.load_streaming_model(name=model_size, 118 | gran=chunk_size, 119 | multilingual=multilingual, 120 | device=device) 121 | 122 | # using a local microphone recording 123 | texts_microphone = model.transcribe(output_filename="/path/to/dump/file.wav", 124 | channels=2, 125 | beam_size=5, 126 | ca_kv_cache=True) 127 | 128 | # Simulating on a wav file 129 | texts_wav_simulation = model.transcribe(simulate_stream=True, 130 | wav_file="/path/to/file/you/want/to/transcribe.wav", 131 | beam_size=5, 132 | ca_kv_cache=True) 133 | ``` 134 | 135 | ## 🦾 Training 136 | In order to train using LoRA, you can use our existing code. Make sure all the requirements are installed. 137 | 138 | ### 📂 Dataset Structure 139 | 140 | Before starting model training using the command-line interface provided below, you must first configure your dataset dictionary file located at `training_code/ds_dict.py`. 141 | 142 | This file defines a Python dictionary named `ds_paths`, where you should specify paths to the `train`, `val`, and `test` partitions of your dataset. Each partition should be a CSV file with the following three columns: 143 | 144 | 1. `wav_path` — Path to the WAV audio file. 145 | 2. `tg_path` — Path to the corresponding `.TextGrid` file containing forced alignment. 146 | 3. `raw_text` — Ground truth transcription. 147 | 148 | > **Note:** The dictionary key (i.e., the name of the dataset) will be used by the training script to identify and load the dataset correctly. 149 | 150 | You can find an example entry in `training_code/ds_dict.py`. 151 | 152 | > **Note:** We used [Montreal Forced Aligner (MFA)](https://montreal-forced-aligner.readthedocs.io/en/latest/index.html) to force-align our dataset. 153 | 154 | To run the same force-alignment process as described in the paper, use: 155 | 156 | ```bash 157 | mfa align --clean /dataset/root/path english_us_arpa english_us_arpa /aligned_dataset/root/path 158 | ``` 159 | 160 | For more details on how to run using `mfa` command, visit [MFA site](https://montreal-forced-aligner.readthedocs.io/en/latest/index.html). 161 | 162 | ### 🖥️ CLI Interface 163 | ```bash 164 | python training_code/train.py \ 165 | --lora \ 166 | --streaming_train \ 167 | --simulate_stream \ 168 | --dataset LIBRI-960-ALIGNED \ 169 | --name example_training_base_model \ 170 | --size base \ 171 | --batch_size 32 \ 172 | --epochs 10 \ 173 | --learning_rate 1e-5 \ 174 | --rank 32 \ 175 | --gran 15 \ 176 | --extra_gran_blocks 1 \ 177 | --streaming_fraction 0.25 \ 178 | --top_k 5 \ 179 | ``` 180 | 181 | For more options and training configurations, run: 182 | ```bash 183 | python training_code/train.py --help 184 | ``` 185 | 186 | ## 📜 License 187 | 188 | This repository uses a dual license: 189 | 190 | [![MIT License](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) 191 | Portions derived from [OpenAI Whisper](https://github.com/openai/whisper) are licensed under the **MIT License**. 192 | 193 | [![CC BY-NC 4.0 License](https://img.shields.io/badge/License-CC--BY--NC%204.0-blue.svg)](https://creativecommons.org/licenses/by-nc/4.0/) 194 | All other original code in this repository is licensed under the **Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0)**. 195 | 196 | See the [LICENSE](./LICENSE) file for full details. 197 | -------------------------------------------------------------------------------- /careless_whisper_stream/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | import sys 5 | import zlib 6 | from typing import Callable, List, Optional, TextIO 7 | 8 | system_encoding = sys.getdefaultencoding() 9 | 10 | if system_encoding != "utf-8": 11 | 12 | def make_safe(string): 13 | # replaces any character not representable using the system default encoding with an '?', 14 | # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729). 15 | return string.encode(system_encoding, errors="replace").decode(system_encoding) 16 | 17 | else: 18 | 19 | def make_safe(string): 20 | # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding 21 | return string 22 | 23 | 24 | def exact_div(x, y): 25 | assert x % y == 0 26 | return x // y 27 | 28 | 29 | def str2bool(string): 30 | str2val = {"True": True, "False": False} 31 | if string in str2val: 32 | return str2val[string] 33 | else: 34 | raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") 35 | 36 | 37 | def optional_int(string): 38 | return None if string == "None" else int(string) 39 | 40 | 41 | def optional_float(string): 42 | return None if string == "None" else float(string) 43 | 44 | 45 | def compression_ratio(text) -> float: 46 | text_bytes = text.encode("utf-8") 47 | return len(text_bytes) / len(zlib.compress(text_bytes)) 48 | 49 | 50 | def format_timestamp( 51 | seconds: float, always_include_hours: bool = False, decimal_marker: str = "." 52 | ): 53 | assert seconds >= 0, "non-negative timestamp expected" 54 | milliseconds = round(seconds * 1000.0) 55 | 56 | hours = milliseconds // 3_600_000 57 | milliseconds -= hours * 3_600_000 58 | 59 | minutes = milliseconds // 60_000 60 | milliseconds -= minutes * 60_000 61 | 62 | seconds = milliseconds // 1_000 63 | milliseconds -= seconds * 1_000 64 | 65 | hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" 66 | return ( 67 | f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" 68 | ) 69 | 70 | 71 | def get_start(segments: List[dict]) -> Optional[float]: 72 | return next( 73 | (w["start"] for s in segments for w in s["words"]), 74 | segments[0]["start"] if segments else None, 75 | ) 76 | 77 | 78 | def get_end(segments: List[dict]) -> Optional[float]: 79 | return next( 80 | (w["end"] for s in reversed(segments) for w in reversed(s["words"])), 81 | segments[-1]["end"] if segments else None, 82 | ) 83 | 84 | 85 | class ResultWriter: 86 | extension: str 87 | 88 | def __init__(self, output_dir: str): 89 | self.output_dir = output_dir 90 | 91 | def __call__( 92 | self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs 93 | ): 94 | audio_basename = os.path.basename(audio_path) 95 | audio_basename = os.path.splitext(audio_basename)[0] 96 | output_path = os.path.join( 97 | self.output_dir, audio_basename + "." + self.extension 98 | ) 99 | 100 | with open(output_path, "w", encoding="utf-8") as f: 101 | self.write_result(result, file=f, options=options, **kwargs) 102 | 103 | def write_result( 104 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs 105 | ): 106 | raise NotImplementedError 107 | 108 | 109 | class WriteTXT(ResultWriter): 110 | extension: str = "txt" 111 | 112 | def write_result( 113 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs 114 | ): 115 | for segment in result["segments"]: 116 | print(segment["text"].strip(), file=file, flush=True) 117 | 118 | 119 | class SubtitlesWriter(ResultWriter): 120 | always_include_hours: bool 121 | decimal_marker: str 122 | 123 | def iterate_result( 124 | self, 125 | result: dict, 126 | options: Optional[dict] = None, 127 | *, 128 | max_line_width: Optional[int] = None, 129 | max_line_count: Optional[int] = None, 130 | highlight_words: bool = False, 131 | max_words_per_line: Optional[int] = None, 132 | ): 133 | options = options or {} 134 | max_line_width = max_line_width or options.get("max_line_width") 135 | max_line_count = max_line_count or options.get("max_line_count") 136 | highlight_words = highlight_words or options.get("highlight_words", False) 137 | max_words_per_line = max_words_per_line or options.get("max_words_per_line") 138 | preserve_segments = max_line_count is None or max_line_width is None 139 | max_line_width = max_line_width or 1000 140 | max_words_per_line = max_words_per_line or 1000 141 | 142 | def iterate_subtitles(): 143 | line_len = 0 144 | line_count = 1 145 | # the next subtitle to yield (a list of word timings with whitespace) 146 | subtitle: List[dict] = [] 147 | last: float = get_start(result["segments"]) or 0.0 148 | for segment in result["segments"]: 149 | chunk_index = 0 150 | words_count = max_words_per_line 151 | while chunk_index < len(segment["words"]): 152 | remaining_words = len(segment["words"]) - chunk_index 153 | if max_words_per_line > len(segment["words"]) - chunk_index: 154 | words_count = remaining_words 155 | for i, original_timing in enumerate( 156 | segment["words"][chunk_index : chunk_index + words_count] 157 | ): 158 | timing = original_timing.copy() 159 | long_pause = ( 160 | not preserve_segments and timing["start"] - last > 3.0 161 | ) 162 | has_room = line_len + len(timing["word"]) <= max_line_width 163 | seg_break = i == 0 and len(subtitle) > 0 and preserve_segments 164 | if ( 165 | line_len > 0 166 | and has_room 167 | and not long_pause 168 | and not seg_break 169 | ): 170 | # line continuation 171 | line_len += len(timing["word"]) 172 | else: 173 | # new line 174 | timing["word"] = timing["word"].strip() 175 | if ( 176 | len(subtitle) > 0 177 | and max_line_count is not None 178 | and (long_pause or line_count >= max_line_count) 179 | or seg_break 180 | ): 181 | # subtitle break 182 | yield subtitle 183 | subtitle = [] 184 | line_count = 1 185 | elif line_len > 0: 186 | # line break 187 | line_count += 1 188 | timing["word"] = "\n" + timing["word"] 189 | line_len = len(timing["word"].strip()) 190 | subtitle.append(timing) 191 | last = timing["start"] 192 | chunk_index += max_words_per_line 193 | if len(subtitle) > 0: 194 | yield subtitle 195 | 196 | if len(result["segments"]) > 0 and "words" in result["segments"][0]: 197 | for subtitle in iterate_subtitles(): 198 | subtitle_start = self.format_timestamp(subtitle[0]["start"]) 199 | subtitle_end = self.format_timestamp(subtitle[-1]["end"]) 200 | subtitle_text = "".join([word["word"] for word in subtitle]) 201 | if highlight_words: 202 | last = subtitle_start 203 | all_words = [timing["word"] for timing in subtitle] 204 | for i, this_word in enumerate(subtitle): 205 | start = self.format_timestamp(this_word["start"]) 206 | end = self.format_timestamp(this_word["end"]) 207 | if last != start: 208 | yield last, start, subtitle_text 209 | 210 | yield start, end, "".join( 211 | [ 212 | re.sub(r"^(\s*)(.*)$", r"\1\2", word) 213 | if j == i 214 | else word 215 | for j, word in enumerate(all_words) 216 | ] 217 | ) 218 | last = end 219 | else: 220 | yield subtitle_start, subtitle_end, subtitle_text 221 | else: 222 | for segment in result["segments"]: 223 | segment_start = self.format_timestamp(segment["start"]) 224 | segment_end = self.format_timestamp(segment["end"]) 225 | segment_text = segment["text"].strip().replace("-->", "->") 226 | yield segment_start, segment_end, segment_text 227 | 228 | def format_timestamp(self, seconds: float): 229 | return format_timestamp( 230 | seconds=seconds, 231 | always_include_hours=self.always_include_hours, 232 | decimal_marker=self.decimal_marker, 233 | ) 234 | 235 | 236 | class WriteVTT(SubtitlesWriter): 237 | extension: str = "vtt" 238 | always_include_hours: bool = False 239 | decimal_marker: str = "." 240 | 241 | def write_result( 242 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs 243 | ): 244 | print("WEBVTT\n", file=file) 245 | for start, end, text in self.iterate_result(result, options, **kwargs): 246 | print(f"{start} --> {end}\n{text}\n", file=file, flush=True) 247 | 248 | 249 | class WriteSRT(SubtitlesWriter): 250 | extension: str = "srt" 251 | always_include_hours: bool = True 252 | decimal_marker: str = "," 253 | 254 | def write_result( 255 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs 256 | ): 257 | for i, (start, end, text) in enumerate( 258 | self.iterate_result(result, options, **kwargs), start=1 259 | ): 260 | print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True) 261 | 262 | 263 | class WriteTSV(ResultWriter): 264 | """ 265 | Write a transcript to a file in TSV (tab-separated values) format containing lines like: 266 | \t\t 267 | 268 | Using integer milliseconds as start and end times means there's no chance of interference from 269 | an environment setting a language encoding that causes the decimal in a floating point number 270 | to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++. 271 | """ 272 | 273 | extension: str = "tsv" 274 | 275 | def write_result( 276 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs 277 | ): 278 | print("start", "end", "text", sep="\t", file=file) 279 | for segment in result["segments"]: 280 | print(round(1000 * segment["start"]), file=file, end="\t") 281 | print(round(1000 * segment["end"]), file=file, end="\t") 282 | print(segment["text"].strip().replace("\t", " "), file=file, flush=True) 283 | 284 | 285 | class WriteJSON(ResultWriter): 286 | extension: str = "json" 287 | 288 | def write_result( 289 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs 290 | ): 291 | json.dump(result, file) 292 | 293 | 294 | def get_writer( 295 | output_format: str, output_dir: str 296 | ) -> Callable[[dict, TextIO, dict], None]: 297 | writers = { 298 | "txt": WriteTXT, 299 | "vtt": WriteVTT, 300 | "srt": WriteSRT, 301 | "tsv": WriteTSV, 302 | "json": WriteJSON, 303 | } 304 | 305 | if output_format == "all": 306 | all_writers = [writer(output_dir) for writer in writers.values()] 307 | 308 | def write_all( 309 | result: dict, file: TextIO, options: Optional[dict] = None, **kwargs 310 | ): 311 | for writer in all_writers: 312 | writer(result, file, options, **kwargs) 313 | 314 | return write_all 315 | 316 | return writers[output_format](output_dir) 317 | -------------------------------------------------------------------------------- /careless_whisper_stream/model.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import gzip 3 | from dataclasses import dataclass 4 | from typing import Dict, Iterable, Optional, Tuple 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import Tensor, nn 10 | 11 | from .decoding import decode as decode_function 12 | from .decoding import detect_language as detect_language_function 13 | from .transcribe import transcribe as transcribe_function 14 | 15 | 16 | @dataclass 17 | class ModelDimensions: 18 | n_mels: int 19 | n_audio_ctx: int 20 | n_audio_state: int 21 | n_audio_head: int 22 | n_audio_layer: int 23 | n_vocab: int 24 | n_text_ctx: int 25 | n_text_state: int 26 | n_text_head: int 27 | n_text_layer: int 28 | 29 | 30 | class LayerNorm(nn.LayerNorm): 31 | def forward(self, x: Tensor) -> Tensor: 32 | return super().forward(x.float()).type(x.dtype) 33 | 34 | 35 | class Linear(nn.Linear): 36 | def forward(self, x: Tensor) -> Tensor: 37 | return F.linear( 38 | x, 39 | self.weight.to(x.dtype), 40 | None if self.bias is None else self.bias.to(x.dtype), 41 | ) 42 | 43 | 44 | class Conv1d(nn.Conv1d): 45 | def _conv_forward( 46 | self, x: Tensor, weight: Tensor, bias: Optional[Tensor] 47 | ) -> Tensor: 48 | return super()._conv_forward( 49 | x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) 50 | ) 51 | 52 | 53 | def sinusoids(length, channels, max_timescale=10000): 54 | """Returns sinusoids for positional embedding""" 55 | assert channels % 2 == 0 56 | log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) 57 | inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) 58 | scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] 59 | return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) 60 | 61 | 62 | class MultiHeadAttention(nn.Module): 63 | def __init__(self, n_state: int, n_head: int): 64 | super().__init__() 65 | self.n_head = n_head 66 | self.query = Linear(n_state, n_state) 67 | self.key = Linear(n_state, n_state, bias=False) 68 | self.value = Linear(n_state, n_state) 69 | self.out = Linear(n_state, n_state) 70 | 71 | def forward( 72 | self, 73 | x: Tensor, 74 | xa: Optional[Tensor] = None, 75 | mask: Optional[Tensor] = None, 76 | kv_cache: Optional[dict[any, Tensor]] = None, 77 | ): 78 | q = self.query(x) 79 | 80 | if kv_cache is None or xa is None or self.key not in kv_cache: 81 | # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; 82 | # otherwise, perform key/value projections for self- or cross-attention as usual. 83 | k = self.key(x if xa is None else xa) 84 | v = self.value(x if xa is None else xa) 85 | else: 86 | # for cross-attention, calculate keys and values once and reuse in subsequent calls. 87 | k = kv_cache[self.key] 88 | v = kv_cache[self.value] 89 | 90 | wv, qk = self.qkv_attention(q, k, v, mask) 91 | 92 | return self.out(wv), qk 93 | 94 | def qkv_attention( 95 | self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None 96 | ): 97 | # print(f"q shape: {q.shape}") 98 | n_batch, n_ctx, n_state = q.shape 99 | scale = (n_state // self.n_head) ** -0.25 100 | q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale 101 | k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale 102 | v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) 103 | 104 | qk = q @ k 105 | if mask is not None: 106 | qk = qk + mask[:n_ctx, :n_ctx] 107 | qk = qk.float() 108 | 109 | w = F.softmax(qk, dim=-1).to(q.dtype) 110 | return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() 111 | 112 | 113 | class ResidualAttentionBlock(nn.Module): 114 | def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): 115 | super().__init__() 116 | 117 | self.attn = MultiHeadAttention(n_state, n_head) 118 | self.attn_ln = LayerNorm(n_state) 119 | 120 | self.cross_attn = ( 121 | MultiHeadAttention(n_state, n_head) if cross_attention else None 122 | ) 123 | self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None 124 | 125 | n_mlp = n_state * 4 126 | self.mlp = nn.Sequential( 127 | Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state) 128 | ) 129 | self.mlp_ln = LayerNorm(n_state) 130 | 131 | def forward( 132 | self, 133 | x: Tensor, 134 | xa: Optional[Tensor] = None, 135 | mask: Optional[Tensor] = None, 136 | kv_cache: Optional[dict] = None, 137 | ): 138 | # SA 139 | x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] 140 | 141 | # CA 142 | if self.cross_attn: 143 | x = x + self.cross_attn(self.cross_attn_ln(x), xa)[0] 144 | 145 | # MLP 146 | x = x + self.mlp(self.mlp_ln(x)) 147 | 148 | return x 149 | 150 | 151 | class AudioEncoder(nn.Module): 152 | def __init__( 153 | self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int 154 | ): 155 | super().__init__() 156 | self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) 157 | self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) 158 | self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) 159 | 160 | self.n_head = n_head 161 | self.n_layer = n_layer 162 | self.n_state = n_state 163 | 164 | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( 165 | [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] 166 | ) 167 | self.ln_post = LayerNorm(n_state) 168 | 169 | def forward(self, x: Tensor): 170 | """ 171 | x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) 172 | the mel spectrogram of the audio 173 | """ 174 | x = F.gelu(self.conv1(x)) 175 | x = F.gelu(self.conv2(x)) 176 | x = x.permute(0, 2, 1) 177 | 178 | assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" 179 | x = (x + self.positional_embedding).to(x.dtype) 180 | 181 | for block in self.blocks: 182 | x = block(x) 183 | 184 | x = self.ln_post(x) 185 | return x 186 | 187 | 188 | class TextDecoder(nn.Module): 189 | def __init__( 190 | self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int 191 | ): 192 | super().__init__() 193 | 194 | self.token_embedding = nn.Embedding(n_vocab, n_state) 195 | self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) 196 | 197 | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( 198 | [ 199 | ResidualAttentionBlock(n_state, n_head, cross_attention=True) 200 | for _ in range(n_layer) 201 | ] 202 | ) 203 | self.ln = LayerNorm(n_state) 204 | 205 | mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) 206 | self.register_buffer("mask", mask, persistent=False) 207 | 208 | def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): 209 | """ 210 | x : torch.LongTensor, shape = (batch_size, <= n_ctx) 211 | the text tokens 212 | xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) 213 | the encoded audio features to be attended on 214 | dump_type: str - specifies which dump to return (MLP, pre_MLP, ATT) 215 | """ 216 | offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 217 | x = ( 218 | self.token_embedding(x) 219 | + self.positional_embedding[offset : offset + x.shape[-1]] 220 | ) 221 | x = x.to(xa.dtype) 222 | 223 | for block in self.blocks: 224 | x = block(x, xa, mask=self.mask, kv_cache=kv_cache) 225 | 226 | x = self.ln(x) 227 | logits = ( 228 | x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) 229 | ).float() 230 | 231 | return logits 232 | 233 | 234 | class Whisper(nn.Module): 235 | def __init__(self, dims: ModelDimensions): 236 | super().__init__() 237 | self.dims = dims 238 | self.encoder = AudioEncoder( 239 | self.dims.n_mels, 240 | self.dims.n_audio_ctx, 241 | self.dims.n_audio_state, 242 | self.dims.n_audio_head, 243 | self.dims.n_audio_layer, 244 | ) 245 | self.decoder = TextDecoder( 246 | self.dims.n_vocab, 247 | self.dims.n_text_ctx, 248 | self.dims.n_text_state, 249 | self.dims.n_text_head, 250 | self.dims.n_text_layer, 251 | ) 252 | # use the last half among the decoder layers for time alignment by default; 253 | # to use a specific set of heads, see `set_alignment_heads()` below. 254 | all_heads = torch.zeros( 255 | self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool 256 | ) 257 | all_heads[self.dims.n_text_layer // 2 :] = True 258 | # self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) 259 | self.register_buffer("alignment_heads", all_heads, persistent=False) # To use lightning can't use sparse weights 260 | 261 | def set_alignment_heads(self, dump: bytes): 262 | array = np.frombuffer( 263 | gzip.decompress(base64.b85decode(dump)), dtype=bool 264 | ).copy() 265 | mask = torch.from_numpy(array).reshape( 266 | self.dims.n_text_layer, self.dims.n_text_head 267 | ) 268 | # self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False) 269 | self.register_buffer("alignment_heads", mask, persistent=False) # To use lightning can't use sparse weights 270 | 271 | def embed_audio(self, mel: torch.Tensor): 272 | return self.encoder(mel) 273 | 274 | def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): 275 | return self.decoder(tokens, audio_features) 276 | 277 | def forward( 278 | self, mel: torch.Tensor, tokens: torch.Tensor 279 | ) -> Dict[str, torch.Tensor]: 280 | return self.decoder(tokens, self.encoder(mel)) 281 | 282 | @property 283 | def device(self): 284 | return next(self.parameters()).device 285 | 286 | @property 287 | def is_multilingual(self): 288 | return self.dims.n_vocab >= 51865 289 | 290 | @property 291 | def num_languages(self): 292 | return self.dims.n_vocab - 51765 - int(self.is_multilingual) 293 | 294 | def install_kv_cache_hooks(self, cache: Optional[dict] = None): 295 | """ 296 | The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value 297 | tensors calculated for the previous positions. This method returns a dictionary that stores 298 | all caches, and the necessary hooks for the key and value projection modules that save the 299 | intermediate tensors to be reused during later calculations. 300 | 301 | Returns 302 | ------- 303 | cache : Dict[nn.Module, torch.Tensor] 304 | A dictionary object mapping the key/value projection modules to its cache 305 | hooks : List[RemovableHandle] 306 | List of PyTorch RemovableHandle objects to stop the hooks to be called 307 | """ 308 | cache = {**cache} if cache is not None else {} 309 | hooks = [] 310 | 311 | def save_to_cache(module, _, output): 312 | if module not in cache or output.shape[1] > self.dims.n_text_ctx: 313 | # save as-is, for the first token or cross attention 314 | cache[module] = output 315 | else: 316 | cache[module] = torch.cat([cache[module], output], dim=1).detach() 317 | return cache[module] 318 | 319 | def install_hooks(layer: nn.Module): 320 | if isinstance(layer, MultiHeadAttention): 321 | hooks.append(layer.key.register_forward_hook(save_to_cache)) 322 | hooks.append(layer.value.register_forward_hook(save_to_cache)) 323 | 324 | self.decoder.apply(install_hooks) 325 | return cache, hooks 326 | 327 | detect_language = detect_language_function 328 | transcribe = transcribe_function 329 | decode = decode_function 330 | -------------------------------------------------------------------------------- /careless_whisper_stream/__init__.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import io 3 | import os 4 | import urllib 5 | import warnings 6 | from typing import List, Optional, Union 7 | 8 | import torch 9 | from tqdm import tqdm 10 | 11 | from .audio import load_audio, pad_or_trim, log_mel_spectrogram 12 | from .model import ModelDimensions, Whisper 13 | from .streaming_model import StreamingWhisper 14 | from .version import __version__ 15 | 16 | _MODELS = { 17 | "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", 18 | "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", 19 | "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", 20 | "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", 21 | "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", 22 | "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", 23 | "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", 24 | "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", 25 | "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", 26 | "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", 27 | "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", 28 | "large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", 29 | "large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt", 30 | "turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt", 31 | } 32 | 33 | _STREAMING_MODELS_HF = { 34 | "base": { 35 | "300": "base_300.pt", 36 | "200": "base_200.pt", 37 | "100": "base_100.pt", 38 | "40": "base_40.pt", 39 | }, 40 | "small": { 41 | "1000": "small_1000.pt", 42 | "300": "small_300.pt", 43 | "200": "small_200.pt", 44 | "100": "small_100.pt", 45 | "40": "small_40.pt", 46 | }, 47 | "large-v2": { 48 | "1000": "large-v2_1000.pt", 49 | "300": "large-v2_300.pt", 50 | "200": "large-v2_200.pt", 51 | "100": "large-v2_100.pt", 52 | "40": "large-v2_40.pt", 53 | "300-multi": "large-v2_300_multi.pt", 54 | } 55 | } 56 | 57 | # base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are 58 | # highly correlated to the word-level timing, i.e. the alignment between audio and text tokens. 59 | _ALIGNMENT_HEADS = { 60 | "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00", 61 | "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO", 62 | "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00", 63 | "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00", 65 | "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P%R7%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9", 68 | "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", 70 | "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", 71 | "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", 72 | "large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`", 73 | "turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`", 74 | } 75 | 76 | 77 | def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: 78 | os.makedirs(root, exist_ok=True) 79 | 80 | expected_sha256 = url.split("/")[-2] 81 | download_target = os.path.join(root, os.path.basename(url)) 82 | 83 | if os.path.exists(download_target) and not os.path.isfile(download_target): 84 | raise RuntimeError(f"{download_target} exists and is not a regular file") 85 | 86 | if os.path.isfile(download_target): 87 | with open(download_target, "rb") as f: 88 | model_bytes = f.read() 89 | if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: 90 | return model_bytes if in_memory else download_target 91 | else: 92 | warnings.warn( 93 | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" 94 | ) 95 | 96 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 97 | with tqdm( 98 | total=int(source.info().get("Content-Length")), 99 | ncols=80, 100 | unit="iB", 101 | unit_scale=True, 102 | unit_divisor=1024, 103 | ) as loop: 104 | while True: 105 | buffer = source.read(8192) 106 | if not buffer: 107 | break 108 | 109 | output.write(buffer) 110 | loop.update(len(buffer)) 111 | 112 | model_bytes = open(download_target, "rb").read() 113 | if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: 114 | raise RuntimeError( 115 | "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." 116 | ) 117 | 118 | return model_bytes if in_memory else download_target 119 | 120 | 121 | def available_models() -> List[str]: 122 | """Returns the names of available models""" 123 | return list(_MODELS.keys()) 124 | 125 | 126 | def load_model( 127 | name: str, 128 | device: Optional[Union[str, torch.device]] = None, 129 | download_root: str = None, 130 | in_memory: bool = False, 131 | ) -> Whisper: 132 | """ 133 | Load a Whisper ASR model 134 | 135 | Parameters 136 | ---------- 137 | name : str 138 | one of the official model names listed by `whisper.available_models()`, or 139 | path to a model checkpoint containing the model dimensions and the model state_dict. 140 | device : Union[str, torch.device] 141 | the PyTorch device to put the model into 142 | download_root: str 143 | path to download the model files; by default, it uses "~/.cache/whisper" 144 | in_memory: bool 145 | whether to preload the model weights into host memory 146 | 147 | Returns 148 | ------- 149 | model : Whisper 150 | The Whisper ASR model instance 151 | """ 152 | 153 | if device is None: 154 | device = "cuda" if torch.cuda.is_available() else "cpu" 155 | if download_root is None: 156 | default = os.path.join(os.path.expanduser("~"), ".cache") 157 | download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") 158 | 159 | if name in _MODELS: 160 | checkpoint_file = _download(_MODELS[name], download_root, in_memory) 161 | alignment_heads = _ALIGNMENT_HEADS[name] 162 | elif os.path.isfile(name): 163 | checkpoint_file = open(name, "rb").read() if in_memory else name 164 | alignment_heads = None 165 | else: 166 | raise RuntimeError( 167 | f"Model {name} not found; available models = {available_models()}" 168 | ) 169 | 170 | with ( 171 | io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") 172 | ) as fp: 173 | checkpoint = torch.load(fp, map_location=device) 174 | del checkpoint_file 175 | 176 | dims = ModelDimensions(**checkpoint["dims"]) 177 | model = Whisper(dims) 178 | model.load_state_dict(checkpoint["model_state_dict"]) 179 | 180 | if alignment_heads is not None: 181 | model.set_alignment_heads(alignment_heads) 182 | 183 | return model.to(device) 184 | 185 | 186 | def load_streaming_model_for_train( 187 | name: str, 188 | advisor_ckpt_path: str = None, 189 | ft_model_ckpt_path: str = None, 190 | device: Optional[Union[str, torch.device]] = None, 191 | download_root: str = None, 192 | in_memory: bool = False, 193 | cache_gran: bool = True, 194 | gran: int = 15, 195 | rank: int = 8, 196 | extra_gran_blocks: int = 0, 197 | n_advisor_class: int = 4, 198 | **kwargs: any 199 | ) -> StreamingWhisper: 200 | """ 201 | Load a StreamingWhisper ASR model 202 | 203 | Parameters 204 | ---------- 205 | name : str 206 | one of the official model names listed by `whisper.available_models()`, or 207 | path to a model checkpoint containing the model dimensions and the model state_dict. 208 | device : Union[str, torch.device] 209 | the PyTorch device to put the model into 210 | download_root: str 211 | path to download the model files; by default, it uses "~/.cache/whisper" 212 | in_memory: bool 213 | whether to preload the model weights into host memory 214 | 215 | Returns 216 | ------- 217 | model : Whisper 218 | The Whisper ASR model instance 219 | """ 220 | if ft_model_ckpt_path is None: 221 | if device is None: 222 | device = "cuda" if torch.cuda.is_available() else "cpu" 223 | if download_root is None: 224 | default = os.path.join(os.path.expanduser("~"), ".cache") 225 | download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") 226 | 227 | if name in _MODELS: 228 | checkpoint_file = _download(_MODELS[name], download_root, in_memory) 229 | alignment_heads = _ALIGNMENT_HEADS[name] 230 | elif os.path.isfile(name): 231 | checkpoint_file = open(name, "rb").read() if in_memory else name 232 | alignment_heads = None 233 | else: 234 | raise RuntimeError( 235 | f"Model {name} not found; available models = {available_models()}" 236 | ) 237 | 238 | with ( 239 | io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") 240 | ) as fp: 241 | checkpoint = torch.load(fp, map_location=device) 242 | del checkpoint_file 243 | else: 244 | checkpoint = torch.load(ft_model_ckpt_path, weights_only=False) 245 | 246 | decoder_advisor_chkpt = torch.load(advisor_ckpt_path, weights_only=False) if advisor_ckpt_path is not None else {"state_dict": {}} 247 | advisor_state_dict = {k: v for k, v in decoder_advisor_chkpt["state_dict"].items() if "decoder_advisor" in k} 248 | 249 | whisper_dict = checkpoint["model_state_dict"] if "model_state_dict" in checkpoint.keys() else checkpoint["state_dict"] 250 | 251 | whisper_dict = {k.replace("weight", "base_layer.weight") if "attn." in k and "weight" in k 252 | else k.replace("bias", "base_layer.bias") if "attn." in k and "bias" in k 253 | else k: v for k, v in whisper_dict.items()} 254 | 255 | streaming_whisper_state_dict = {**advisor_state_dict, **whisper_dict} 256 | 257 | dims = ModelDimensions(**checkpoint["dims"]) 258 | 259 | model = StreamingWhisper(dims, 260 | cache_gran=cache_gran, 261 | gran=gran, 262 | rank=rank, 263 | extra_gran_blocks=extra_gran_blocks) 264 | 265 | model.load_state_dict(streaming_whisper_state_dict, strict=False) 266 | 267 | # for n, p in model.named_parameters(): 268 | # print(n, p) 269 | 270 | if ft_model_ckpt_path is None and alignment_heads is not None: 271 | model.set_alignment_heads(alignment_heads) 272 | 273 | return model.to(device) 274 | 275 | 276 | def load_streaming_model( 277 | name: str, 278 | gran: int = 300, 279 | multilingual: bool = False, 280 | device: Optional[Union[str, torch.device]] = None, 281 | download_root: str = None, 282 | in_memory: bool = False, 283 | ) -> StreamingWhisper: 284 | 285 | subname = (str(gran) + '-multi') if multilingual else str(gran) 286 | 287 | from huggingface_hub import hf_hub_download 288 | 289 | try: 290 | ckpt_path = hf_hub_download(repo_id="MLSpeech/CarelessWhisper-Streaming", filename=_STREAMING_MODELS_HF[name][subname], repo_type="model", token=True) 291 | except KeyError as e: 292 | print(f"Streaming model with the next configs: size {name}, multilingual: {multilingual} and chunk size: {gran} is not available.") 293 | 294 | checkpoint = torch.load(ckpt_path, weights_only=False) 295 | 296 | dims = ModelDimensions(**checkpoint["dims"]) 297 | 298 | model = StreamingWhisper(dims, 299 | gran=checkpoint['cfg']['gran'], 300 | rank=checkpoint['cfg']['rank'], 301 | extra_gran_blocks=checkpoint['cfg']['extra_gran_blocks']) 302 | 303 | model.load_state_dict(checkpoint['state_dict'], strict=False) 304 | 305 | return model.to(device) 306 | -------------------------------------------------------------------------------- /careless_whisper_stream/tokenizer.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import os 3 | import string 4 | from dataclasses import dataclass, field 5 | from functools import cached_property, lru_cache 6 | from typing import Dict, List, Optional, Tuple 7 | 8 | import tiktoken 9 | 10 | LANGUAGES = { 11 | "en": "english", 12 | "zh": "chinese", 13 | "de": "german", 14 | "es": "spanish", 15 | "ru": "russian", 16 | "ko": "korean", 17 | "fr": "french", 18 | "ja": "japanese", 19 | "pt": "portuguese", 20 | "tr": "turkish", 21 | "pl": "polish", 22 | "ca": "catalan", 23 | "nl": "dutch", 24 | "ar": "arabic", 25 | "sv": "swedish", 26 | "it": "italian", 27 | "id": "indonesian", 28 | "hi": "hindi", 29 | "fi": "finnish", 30 | "vi": "vietnamese", 31 | "he": "hebrew", 32 | "uk": "ukrainian", 33 | "el": "greek", 34 | "ms": "malay", 35 | "cs": "czech", 36 | "ro": "romanian", 37 | "da": "danish", 38 | "hu": "hungarian", 39 | "ta": "tamil", 40 | "no": "norwegian", 41 | "th": "thai", 42 | "ur": "urdu", 43 | "hr": "croatian", 44 | "bg": "bulgarian", 45 | "lt": "lithuanian", 46 | "la": "latin", 47 | "mi": "maori", 48 | "ml": "malayalam", 49 | "cy": "welsh", 50 | "sk": "slovak", 51 | "te": "telugu", 52 | "fa": "persian", 53 | "lv": "latvian", 54 | "bn": "bengali", 55 | "sr": "serbian", 56 | "az": "azerbaijani", 57 | "sl": "slovenian", 58 | "kn": "kannada", 59 | "et": "estonian", 60 | "mk": "macedonian", 61 | "br": "breton", 62 | "eu": "basque", 63 | "is": "icelandic", 64 | "hy": "armenian", 65 | "ne": "nepali", 66 | "mn": "mongolian", 67 | "bs": "bosnian", 68 | "kk": "kazakh", 69 | "sq": "albanian", 70 | "sw": "swahili", 71 | "gl": "galician", 72 | "mr": "marathi", 73 | "pa": "punjabi", 74 | "si": "sinhala", 75 | "km": "khmer", 76 | "sn": "shona", 77 | "yo": "yoruba", 78 | "so": "somali", 79 | "af": "afrikaans", 80 | "oc": "occitan", 81 | "ka": "georgian", 82 | "be": "belarusian", 83 | "tg": "tajik", 84 | "sd": "sindhi", 85 | "gu": "gujarati", 86 | "am": "amharic", 87 | "yi": "yiddish", 88 | "lo": "lao", 89 | "uz": "uzbek", 90 | "fo": "faroese", 91 | "ht": "haitian creole", 92 | "ps": "pashto", 93 | "tk": "turkmen", 94 | "nn": "nynorsk", 95 | "mt": "maltese", 96 | "sa": "sanskrit", 97 | "lb": "luxembourgish", 98 | "my": "myanmar", 99 | "bo": "tibetan", 100 | "tl": "tagalog", 101 | "mg": "malagasy", 102 | "as": "assamese", 103 | "tt": "tatar", 104 | "haw": "hawaiian", 105 | "ln": "lingala", 106 | "ha": "hausa", 107 | "ba": "bashkir", 108 | "jw": "javanese", 109 | "su": "sundanese", 110 | "yue": "cantonese", 111 | } 112 | 113 | # language code lookup by name, with a few language aliases 114 | TO_LANGUAGE_CODE = { 115 | **{language: code for code, language in LANGUAGES.items()}, 116 | "burmese": "my", 117 | "valencian": "ca", 118 | "flemish": "nl", 119 | "haitian": "ht", 120 | "letzeburgesch": "lb", 121 | "pushto": "ps", 122 | "panjabi": "pa", 123 | "moldavian": "ro", 124 | "moldovan": "ro", 125 | "sinhalese": "si", 126 | "castilian": "es", 127 | "mandarin": "zh", 128 | } 129 | 130 | 131 | @dataclass 132 | class Tokenizer: 133 | """A thin wrapper around `tiktoken` providing quick access to special tokens""" 134 | 135 | encoding: tiktoken.Encoding 136 | num_languages: int 137 | language: Optional[str] = None 138 | task: Optional[str] = None 139 | sot_sequence: Tuple[int] = () 140 | special_tokens: Dict[str, int] = field(default_factory=dict) 141 | 142 | def __post_init__(self): 143 | for special in self.encoding.special_tokens_set: 144 | special_token = self.encoding.encode_single_token(special) 145 | self.special_tokens[special] = special_token 146 | 147 | sot: int = self.special_tokens["<|startoftranscript|>"] 148 | translate: int = self.special_tokens["<|translate|>"] 149 | transcribe: int = self.special_tokens["<|transcribe|>"] 150 | 151 | langs = tuple(LANGUAGES.keys())[: self.num_languages] 152 | sot_sequence = [sot] 153 | if self.language is not None: 154 | sot_sequence.append(sot + 1 + langs.index(self.language)) 155 | if self.task is not None: 156 | task_token: int = transcribe if self.task == "transcribe" else translate 157 | sot_sequence.append(task_token) 158 | 159 | self.sot_sequence = tuple(sot_sequence) 160 | 161 | def encode(self, text, **kwargs): 162 | return self.encoding.encode(text, **kwargs) 163 | 164 | def decode(self, token_ids: List[int], **kwargs) -> str: 165 | token_ids = [t for t in token_ids if t < self.timestamp_begin] 166 | return self.encoding.decode(token_ids, **kwargs) 167 | 168 | def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str: 169 | """ 170 | Timestamp tokens are above other special tokens' id range and are ignored by `decode()`. 171 | This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". 172 | """ 173 | return self.encoding.decode(token_ids, **kwargs) 174 | 175 | @cached_property 176 | def eot(self) -> int: 177 | return self.encoding.eot_token 178 | 179 | @cached_property 180 | def transcribe(self) -> int: 181 | return self.special_tokens["<|transcribe|>"] 182 | 183 | @cached_property 184 | def translate(self) -> int: 185 | return self.special_tokens["<|translate|>"] 186 | 187 | @cached_property 188 | def sot(self) -> int: 189 | return self.special_tokens["<|startoftranscript|>"] 190 | 191 | @cached_property 192 | def sot_lm(self) -> int: 193 | return self.special_tokens["<|startoflm|>"] 194 | 195 | @cached_property 196 | def sot_prev(self) -> int: 197 | return self.special_tokens["<|startofprev|>"] 198 | 199 | @cached_property 200 | def no_speech(self) -> int: 201 | return self.special_tokens["<|nospeech|>"] 202 | 203 | @cached_property 204 | def no_timestamps(self) -> int: 205 | return self.special_tokens["<|notimestamps|>"] 206 | 207 | @cached_property 208 | def timestamp_begin(self) -> int: 209 | return self.special_tokens["<|0.00|>"] 210 | 211 | @cached_property 212 | def language_token(self) -> int: 213 | """Returns the token id corresponding to the value of the `language` field""" 214 | if self.language is None: 215 | raise ValueError("This tokenizer does not have language token configured") 216 | 217 | return self.to_language_token(self.language) 218 | 219 | def to_language_token(self, language): 220 | if token := self.special_tokens.get(f"<|{language}|>", None): 221 | return token 222 | 223 | raise KeyError(f"Language {language} not found in tokenizer.") 224 | 225 | @cached_property 226 | def all_language_tokens(self) -> Tuple[int]: 227 | result = [] 228 | for token, token_id in self.special_tokens.items(): 229 | if token.strip("<|>") in LANGUAGES: 230 | result.append(token_id) 231 | return tuple(result)[: self.num_languages] 232 | 233 | @cached_property 234 | def all_language_codes(self) -> Tuple[str]: 235 | return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens) 236 | 237 | @cached_property 238 | def sot_sequence_including_notimestamps(self) -> Tuple[int]: 239 | return tuple(list(self.sot_sequence) + [self.no_timestamps]) 240 | 241 | @cached_property 242 | def non_speech_tokens(self) -> Tuple[int]: 243 | """ 244 | Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech 245 | annotations, to prevent sampling texts that are not actually spoken in the audio, e.g. 246 | 247 | - ♪♪♪ 248 | - ( SPEAKING FOREIGN LANGUAGE ) 249 | - [DAVID] Hey there, 250 | 251 | keeping basic punctuations like commas, periods, question marks, exclamation points, etc. 252 | """ 253 | symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』') 254 | symbols += ( 255 | "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split() 256 | ) 257 | 258 | # symbols that may be a single token or multiple tokens depending on the tokenizer. 259 | # In case they're multiple tokens, suppress the first token, which is safe because: 260 | # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress 261 | # in generations, and in the 3-byte UTF-8 representation they share the first two bytes. 262 | miscellaneous = set("♩♪♫♬♭♮♯") 263 | assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous) 264 | 265 | # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word 266 | result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]} 267 | for symbol in symbols + list(miscellaneous): 268 | for tokens in [ 269 | self.encoding.encode(symbol), 270 | self.encoding.encode(" " + symbol), 271 | ]: 272 | if len(tokens) == 1 or symbol in miscellaneous: 273 | result.add(tokens[0]) 274 | 275 | return tuple(sorted(result)) 276 | 277 | def split_to_word_tokens(self, tokens: List[int]): 278 | if self.language in {"zh", "ja", "th", "lo", "my", "yue"}: 279 | # These languages don't typically use spaces, so it is difficult to split words 280 | # without morpheme analysis. Here, we instead split words at any 281 | # position where the tokens are decoded as valid unicode points 282 | return self.split_tokens_on_unicode(tokens) 283 | 284 | return self.split_tokens_on_spaces(tokens) 285 | 286 | def split_tokens_on_unicode(self, tokens: List[int]): 287 | decoded_full = self.decode_with_timestamps(tokens) 288 | replacement_char = "\ufffd" 289 | 290 | words = [] 291 | word_tokens = [] 292 | current_tokens = [] 293 | unicode_offset = 0 294 | 295 | for token in tokens: 296 | current_tokens.append(token) 297 | decoded = self.decode_with_timestamps(current_tokens) 298 | 299 | if ( 300 | replacement_char not in decoded 301 | or decoded_full[unicode_offset + decoded.index(replacement_char)] 302 | == replacement_char 303 | ): 304 | words.append(decoded) 305 | word_tokens.append(current_tokens) 306 | current_tokens = [] 307 | unicode_offset += len(decoded) 308 | 309 | return words, word_tokens 310 | 311 | def split_tokens_on_spaces(self, tokens: List[int]): 312 | subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens) 313 | words = [] 314 | word_tokens = [] 315 | 316 | for subword, subword_tokens in zip(subwords, subword_tokens_list): 317 | special = subword_tokens[0] >= self.eot 318 | with_space = subword.startswith(" ") 319 | punctuation = subword.strip() in string.punctuation 320 | if special or with_space or punctuation or len(words) == 0: 321 | words.append(subword) 322 | word_tokens.append(subword_tokens) 323 | else: 324 | words[-1] = words[-1] + subword 325 | word_tokens[-1].extend(subword_tokens) 326 | 327 | return words, word_tokens 328 | 329 | 330 | @lru_cache(maxsize=None) 331 | def get_encoding(name: str = "gpt2", num_languages: int = 99): 332 | vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken") 333 | ranks = { 334 | base64.b64decode(token): int(rank) 335 | for token, rank in (line.split() for line in open(vocab_path) if line) 336 | } 337 | n_vocab = len(ranks) 338 | special_tokens = {} 339 | 340 | specials = [ 341 | "<|endoftext|>", 342 | "<|startoftranscript|>", 343 | *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], 344 | "<|translate|>", 345 | "<|transcribe|>", 346 | "<|startoflm|>", 347 | "<|startofprev|>", 348 | "<|nospeech|>", 349 | "<|notimestamps|>", 350 | *[f"<|{i * 0.02:.2f}|>" for i in range(1501)], 351 | ] 352 | 353 | for token in specials: 354 | special_tokens[token] = n_vocab 355 | n_vocab += 1 356 | 357 | return tiktoken.Encoding( 358 | name=os.path.basename(vocab_path), 359 | explicit_n_vocab=n_vocab, 360 | pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", 361 | mergeable_ranks=ranks, 362 | special_tokens=special_tokens, 363 | ) 364 | 365 | 366 | @lru_cache(maxsize=None) 367 | def get_tokenizer( 368 | multilingual: bool, 369 | *, 370 | num_languages: int = 99, 371 | language: Optional[str] = None, 372 | task: Optional[str] = None, # Literal["transcribe", "translate", None] 373 | ) -> Tokenizer: 374 | if language is not None: 375 | language = language.lower() 376 | if language not in LANGUAGES: 377 | if language in TO_LANGUAGE_CODE: 378 | language = TO_LANGUAGE_CODE[language] 379 | else: 380 | raise ValueError(f"Unsupported language: {language}") 381 | 382 | if multilingual: 383 | encoding_name = "multilingual" 384 | language = language or "en" 385 | task = task or "transcribe" 386 | else: 387 | encoding_name = "gpt2" 388 | language = None 389 | task = None 390 | 391 | encoding = get_encoding(name=encoding_name, num_languages=num_languages) 392 | 393 | return Tokenizer( 394 | encoding=encoding, num_languages=num_languages, language=language, task=task 395 | ) 396 | -------------------------------------------------------------------------------- /careless_whisper_stream/audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from functools import lru_cache 4 | from subprocess import CalledProcessError, run 5 | from typing import Optional, Union 6 | 7 | import wave 8 | import torch 9 | import pyaudio 10 | import numpy as np 11 | import soundfile as sf 12 | import torch.nn.functional as F 13 | 14 | from .utils import exact_div 15 | 16 | # hard-coded audio hyperparameters 17 | SAMPLE_RATE = 16000 18 | N_FFT = 400 19 | HOP_LENGTH = 160 20 | CHUNK_LENGTH = 30 21 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk 22 | N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input 23 | 24 | N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 25 | FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame 26 | TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token 27 | 28 | 29 | class MyStream: 30 | def __init__(self, 31 | ms_gran: int = 200, 32 | sample_rate: int = 16000, 33 | channels: int = 2, 34 | filename: str = None, 35 | inp_dtype: any = pyaudio.paInt16, 36 | simulate_stream: bool = False, 37 | wav_file: str = None, 38 | relay: bool = False, 39 | use_latency: bool = False, 40 | pad_trim: bool = True, 41 | use_remote_machine: bool = False): 42 | 43 | assert ms_gran % 20 == 0, "ms_gran must be a multiple of 20" 44 | 45 | self.ms_gran = ms_gran 46 | self.sample_rate = sample_rate 47 | self.channels = channels 48 | self.inp_dtype = inp_dtype 49 | self.relay = relay 50 | self.use_latency = use_latency 51 | self.use_remote_machine = use_remote_machine 52 | 53 | rate_fraction = ms_gran / 1000 54 | self.chunk_size = int(rate_fraction * sample_rate) 55 | self.filename = filename 56 | self.streamed_wav_file = wav_file 57 | 58 | self.simulate_stream = simulate_stream 59 | if self.simulate_stream: 60 | assert wav_file is not None, "when simulating stream a wav file must be provided." 61 | if pad_trim: 62 | self.wav_array = pad_or_trim(load_audio(wav_file, sample_rate), length=N_SAMPLES+180) # wav array 63 | else: 64 | audio = load_audio(wav_file, sample_rate) 65 | self.wav_array = pad_or_trim(audio, length=audio.shape[-1]+180) 66 | print(f"{self.wav_array.shape=}") 67 | 68 | def _simulate_stream_using_wav(self): 69 | print("Streaming simulation of a wav started...") 70 | 71 | for i in range(self.wav_array.shape[-1] // self.chunk_size): 72 | if i == 0: 73 | yield self.wav_array[..., :(((i + 1) * self.chunk_size) + 40 + 320)] # 320 is extra 20 msec buffer we need! 74 | else: 75 | yield self.wav_array[..., ((i * self.chunk_size) + 40 + 320):(((i + 1) * self.chunk_size) + 40 + 320)] 76 | 77 | if self.use_latency: time.sleep(self.ms_gran / 1000) # simulating the latency between audio chunks 78 | 79 | def open_stream(self): 80 | if self.simulate_stream or self.relay or self.use_remote_machine: return 81 | 82 | self.audio = pyaudio.PyAudio() 83 | self.stream = self.audio.open(input=True, format=self.inp_dtype, channels=self.channels, rate=self.sample_rate, frames_per_buffer=self.chunk_size) 84 | 85 | def _read_from_stream(self): 86 | print("Streaming instance recording started...") 87 | 88 | while True: 89 | yield self.stream.read(self.chunk_size) 90 | 91 | def _follow_growing_wav(self): 92 | while not os.path.exists(self.streamed_wav_file): 93 | time.sleep(0.1) 94 | 95 | with sf.SoundFile(self.streamed_wav_file, mode='r') as f: 96 | while True: 97 | block = f.read(self.chunk_size) 98 | if len(block) == 0: 99 | time.sleep(self.ms_gran / 1000) # Wait for more data 100 | continue 101 | yield block 102 | 103 | def _read_raw_pcm(self): 104 | samples_per_chunk = int(self.sample_rate * (self.ms_gran / 1000)) 105 | bytes_per_sample = 2 # s16le = 16 bits = 2 bytes 106 | chunk_size = samples_per_chunk * bytes_per_sample 107 | 108 | while not os.path.exists(self.streamed_wav_file): 109 | time.sleep(0.1) 110 | 111 | with open(self.streamed_wav_file, 'rb') as f: 112 | while True: 113 | chunk = f.read(chunk_size) 114 | if not chunk: 115 | time.sleep((self.ms_gran / 1000)) 116 | continue 117 | yield np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32768.0 118 | 119 | def read(self): 120 | if self.simulate_stream: 121 | return self._simulate_stream_using_wav() 122 | 123 | if self.use_remote_machine: 124 | return self._read_raw_pcm() 125 | 126 | return self._read_from_stream() 127 | 128 | def _save_recording_file(self, frames: list): 129 | print(f"Saving recorded audio file on path {self.filename}") 130 | 131 | waveFile = wave.open(self.filename, 'wb') 132 | waveFile.setnchannels(self.channels) 133 | waveFile.setsampwidth(self.audio.get_sample_size(self.inp_dtype)) 134 | waveFile.setframerate(self.sample_rate) 135 | waveFile.writeframes(b''.join(frames)) 136 | waveFile.close() 137 | 138 | def close_stream(self, frames: list): 139 | if self.simulate_stream: return 140 | 141 | # Stop Recording 142 | self.stream.stop_stream() 143 | self.stream.close() 144 | self.audio.terminate() 145 | 146 | print("Finished recording, stream and audio terminated.") 147 | 148 | if self.filename: self._save_recording_file(frames) 149 | 150 | 151 | def load_audio(file: str, sr: int = SAMPLE_RATE): 152 | """ 153 | Open an audio file and read as mono waveform, resampling as necessary 154 | 155 | Parameters 156 | ---------- 157 | file: str 158 | The audio file to open 159 | 160 | sr: int 161 | The sample rate to resample the audio if necessary 162 | 163 | Returns 164 | ------- 165 | A NumPy array containing the audio waveform, in float32 dtype. 166 | """ 167 | 168 | # This launches a subprocess to decode audio while down-mixing 169 | # and resampling as necessary. Requires the ffmpeg CLI in PATH. 170 | # fmt: off 171 | cmd = [ 172 | "ffmpeg", 173 | "-nostdin", 174 | "-threads", "0", 175 | "-i", file, 176 | "-f", "s16le", 177 | "-ac", "1", 178 | "-acodec", "pcm_s16le", 179 | "-ar", str(sr), 180 | "-" 181 | ] 182 | # fmt: on 183 | try: 184 | out = run(cmd, capture_output=True, check=True).stdout 185 | except CalledProcessError as e: 186 | raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e 187 | 188 | return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 189 | 190 | 191 | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): 192 | """ 193 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder. 194 | """ 195 | if torch.is_tensor(array): 196 | if array.shape[axis] > length: 197 | array = array.index_select( 198 | dim=axis, index=torch.arange(length, device=array.device) 199 | ) 200 | 201 | if array.shape[axis] < length: 202 | pad_widths = [(0, 0)] * array.ndim 203 | pad_widths[axis] = (0, length - array.shape[axis]) 204 | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) 205 | else: 206 | if array.shape[axis] > length: 207 | array = array.take(indices=range(length), axis=axis) 208 | 209 | if array.shape[axis] < length: 210 | pad_widths = [(0, 0)] * array.ndim 211 | pad_widths[axis] = (0, length - array.shape[axis]) 212 | array = np.pad(array, pad_widths) 213 | 214 | return array 215 | 216 | 217 | @lru_cache(maxsize=None) 218 | def mel_filters(device, n_mels: int) -> torch.Tensor: 219 | """ 220 | load the mel filterbank matrix for projecting STFT into a Mel spectrogram. 221 | Allows decoupling librosa dependency; saved using: 222 | 223 | np.savez_compressed( 224 | "mel_filters.npz", 225 | mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), 226 | mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128), 227 | ) 228 | """ 229 | assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" 230 | 231 | filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") 232 | with np.load(filters_path, allow_pickle=False) as f: 233 | return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) 234 | 235 | 236 | def log_mel_spectrogram( 237 | audio: Union[str, np.ndarray, torch.Tensor], 238 | n_mels: int = 80, 239 | padding: int = 0, 240 | device: Optional[Union[str, torch.device]] = None, 241 | ): 242 | """ 243 | Compute the log-Mel spectrogram of 244 | 245 | Parameters 246 | ---------- 247 | audio: Union[str, np.ndarray, torch.Tensor], shape = (*) 248 | The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz 249 | 250 | n_mels: int 251 | The number of Mel-frequency filters, only 80 is supported 252 | 253 | padding: int 254 | Number of zero samples to pad to the right 255 | 256 | device: Optional[Union[str, torch.device]] 257 | If given, the audio tensor is moved to this device before STFT 258 | 259 | Returns 260 | ------- 261 | torch.Tensor, shape = (80, n_frames) 262 | A Tensor that contains the Mel spectrogram 263 | """ 264 | if not torch.is_tensor(audio): 265 | if isinstance(audio, str): 266 | audio = load_audio(audio) 267 | audio = torch.from_numpy(audio) 268 | 269 | if device is not None: 270 | audio = audio.to(device) 271 | if padding > 0: 272 | audio = F.pad(audio, (0, padding)) 273 | 274 | window = torch.hann_window(N_FFT).to(audio.device) 275 | stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) 276 | magnitudes = stft[..., :-1].abs() ** 2 277 | 278 | filters = mel_filters(audio.device, n_mels) 279 | mel_spec = filters @ magnitudes 280 | 281 | log_spec = torch.clamp(mel_spec, min=1e-10).log10() 282 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) 283 | log_spec = (log_spec + 4.0) / 4.0 284 | return log_spec 285 | 286 | 287 | class SpectrogramStream: 288 | def __init__(self, n_fft: int = N_FFT, hop_length: int = HOP_LENGTH, n_mels: int = 80, window: Optional[str] = "hann", pad_mode: str = "reflect"): 289 | 290 | self.n_fft = n_fft 291 | self.hop_length = hop_length 292 | self.pad_mode = pad_mode 293 | self.n_mels = n_mels 294 | 295 | self.window = torch.hann_window(n_fft) 296 | self.window_type = window 297 | 298 | self.ctx_samples = self.n_fft - self.hop_length 299 | 300 | self.reset() 301 | 302 | def reset(self): 303 | self.is_first = True 304 | self.audio_ctx = torch.tensor([]) 305 | self.log_spec_max = -torch.inf 306 | 307 | def calc_mel_with_new_frame(self, audio_frame: torch.Tensor, is_last: bool = False): 308 | 309 | self.window = self.window.to(audio_frame.device) 310 | 311 | if len(audio_frame.shape) == 1: 312 | audio_frame = audio_frame.unsqueeze(0) 313 | 314 | n_batch = audio_frame.shape[0] 315 | 316 | if isinstance(self.log_spec_max, float): 317 | self.log_spec_max = torch.ones((n_batch)).to(audio_frame.device) * -torch.inf 318 | 319 | # check if we are on first frame, if so, pad using reflection 320 | if self.is_first: 321 | pad = int(self.n_fft // 2) + 1 322 | audio_input = F.pad(audio_frame, [pad, 0], self.pad_mode) 323 | self.is_first = False 324 | else: # pad with previous context 325 | audio_input = torch.cat([self.audio_ctx[..., -self.ctx_samples:], audio_frame], dim=-1) 326 | 327 | if is_last: # pad reflect last frame 328 | pad = int(self.n_fft // 4) + 1 329 | audio_input = F.pad(audio_input, [pad, 0], self.pad_mode) 330 | 331 | self.audio_ctx = audio_frame # now audio ctx is the last frame 332 | 333 | stft = torch.stft(audio_input, self.n_fft, self.hop_length, window=self.window, return_complex=True, center=False) 334 | magnitudes = stft.abs() ** 2 335 | filters = mel_filters(audio_frame.device, self.n_mels) 336 | mel_spec = filters @ magnitudes 337 | 338 | log_spec = torch.clamp(mel_spec, min=1e-10).log10() # from shape (b, n_mels, audio_frames) 339 | self.log_spec_max = torch.maximum(log_spec.view(n_batch, -1).max(dim=-1).values, self.log_spec_max).to(log_spec.device) 340 | 341 | log_spec = torch.maximum(log_spec.view(n_batch, -1).permute(1, 0), self.log_spec_max - 8.0).permute(1, 0).view(n_batch, self.n_mels, -1) 342 | log_spec = (log_spec + 4.0) / 4.0 343 | return log_spec 344 | 345 | def _simulate_streaming_log_spec(self, audio: torch.Tensor, ms_gran: int = 300, total_frames: int = 3000, get_gt: bool = False): 346 | self.reset() 347 | 348 | samples_gran = HOP_LENGTH * (ms_gran // 10) 349 | sub_mel_frames = int(total_frames / ms_gran) * 10 350 | # print(samples_gran, sub_mel_frames) 351 | pred_mel = torch.cat([self.calc_mel_with_new_frame(audio[..., (i * samples_gran) + (40 * int(i != 0)): ((i + 1) * samples_gran) + 40], is_last=(i == sub_mel_frames - 1)) for i in range(sub_mel_frames)], dim=-1) 352 | 353 | if get_gt: 354 | gt_mel = log_mel_spectrogram(audio) 355 | return pred_mel, gt_mel 356 | 357 | return pred_mel 358 | -------------------------------------------------------------------------------- /careless_whisper_stream/timing.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import subprocess 3 | import warnings 4 | from dataclasses import dataclass 5 | from typing import TYPE_CHECKING, List 6 | 7 | import numba 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND 13 | from .tokenizer import Tokenizer 14 | 15 | if TYPE_CHECKING: 16 | from .model import Whisper 17 | 18 | 19 | def median_filter(x: torch.Tensor, filter_width: int): 20 | """Apply a median filter of width `filter_width` along the last dimension of `x`""" 21 | pad_width = filter_width // 2 22 | if x.shape[-1] <= pad_width: 23 | # F.pad requires the padding width to be smaller than the input dimension 24 | return x 25 | 26 | if (ndim := x.ndim) <= 2: 27 | # `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D 28 | x = x[None, None, :] 29 | 30 | assert ( 31 | filter_width > 0 and filter_width % 2 == 1 32 | ), "`filter_width` should be an odd number" 33 | 34 | result = None 35 | x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect") 36 | if x.is_cuda: 37 | try: 38 | from .triton_ops import median_filter_cuda 39 | 40 | result = median_filter_cuda(x, filter_width) 41 | except (RuntimeError, subprocess.CalledProcessError): 42 | warnings.warn( 43 | "Failed to launch Triton kernels, likely due to missing CUDA toolkit; " 44 | "falling back to a slower median kernel implementation..." 45 | ) 46 | 47 | if result is None: 48 | # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450) 49 | result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2] 50 | 51 | if ndim <= 2: 52 | result = result[0, 0] 53 | 54 | return result 55 | 56 | 57 | @numba.jit(nopython=True) 58 | def backtrace(trace: np.ndarray): 59 | i = trace.shape[0] - 1 60 | j = trace.shape[1] - 1 61 | trace[0, :] = 2 62 | trace[:, 0] = 1 63 | 64 | result = [] 65 | while i > 0 or j > 0: 66 | result.append((i - 1, j - 1)) 67 | 68 | if trace[i, j] == 0: 69 | i -= 1 70 | j -= 1 71 | elif trace[i, j] == 1: 72 | i -= 1 73 | elif trace[i, j] == 2: 74 | j -= 1 75 | else: 76 | raise ValueError("Unexpected trace[i, j]") 77 | 78 | result = np.array(result) 79 | return result[::-1, :].T 80 | 81 | 82 | @numba.jit(nopython=True, parallel=True) 83 | def dtw_cpu(x: np.ndarray): 84 | N, M = x.shape 85 | cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf 86 | trace = -np.ones((N + 1, M + 1), dtype=np.float32) 87 | 88 | cost[0, 0] = 0 89 | for j in range(1, M + 1): 90 | for i in range(1, N + 1): 91 | c0 = cost[i - 1, j - 1] 92 | c1 = cost[i - 1, j] 93 | c2 = cost[i, j - 1] 94 | 95 | if c0 < c1 and c0 < c2: 96 | c, t = c0, 0 97 | elif c1 < c0 and c1 < c2: 98 | c, t = c1, 1 99 | else: 100 | c, t = c2, 2 101 | 102 | cost[i, j] = x[i - 1, j - 1] + c 103 | trace[i, j] = t 104 | 105 | return backtrace(trace) 106 | 107 | 108 | def dtw_cuda(x, BLOCK_SIZE=1024): 109 | from .triton_ops import dtw_kernel 110 | 111 | M, N = x.shape 112 | assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}" 113 | 114 | x_skew = ( 115 | F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M) 116 | ) 117 | x_skew = x_skew.T.contiguous() 118 | cost = torch.ones(N + M + 2, M + 2) * np.inf 119 | cost[0, 0] = 0 120 | cost = cost.cuda() 121 | trace = torch.zeros_like(cost, dtype=torch.int32) 122 | 123 | dtw_kernel[(1,)]( 124 | cost, 125 | trace, 126 | x_skew, 127 | x_skew.stride(0), 128 | cost.stride(0), 129 | trace.stride(0), 130 | N, 131 | M, 132 | BLOCK_SIZE=BLOCK_SIZE, 133 | ) 134 | 135 | trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[ 136 | :, : N + 1 137 | ] 138 | return backtrace(trace.cpu().numpy()) 139 | 140 | 141 | def dtw(x: torch.Tensor) -> np.ndarray: 142 | if x.is_cuda: 143 | try: 144 | return dtw_cuda(x) 145 | except (RuntimeError, subprocess.CalledProcessError): 146 | warnings.warn( 147 | "Failed to launch Triton kernels, likely due to missing CUDA toolkit; " 148 | "falling back to a slower DTW implementation..." 149 | ) 150 | 151 | return dtw_cpu(x.double().cpu().numpy()) 152 | 153 | 154 | @dataclass 155 | class WordTiming: 156 | word: str 157 | tokens: List[int] 158 | start: float 159 | end: float 160 | probability: float 161 | 162 | 163 | def find_alignment( 164 | model: "Whisper", 165 | tokenizer: Tokenizer, 166 | text_tokens: List[int], 167 | mel: torch.Tensor, 168 | num_frames: int, 169 | *, 170 | medfilt_width: int = 7, 171 | qk_scale: float = 1.0, 172 | ) -> List[WordTiming]: 173 | if len(text_tokens) == 0: 174 | return [] 175 | 176 | tokens = torch.tensor( 177 | [ 178 | *tokenizer.sot_sequence, 179 | tokenizer.no_timestamps, 180 | *text_tokens, 181 | tokenizer.eot, 182 | ] 183 | ).to(model.device) 184 | 185 | # install hooks on the cross attention layers to retrieve the attention weights 186 | QKs = [None] * model.dims.n_text_layer 187 | hooks = [ 188 | block.cross_attn.register_forward_hook( 189 | lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0]) 190 | ) 191 | for i, block in enumerate(model.decoder.blocks) 192 | ] 193 | 194 | with torch.no_grad(): 195 | logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0] 196 | sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot] 197 | token_probs = sampled_logits.softmax(dim=-1) 198 | text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens] 199 | text_token_probs = text_token_probs.tolist() 200 | 201 | for hook in hooks: 202 | hook.remove() 203 | 204 | # heads * tokens * frames 205 | weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T]) 206 | weights = weights[:, :, : num_frames // 2] 207 | weights = (weights * qk_scale).softmax(dim=-1) 208 | std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False) 209 | weights = (weights - mean) / std 210 | weights = median_filter(weights, medfilt_width) 211 | 212 | matrix = weights.mean(axis=0) 213 | matrix = matrix[len(tokenizer.sot_sequence) : -1] 214 | text_indices, time_indices = dtw(-matrix) 215 | 216 | words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot]) 217 | if len(word_tokens) <= 1: 218 | # return on eot only 219 | # >>> np.pad([], (1, 0)) 220 | # array([0.]) 221 | # This results in crashes when we lookup jump_times with float, like 222 | # IndexError: arrays used as indices must be of integer (or boolean) type 223 | return [] 224 | word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)) 225 | 226 | jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) 227 | jump_times = time_indices[jumps] / TOKENS_PER_SECOND 228 | start_times = jump_times[word_boundaries[:-1]] 229 | end_times = jump_times[word_boundaries[1:]] 230 | word_probabilities = [ 231 | np.mean(text_token_probs[i:j]) 232 | for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) 233 | ] 234 | 235 | return [ 236 | WordTiming(word, tokens, start, end, probability) 237 | for word, tokens, start, end, probability in zip( 238 | words, word_tokens, start_times, end_times, word_probabilities 239 | ) 240 | ] 241 | 242 | 243 | def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str): 244 | # merge prepended punctuations 245 | i = len(alignment) - 2 246 | j = len(alignment) - 1 247 | while i >= 0: 248 | previous = alignment[i] 249 | following = alignment[j] 250 | if previous.word.startswith(" ") and previous.word.strip() in prepended: 251 | # prepend it to the following word 252 | following.word = previous.word + following.word 253 | following.tokens = previous.tokens + following.tokens 254 | previous.word = "" 255 | previous.tokens = [] 256 | else: 257 | j = i 258 | i -= 1 259 | 260 | # merge appended punctuations 261 | i = 0 262 | j = 1 263 | while j < len(alignment): 264 | previous = alignment[i] 265 | following = alignment[j] 266 | if not previous.word.endswith(" ") and following.word in appended: 267 | # append it to the previous word 268 | previous.word = previous.word + following.word 269 | previous.tokens = previous.tokens + following.tokens 270 | following.word = "" 271 | following.tokens = [] 272 | else: 273 | i = j 274 | j += 1 275 | 276 | 277 | def add_word_timestamps( 278 | *, 279 | segments: List[dict], 280 | model: "Whisper", 281 | tokenizer: Tokenizer, 282 | mel: torch.Tensor, 283 | num_frames: int, 284 | prepend_punctuations: str = "\"'“¿([{-", 285 | append_punctuations: str = "\"'.。,,!!??::”)]}、", 286 | last_speech_timestamp: float, 287 | **kwargs, 288 | ): 289 | if len(segments) == 0: 290 | return 291 | 292 | text_tokens_per_segment = [ 293 | [token for token in segment["tokens"] if token < tokenizer.eot] 294 | for segment in segments 295 | ] 296 | 297 | text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment)) 298 | alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs) 299 | word_durations = np.array([t.end - t.start for t in alignment]) 300 | word_durations = word_durations[word_durations.nonzero()] 301 | median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0 302 | median_duration = min(0.7, float(median_duration)) 303 | max_duration = median_duration * 2 304 | 305 | # hack: truncate long words at sentence boundaries. 306 | # a better segmentation algorithm based on VAD should be able to replace this. 307 | if len(word_durations) > 0: 308 | sentence_end_marks = ".。!!??" 309 | # ensure words at sentence boundaries are not longer than twice the median word duration. 310 | for i in range(1, len(alignment)): 311 | if alignment[i].end - alignment[i].start > max_duration: 312 | if alignment[i].word in sentence_end_marks: 313 | alignment[i].end = alignment[i].start + max_duration 314 | elif alignment[i - 1].word in sentence_end_marks: 315 | alignment[i].start = alignment[i].end - max_duration 316 | 317 | merge_punctuations(alignment, prepend_punctuations, append_punctuations) 318 | 319 | time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE 320 | word_index = 0 321 | 322 | for segment, text_tokens in zip(segments, text_tokens_per_segment): 323 | saved_tokens = 0 324 | words = [] 325 | 326 | while word_index < len(alignment) and saved_tokens < len(text_tokens): 327 | timing = alignment[word_index] 328 | 329 | if timing.word: 330 | words.append( 331 | dict( 332 | word=timing.word, 333 | start=round(time_offset + timing.start, 2), 334 | end=round(time_offset + timing.end, 2), 335 | probability=timing.probability, 336 | ) 337 | ) 338 | 339 | saved_tokens += len(timing.tokens) 340 | word_index += 1 341 | 342 | # hack: truncate long words at segment boundaries. 343 | # a better segmentation algorithm based on VAD should be able to replace this. 344 | if len(words) > 0: 345 | # ensure the first and second word after a pause is not longer than 346 | # twice the median word duration. 347 | if words[0]["end"] - last_speech_timestamp > median_duration * 4 and ( 348 | words[0]["end"] - words[0]["start"] > max_duration 349 | or ( 350 | len(words) > 1 351 | and words[1]["end"] - words[0]["start"] > max_duration * 2 352 | ) 353 | ): 354 | if ( 355 | len(words) > 1 356 | and words[1]["end"] - words[1]["start"] > max_duration 357 | ): 358 | boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration) 359 | words[0]["end"] = words[1]["start"] = boundary 360 | words[0]["start"] = max(0, words[0]["end"] - max_duration) 361 | 362 | # prefer the segment-level start timestamp if the first word is too long. 363 | if ( 364 | segment["start"] < words[0]["end"] 365 | and segment["start"] - 0.5 > words[0]["start"] 366 | ): 367 | words[0]["start"] = max( 368 | 0, min(words[0]["end"] - median_duration, segment["start"]) 369 | ) 370 | else: 371 | segment["start"] = words[0]["start"] 372 | 373 | # prefer the segment-level end timestamp if the last word is too long. 374 | if ( 375 | segment["end"] > words[-1]["start"] 376 | and segment["end"] + 0.5 < words[-1]["end"] 377 | ): 378 | words[-1]["end"] = max( 379 | words[-1]["start"] + median_duration, segment["end"] 380 | ) 381 | else: 382 | segment["end"] = words[-1]["end"] 383 | 384 | last_speech_timestamp = segment["end"] 385 | 386 | segment["words"] = words 387 | -------------------------------------------------------------------------------- /training_code/whisper_module.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("./") 3 | import torch 4 | import random 5 | import evaluate 6 | import careless_whisper_stream 7 | import careless_whisper_stream.tokenizer as whisper_tokenizer 8 | 9 | from torch import nn, Tensor 10 | from torch.optim.adamw import AdamW 11 | from training_code.utils import Config 12 | from careless_whisper_stream import StreamingWhisper 13 | from careless_whisper_stream.audio import HOP_LENGTH 14 | from pytorch_lightning import LightningModule 15 | from torch.optim.lr_scheduler import LinearLR, ReduceLROnPlateau 16 | from training_code.datasets_classes import TIMIT, WAVsDataset, AlignedTextGridDataset 17 | from training_code.collators import WhisperDataCollatorWithPadding, LoRAWhisperDataCollatorWithPadding 18 | 19 | class WhisperCustomModel(LightningModule): 20 | def __init__(self, cfg:Config, model_name="tiny", lang="en", train_dataset: str = None, eval_dataset: str = None, task="transcribe") -> None: 21 | super().__init__() 22 | self.save_hyperparameters() 23 | self.task = task 24 | self.lang = lang 25 | self.model = careless_whisper_stream.load_model(model_name) 26 | self.tokenizer: whisper_tokenizer = whisper_tokenizer.get_tokenizer(True, language=lang) 27 | 28 | self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100) 29 | self.metrics_wer = evaluate.load("wer") 30 | self.metrics_cer = evaluate.load("cer") 31 | 32 | self.params = self.model 33 | 34 | self.cfg = cfg 35 | self.__train_dataset = train_dataset 36 | self.__eval_dataset = eval_dataset 37 | 38 | def forward(self, x): 39 | return self.model(x) 40 | 41 | def calc_wer_val(self, out: Tensor, labels: Tensor): 42 | out[out == -100] = self.tokenizer.eot 43 | labels[labels == -100] = self.tokenizer.eot 44 | 45 | o_list, l_list = [], [] 46 | for o, l in zip(out, labels): 47 | o = torch.argmax(o, dim=1) 48 | o_list.append(self.tokenizer.decode(o)) 49 | l_list.append(self.tokenizer.decode(l)) 50 | 51 | wer = self.metrics_wer.compute(references=l_list, predictions=o_list) 52 | return wer 53 | 54 | def training_step(self, batch, batch_id): 55 | input_ids = batch["input_ids"] 56 | labels = batch["labels"].long() 57 | dec_input_ids = batch["dec_input_ids"].long() 58 | 59 | with torch.no_grad(): 60 | audio_features = self.model.encoder(input_ids) 61 | 62 | out, _, _ = self.model.decoder(dec_input_ids, audio_features) 63 | 64 | loss = self.loss_fn(out.view(-1, out.size(-1)), labels.view(-1)) 65 | self.log("train/loss", loss, on_step=True, prog_bar=True, logger=True) 66 | return loss 67 | 68 | def validation_step(self, batch, batch_id): 69 | input_ids = batch["input_ids"] 70 | labels = batch["labels"].long() 71 | dec_input_ids = batch["dec_input_ids"].long() 72 | 73 | audio_features = self.model.encoder(input_ids) 74 | out, _, _ = self.model.decoder(dec_input_ids, audio_features) 75 | 76 | loss = self.loss_fn(out.view(-1, out.size(-1)), labels.view(-1)) 77 | wer = self.calc_wer_val(out, labels) 78 | 79 | self.log("val/loss", loss, on_step=True, prog_bar=True, logger=True, on_epoch=True) 80 | self.log("val/wer", wer, on_step=True, prog_bar=True, logger=True, on_epoch=True) 81 | 82 | return { 83 | "wer": wer, 84 | "loss": loss 85 | } 86 | 87 | def configure_optimizers(self): 88 | model = self.params 89 | no_decay = ["bias", "LayerNorm.weight"] 90 | optimizer_grouped_parameters = [ 91 | { 92 | "params": [p for n, p in model.named_parameters() 93 | if not any(nd in n for nd in no_decay)], 94 | "weight_decay": self.cfg.weight_decay, 95 | }, 96 | { 97 | "params": [p for n, p in model.named_parameters() 98 | if any(nd in n for nd in no_decay)], 99 | "weight_decay": 0.0, 100 | }, 101 | ] 102 | optimizer = AdamW(optimizer_grouped_parameters, 103 | lr=self.cfg.learning_rate, 104 | eps=self.cfg.adam_epsilon) 105 | self.optimizer = optimizer 106 | 107 | scheduler = LinearLR( 108 | self.optimizer, start_factor=0.5, end_factor=0.8, total_iters=self.t_total 109 | ) 110 | 111 | self.scheduler = scheduler 112 | 113 | return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}] 114 | 115 | def setup(self, stage=None): 116 | if stage == 'fit' or stage is None: 117 | self.t_total = ( 118 | (len(self.__train_dataset) // (self.cfg.batch_size)) 119 | // self.cfg.gradient_accumulation_steps 120 | * float(self.cfg.num_train_epochs) 121 | ) 122 | 123 | def get_dataset(self, ds_path: str, split: str): 124 | return TIMIT(ds_path, self.tokenizer) 125 | 126 | def train_dataloader(self): 127 | dataset = self.get_dataset(self.__train_dataset, "train") 128 | return torch.utils.data.DataLoader(dataset, 129 | batch_size=self.cfg.batch_size, 130 | drop_last=True, shuffle=True, num_workers=self.cfg.num_worker, 131 | collate_fn=WhisperDataCollatorWithPadding() 132 | ) 133 | 134 | def val_dataloader(self): 135 | dataset = self.get_dataset(self.__eval_dataset, "val") 136 | return torch.utils.data.DataLoader(dataset, 137 | batch_size=self.cfg.batch_size, 138 | num_workers=self.cfg.num_worker, 139 | collate_fn=WhisperDataCollatorWithPadding() 140 | ) 141 | 142 | 143 | class LoRAStreamedWhisper(WhisperCustomModel): 144 | def __init__(self, cfg: Config, model_name="tiny", lang="en", train_dataset: str = None, eval_dataset: str = None, task="transcribe", rank=8, enc_emb_gran=15, enc_context=1, sim_stream=False, beam_size=None, use_kv_cache=False, use_ca_kv_cache=False, get_times=False, eval_script=False) -> None: 145 | super().__init__(cfg, model_name, lang, train_dataset, eval_dataset, task) 146 | 147 | self.automatic_optimization = not cfg.streaming_train 148 | 149 | # if model_name != "large-v2" and not eval_script: 150 | print(f"enc_emb_gran: {enc_emb_gran}") 151 | self.model: StreamingWhisper = careless_whisper_stream.load_streaming_model_for_train(model_name, 152 | advisor_ckpt_path=None, 153 | advisor_type=None, 154 | rank=rank, 155 | gran=enc_emb_gran, 156 | extra_gran_blocks=enc_context, 157 | ) 158 | 159 | for n, p in self.model.named_parameters(): 160 | if "lora" not in n: 161 | p.requires_grad = False 162 | 163 | self.model_name = model_name 164 | self.rank = rank 165 | self.enc_emb_gran = enc_emb_gran 166 | self.enc_context = enc_context 167 | self.simulate_stream = sim_stream 168 | self.full_stream = cfg.streaming_train 169 | self.beam_size = beam_size 170 | self.language = lang 171 | self.use_kv_cache = use_kv_cache 172 | self.use_ca_kv_cache = use_ca_kv_cache 173 | self.get_times = get_times 174 | self.eval_script = eval_script 175 | 176 | # for stream mode train 177 | self.num_frames = self.model.dims.n_audio_ctx // self.enc_emb_gran # 1500 // enc_emb_gran 178 | self.mel_samples = self.enc_emb_gran * 2 * HOP_LENGTH 179 | 180 | self.params = None 181 | 182 | self.last_out = None 183 | self.__train_dataset = train_dataset 184 | self.__eval_dataset = eval_dataset 185 | 186 | def _calc_labels(self, labels: Tensor, endpoints: Tensor, index: int): 187 | t_seconds = (index + 1) * self.enc_emb_gran * 0.02 188 | 189 | # take only relevant labels into account 190 | mask = (endpoints <= t_seconds) & (endpoints != -100) 191 | eot_indices = mask.int().argmin(dim=-1) 192 | clone_labels = labels.clone() 193 | 194 | # ignore irrelevant labels 195 | clone_labels[~mask] = -100 196 | 197 | # mark eot labels 198 | clone_labels[range(labels.shape[0]), eot_indices] = self.tokenizer.eot 199 | return clone_labels 200 | 201 | def _get_sample_points(self, endpoints: Tensor): 202 | # base case 203 | if self.cfg.streaming_fraction == 1: 204 | return range(self.enc_context, self.num_frames) 205 | 206 | biggest_endpoint = endpoints.max().item() 207 | 208 | # determine last index. 209 | num_frames = min(int(((biggest_endpoint / 0.02) // self.enc_emb_gran) + int(1 / (self.enc_emb_gran * 0.02)) + 1), self.num_frames) # adding 1 sec of silence 210 | new_range = range(self.enc_context, num_frames) 211 | 212 | sample_points = random.sample(new_range, k=int(len(new_range) * self.cfg.streaming_fraction) + 1) 213 | 214 | if self.cfg.streaming_random: 215 | return sample_points 216 | 217 | return sorted(sample_points) 218 | 219 | def _forward_step_stream(self, batch, step): 220 | input_ids = batch["input_ids"] 221 | labels = batch["labels"].long() 222 | dec_input_ids = batch["dec_input_ids"].long() 223 | endpoints = batch["endpoints"] 224 | 225 | if step == "train": 226 | optimizer = self.optimizers() 227 | 228 | # forward 229 | sample_points = self._get_sample_points(endpoints) 230 | for i in sample_points: 231 | audio_features = self.model.encoder(input_ids[..., :(i + 1) * (self.enc_emb_gran * 2)], index=[0, (i + 1) * self.enc_emb_gran], mask=True) 232 | out = self.model.decoder(dec_input_ids, audio_features, dump_type="None") 233 | 234 | if step == "train": 235 | optimizer.zero_grad() 236 | 237 | # loss calc 238 | frame_labels = self._calc_labels(labels, endpoints, i) 239 | loss = self.loss_fn(out.view(-1, out.size(-1)), frame_labels.view(-1)) 240 | 241 | # optimizer step if relevant. 242 | if step == "train": 243 | loss.backward() 244 | optimizer.step() # might move optimizer step to out of the loop for faster training 245 | 246 | if step == "train": 247 | return loss 248 | 249 | return out, loss 250 | 251 | def _forward_step(self, batch, step): 252 | input_ids = batch["input_ids"] 253 | labels = batch["labels"].long() 254 | dec_input_ids = batch["dec_input_ids"].long() 255 | 256 | # self.model.encoder.reset() 257 | audio_features = self.model.encoder(input_ids, mask=True) # use mask for fast learning, simulates stream mode 258 | out = self.model.decoder(dec_input_ids, audio_features, dump_type="None") 259 | 260 | # loss calc 261 | loss = self.loss_fn(out.view(-1, out.size(-1)), labels.view(-1)) 262 | 263 | if step == "train": 264 | return loss 265 | 266 | return out, loss 267 | 268 | def training_step(self, batch, batch_id): 269 | if self.full_stream: 270 | loss = self._forward_step_stream(batch, "train") 271 | else: 272 | loss = self._forward_step(batch, "train") 273 | 274 | self.log("train/loss", loss, on_step=True, prog_bar=True, logger=True, sync_dist=True) 275 | 276 | return loss 277 | 278 | def validation_step(self, batch, batch_id): 279 | out, loss = self._forward_step(batch, "val") 280 | wer = self.calc_wer_val(out, batch["labels"]) 281 | 282 | self.log("val/loss", loss, on_step=True, prog_bar=True, logger=True, sync_dist=True) 283 | self.log("val/wer", wer, on_step=True, prog_bar=True, logger=True, sync_dist=True) 284 | 285 | return { 286 | "wer": wer, 287 | "loss": loss 288 | } 289 | 290 | def predict_step(self, batch, batch_id): 291 | wavs = batch["wav_path"] 292 | text = batch["text"] 293 | 294 | results, times = self.model.transcribe(wav_file=wavs[0], 295 | simulate_stream=True, 296 | beam_size=self.beam_size, 297 | language=self.language, 298 | use_ca_kv_cache=self.use_ca_kv_cache, 299 | use_sa_kv_cache=self.use_kv_cache, 300 | get_times=self.get_times) 301 | 302 | return [res.text for res in results], text, wavs[0], times 303 | 304 | def on_train_epoch_start(self): 305 | random.seed(self.current_epoch + self.cfg.seed) 306 | 307 | def on_train_epoch_end(self): 308 | model_sched = self.lr_schedulers() 309 | model_sched.step(self.trainer.callback_metrics["val/loss"]) 310 | 311 | def get_dataset(self, ds_path, split): 312 | print(f"Stream simulation mode: {self.simulate_stream}") 313 | 314 | if self.full_stream: 315 | return AlignedTextGridDataset(ds_path=ds_path, get_streamed_mel=True, gran=self.enc_emb_gran, extra_gran_blocks=self.enc_context, n_mels=self.model.dims.n_mels, multilingual=self.cfg.multilingual) 316 | 317 | return WAVsDataset(ds_path=ds_path, get_streamed_mel=self.simulate_stream) 318 | 319 | def configure_optimizers(self): 320 | model = self.model 321 | optimizer_grouped_parameters = [ 322 | { 323 | "params": [p for n, p in model.named_parameters() if p.requires_grad], 324 | "weight_decay": self.cfg.weight_decay, 325 | "lr": self.cfg.learning_rate 326 | }, 327 | { 328 | "params": [p for n, p in model.named_parameters() if not p.requires_grad], 329 | "weight_decay": 0.0, 330 | "lr": 0 331 | }, 332 | ] 333 | optimizer = AdamW(optimizer_grouped_parameters, 334 | eps=self.cfg.adam_epsilon) 335 | self.optimizer = optimizer 336 | 337 | scheduler = ReduceLROnPlateau( 338 | self.optimizer, 'min', patience=2, factor=0.5 339 | ) 340 | 341 | self.scheduler = scheduler 342 | 343 | return [optimizer], [{"scheduler": scheduler, "monitor": "val/loss"}] 344 | 345 | def train_dataloader(self): 346 | dataset = self.get_dataset(self.__train_dataset, "train") 347 | return torch.utils.data.DataLoader(dataset, 348 | batch_size=self.cfg.batch_size, 349 | drop_last=True, shuffle=True, num_workers=self.cfg.num_worker, pin_memory=True, 350 | collate_fn=LoRAWhisperDataCollatorWithPadding()) 351 | 352 | def val_dataloader(self): 353 | dataset = self.get_dataset(self.__eval_dataset, "val") 354 | return torch.utils.data.DataLoader(dataset, 355 | batch_size=self.cfg.batch_size, 356 | num_workers=self.cfg.num_worker, pin_memory=True, 357 | collate_fn=LoRAWhisperDataCollatorWithPadding()) 358 | 359 | def on_save_checkpoint(self, checkpoint): 360 | checkpoint["dims"] = self.model.dims.__dict__ 361 | -------------------------------------------------------------------------------- /careless_whisper_stream/streaming_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torch import Tensor 7 | from typing import Optional 8 | from .model import AudioEncoder, TextDecoder, Whisper 9 | 10 | from .streaming_decoding import DecodingTask, DecodingOptions, DecodingResult 11 | from .streaming_transcribe import transcribe as transcribe_function 12 | from .decoding import decode as non_causal_decode_function 13 | from .audio import SpectrogramStream 14 | 15 | from dataclasses import replace 16 | 17 | from pytorch_lightning import LightningModule 18 | 19 | 20 | class LoraLayer(LightningModule): 21 | def __init__(self, input_dim, output_dim, rank=8, alpha=None, *args, **kwargs): 22 | super().__init__(*args, **kwargs) 23 | 24 | self.lora_A = nn.Parameter(torch.zeros(input_dim, rank)) 25 | self.lora_B = nn.Parameter(torch.zeros(rank, output_dim)) 26 | 27 | self.alpha = rank if alpha is None else alpha 28 | self.rank = rank 29 | self.scale = self.alpha / self.rank 30 | 31 | self._init_weights() 32 | 33 | def _init_weights(self): 34 | nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) 35 | nn.init.zeros_(self.lora_B) 36 | 37 | def forward(self, x): 38 | return x @ (self.lora_A @ self.lora_B) * self.scale 39 | 40 | 41 | class LoraLinearLayer(LightningModule): 42 | def __init__(self, base_layer: nn.Linear, rank: int = 8, bias: int = True, *args, **kwargs): 43 | super().__init__(*args, **kwargs) 44 | 45 | self.base_layer = base_layer 46 | 47 | self.lora_layer = LoraLayer(base_layer.in_features, base_layer.out_features, rank=rank) 48 | self.aggregate_lora = True 49 | 50 | def turn_on_lora(self): 51 | self.aggregate_lora = True 52 | 53 | def turn_off_lora(self): 54 | self.aggregate_lora = False 55 | 56 | def forward(self, x: Tensor): 57 | if not self.aggregate_lora: 58 | return self.base_layer(x) 59 | 60 | return self.base_layer(x) + self.lora_layer(x) 61 | 62 | 63 | class LoRAMultiHeadAttention(LightningModule): 64 | def __init__(self, n_head, query, key, value, out, rank, *args, **kwargs): 65 | super().__init__(*args, **kwargs) 66 | 67 | self.n_head = n_head 68 | self.query = LoraLinearLayer(query, rank) 69 | self.key = LoraLinearLayer(key, rank, bias=False) 70 | self.value = LoraLinearLayer(value, rank) 71 | self.out = LoraLinearLayer(out, rank) 72 | 73 | def forward( 74 | self, 75 | x: Tensor, 76 | xa: Tensor = None, 77 | mask: Tensor = None, 78 | kv_cache: dict = None, 79 | *args, **kwargs 80 | ): 81 | q = self.query(x) 82 | 83 | if kv_cache is None or xa is None or self.key not in kv_cache: 84 | k = self.key(x if xa is None else xa) 85 | v = self.value(x if xa is None else xa) 86 | 87 | else: 88 | k = kv_cache[self.key] 89 | v = kv_cache[self.value] 90 | 91 | wv, qk = self.qkv_attention(q, k, v, mask, kv_cache) 92 | 93 | return self.out(wv), qk 94 | 95 | def qkv_attention( 96 | self, q: Tensor, k: Tensor, v: Tensor, mask: Tensor = None, kv_cache: dict = None 97 | ): 98 | n_batch, n_ctx, n_state = q.shape 99 | _, k_ctx, _ = k.shape 100 | scale = (n_state // self.n_head) ** -0.25 101 | q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale 102 | k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale 103 | v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) 104 | qk = q @ k 105 | 106 | # apply causal mask 107 | if mask is not None: 108 | 109 | # kv_cache for beam search decoding case 110 | if kv_cache is not None and "beam_indices" in kv_cache.keys(): 111 | for i in range(n_batch): 112 | qk[i] = qk[i] + mask[kv_cache["beam_indices"][1][i * n_ctx]:kv_cache["beam_indices"][1][i * n_ctx] + n_ctx, :k_ctx] 113 | 114 | # For training, encoder/decoder causal masks 115 | elif k_ctx == n_ctx: 116 | qk = qk + mask[:n_ctx, :n_ctx] 117 | 118 | # kv_cache in the greedy decoding case 119 | else: 120 | qk = qk + mask[k_ctx - n_ctx:k_ctx, :k_ctx] 121 | 122 | qk = qk.float() 123 | 124 | w = F.softmax(qk, dim=-1).to(q.dtype) 125 | 126 | return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() 127 | 128 | 129 | class StreamingAudioEncoder(AudioEncoder): 130 | 131 | def __init__(self, n_mels, n_ctx, n_state, n_head, n_layer, cache_gran, gran, rank, extra_gran_blocks): 132 | super().__init__(n_mels, n_ctx, n_state, n_head, n_layer) 133 | 134 | self.gran = gran 135 | self.extra_gran_blocks = extra_gran_blocks 136 | 137 | for block in self.blocks: 138 | block.attn = LoRAMultiHeadAttention(self.n_head, 139 | block.attn.query, 140 | block.attn.key, 141 | block.attn.value, 142 | block.attn.out, 143 | rank) 144 | 145 | self.use_stream = False 146 | 147 | # mask for training 148 | matrix_size = n_ctx 149 | block_size = gran 150 | extra_blocks = extra_gran_blocks 151 | mask = torch.full((matrix_size, matrix_size), float('-inf')) 152 | 153 | for i in range(0, matrix_size, block_size): 154 | if (i // block_size) <= extra_blocks: 155 | zero_cols = (block_size * (extra_blocks + 1)) 156 | else: 157 | zero_cols = (block_size * (extra_blocks + 1)) + ((i // block_size) - extra_blocks) * block_size 158 | 159 | mask[i:i + block_size, :zero_cols] = 1 160 | 161 | self.register_buffer("mask", mask, persistent=False) 162 | 163 | def _use_stream(self, use_stream: bool): 164 | self.use_stream = use_stream 165 | 166 | def forward(self, x: Tensor, index: list = [0, 1500], kv_cache = None, mask = True): 167 | """ 168 | simulate streaming forward using qk cache self attn. 169 | """ 170 | x = F.gelu(self.conv1(x)) 171 | x = F.gelu(self.conv2(x)) 172 | x = x.permute(0, 2, 1) 173 | 174 | if self.use_stream: 175 | offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 176 | x = x[:, offset:offset + self.gran + (int(offset == 0) * (self.extra_gran_blocks * self.gran))] # offset 177 | x = (x + self.positional_embedding[offset:offset + self.gran + (int(offset == 0) * (self.extra_gran_blocks * self.gran))]).to(x.dtype) 178 | else: # use during training 179 | x = x[:, index[0]:index[1]] # offset 180 | x = (x + self.positional_embedding[index[0]:index[1]]).to(x.dtype) 181 | 182 | for block in self.blocks: 183 | chosen_mask = mask[..., :index[1], :index[1]] if isinstance(mask, Tensor) else self.mask if (mask is not None) and (self.use_stream) else None 184 | x = block(x, mask=chosen_mask, kv_cache=kv_cache) 185 | 186 | x = self.ln_post(x) 187 | 188 | return x 189 | 190 | 191 | class StreamingTextDecoder(TextDecoder): 192 | def __init__(self, n_vocab, n_ctx, n_state, n_head, n_layer, rank): 193 | super().__init__(n_vocab, n_ctx, n_state, n_head, n_layer) 194 | 195 | self.n_ctx = n_ctx 196 | self.n_state = n_state 197 | 198 | for block in self.blocks: 199 | block.attn = LoRAMultiHeadAttention(n_head, 200 | block.attn.query, 201 | block.attn.key, 202 | block.attn.value, 203 | block.attn.out, 204 | rank) 205 | 206 | block.cross_attn = LoRAMultiHeadAttention(n_head, 207 | block.cross_attn.query, 208 | block.cross_attn.key, 209 | block.cross_attn.value, 210 | block.cross_attn.out, 211 | rank) 212 | 213 | def forward(self, x: Tensor, xa: Tensor, kv_cache: dict = None, dump_type: str = None, **kwargs): 214 | """ 215 | x : torch.LongTensor, shape = (batch_size, <= n_ctx) 216 | the text tokens 217 | xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) 218 | the encoded audio features to be attended on 219 | """ 220 | if kv_cache is not None and "beam_indices" in kv_cache.keys(): 221 | x = self.token_embedding(x) + self.positional_embedding.unsqueeze(0).expand(x.shape[0], self.positional_embedding.shape[0], self.positional_embedding.shape[1])[kv_cache["beam_indices"][0], kv_cache["beam_indices"][1]].view(x.shape[0], -1, self.n_state) 222 | else: 223 | offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 224 | x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]] 225 | 226 | x = x.to(xa.dtype) 227 | 228 | for block in self.blocks: 229 | x = block(x, xa, mask=self.mask, kv_cache=kv_cache) 230 | 231 | x = self.ln(x) 232 | logits = ( 233 | x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) 234 | ).float() 235 | 236 | return logits 237 | 238 | 239 | class StreamingWhisper(Whisper): 240 | def __init__(self, dims, cache_gran: bool = True, gran: int = 16, rank: int = 0, extra_gran_blocks: int = 0): 241 | super().__init__(dims) 242 | 243 | self.cache_gran = cache_gran 244 | self.gran = gran 245 | self.rank = rank 246 | self.extra_gran_blocks = extra_gran_blocks 247 | 248 | print(f"Running a streaming whisper model, using chunk size: {gran * 20}[msec] and {extra_gran_blocks} extra chunks for initialization.") 249 | 250 | # The only difference is a streaming encoder 251 | self.encoder = StreamingAudioEncoder( 252 | self.dims.n_mels, 253 | self.dims.n_audio_ctx, 254 | self.dims.n_audio_state, 255 | self.dims.n_audio_head, 256 | self.dims.n_audio_layer, 257 | cache_gran=cache_gran, 258 | gran=gran, 259 | rank=rank, 260 | extra_gran_blocks=extra_gran_blocks 261 | ) 262 | 263 | self.decoder = StreamingTextDecoder( 264 | self.dims.n_vocab, 265 | self.dims.n_text_ctx, 266 | self.dims.n_text_state, 267 | self.dims.n_text_head, 268 | self.dims.n_text_layer, 269 | rank=rank 270 | ) 271 | 272 | # Advisor params - Dropped. 273 | self.advisor_type = None 274 | self.n_advisor_class = 0 275 | self.decoder_advisor = None 276 | 277 | self.decoding_task = None 278 | self.spec_streamer = SpectrogramStream() 279 | 280 | self.use_ca_cache_hook = True # relevant only when ca_kv_cache is installed 281 | 282 | def reset(self, use_stream: bool = True, clean_task: bool = True): 283 | self.encoder._use_stream(use_stream) 284 | del self.decoding_task # trigger clean encoder kv caching 285 | self.decoding_task = None 286 | self.spec_streamer.reset() 287 | 288 | @torch.no_grad() 289 | def decode(self, mel: Tensor, options: DecodingOptions = DecodingOptions(), use_frames: bool = False, **kwargs) -> DecodingResult: 290 | if kwargs: options = replace(options, **kwargs) 291 | 292 | if use_frames: # mel is frames of audio, need to calc mel 293 | mel = self.spec_streamer.calc_mel_with_new_frame(mel).squeeze(0) 294 | 295 | if self.encoder.gran != options.gran: 296 | print(f"Encoder gran & options gran differ. forcing options to be on encoder's gran: {self.encoder.gran}") 297 | options.gran = self.encoder.gran 298 | 299 | if not self.decoding_task: 300 | self.decoding_task = DecodingTask(self, options) 301 | 302 | return self.decoding_task.run(mel.unsqueeze(0)) 303 | 304 | def _turn_off_lora(self): 305 | for _, layer in self.encoder.named_modules(): 306 | if isinstance(layer, LoraLinearLayer): 307 | layer.turn_off_lora() 308 | 309 | def _turn_on_lora(self): 310 | for _, layer in self.encoder.named_modules(): 311 | if isinstance(layer, LoraLinearLayer): 312 | layer.turn_on_lora() 313 | 314 | def _cancel_streaming_mode(self): 315 | self._turn_off_lora() 316 | self.encoder._use_stream(False) 317 | 318 | def _revert_streaming_mode(self): 319 | self._turn_on_lora() 320 | self.encoder._use_stream(True) 321 | 322 | @torch.no_grad() 323 | def non_causal_decode(self, mel: Tensor, options: DecodingOptions = DecodingOptions(), **kwargs) -> DecodingResult: 324 | self._cancel_streaming_mode() 325 | results = non_causal_decode_function(self, mel, options, **kwargs) 326 | self._revert_streaming_mode() 327 | return results 328 | 329 | def remove_encoder_kv_cache_hooks(self): 330 | for hook in self.encoder._forward_hooks.values(): 331 | hook.remove() 332 | 333 | def install_encoder_kv_cache_hooks(self, cache = None): 334 | cache = {**cache} if cache is not None else {} 335 | hooks = [] 336 | 337 | def save_to_cache(module, _, output): 338 | if module not in cache or output.shape[1] > self.dims.n_audio_ctx: 339 | # save as-is, for the first token or cross attention 340 | cache[module] = output 341 | else: 342 | cache[module] = torch.cat([cache[module], output], dim=1).detach() 343 | return cache[module] 344 | 345 | def install_hooks(layer: nn.Module): 346 | if isinstance(layer, LoRAMultiHeadAttention): 347 | hooks.append(layer.key.register_forward_hook(save_to_cache)) 348 | hooks.append(layer.value.register_forward_hook(save_to_cache)) 349 | 350 | self.encoder.apply(install_hooks) 351 | return cache, hooks 352 | 353 | def install_decoder_kv_cache_hooks(self, cache = None): 354 | cache = {**cache} if cache is not None else {} 355 | hooks = [] 356 | 357 | def save_to_cache(module, _, output): 358 | if module not in cache or output.shape[1] > self.dims.n_text_ctx: 359 | cache[module] = output 360 | else: 361 | if "beam_indices" not in cache.keys() or all([index == (cache[module].shape[1]) for index in cache["beam_indices"][1]]): 362 | cache[module] = torch.cat([cache[module], output], dim=1).detach() 363 | else: 364 | for _, (beam, index, output_index) in enumerate(zip(*cache["beam_indices"])): 365 | if index < cache[module].shape[1]: 366 | cache[module][beam, index] = output[beam, output_index] 367 | else: 368 | cache[module] = torch.cat([cache[module], output[:, (output_index):]], dim=1).detach() 369 | 370 | return cache[module] 371 | 372 | for name, module in self.decoder.named_modules(): 373 | if isinstance(module, LoRAMultiHeadAttention) and "attn" in name and "cross" not in name: 374 | hooks.append(module.key.register_forward_hook(save_to_cache)) 375 | hooks.append(module.value.register_forward_hook(save_to_cache)) 376 | 377 | return cache, hooks 378 | 379 | def install_cross_attn_kv_cache_hooks(self, cache=None): 380 | cache = {**cache} if cache is not None else {} 381 | hooks = [] 382 | 383 | def save_to_cache(module, _, output): 384 | if not self.use_ca_cache_hook: 385 | return cache[module] 386 | 387 | if module not in cache or output.shape[1] > self.dims.n_audio_ctx: 388 | # save as-is, for the first token or cross attention 389 | cache[module] = output 390 | else: 391 | cache[module] = torch.cat([cache[module], output], dim=1).detach() 392 | 393 | return cache[module] 394 | 395 | def check_if_calculation_is_needed(module, _): 396 | if not self.use_ca_cache_hook: 397 | return cache[module] 398 | 399 | for name, module in self.decoder.named_modules(): 400 | if isinstance(module, LoRAMultiHeadAttention) and "cross_attn" in name: 401 | hooks.append(module.key.register_forward_hook(save_to_cache)) 402 | hooks.append(module.key.register_forward_pre_hook(check_if_calculation_is_needed)) 403 | hooks.append(module.value.register_forward_hook(save_to_cache)) 404 | hooks.append(module.value.register_forward_pre_hook(check_if_calculation_is_needed)) 405 | 406 | return cache, hooks 407 | 408 | # For non-causal decoding compatibility 409 | def install_kv_cache_hooks(self, cache: Optional[dict] = None): 410 | """ 411 | The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value 412 | tensors calculated for the previous positions. This method returns a dictionary that stores 413 | all caches, and the necessary hooks for the key and value projection modules that save the 414 | intermediate tensors to be reused during later calculations. 415 | 416 | Returns 417 | ------- 418 | cache : Dict[nn.Module, torch.Tensor] 419 | A dictionary object mapping the key/value projection modules to its cache 420 | hooks : List[RemovableHandle] 421 | List of PyTorch RemovableHandle objects to stop the hooks to be called 422 | """ 423 | cache = {**cache} if cache is not None else {} 424 | hooks = [] 425 | 426 | def save_to_cache(module, _, output): 427 | if module not in cache or output.shape[1] > self.dims.n_text_ctx: 428 | # save as-is, for the first token or cross attention 429 | cache[module] = output 430 | else: 431 | cache[module] = torch.cat([cache[module], output], dim=1).detach() 432 | return cache[module] 433 | 434 | def install_hooks(layer: nn.Module): 435 | if isinstance(layer, LoRAMultiHeadAttention): 436 | hooks.append(layer.key.register_forward_hook(save_to_cache)) 437 | hooks.append(layer.value.register_forward_hook(save_to_cache)) 438 | 439 | self.decoder.apply(install_hooks) 440 | return cache, hooks 441 | 442 | # refers to function from streaming_decoding, streaming_transcribe library 443 | transcribe = transcribe_function 444 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | 1. This work is built upon or includes components of "Whisper" code licensed under the MIT License as follows: 3 | 4 | MIT License 5 | 6 | Copyright (c) 2022 OpenAI 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of this software and associated documentation files (the "Software"), to deal 10 | in the Software without restriction, including without limitation the rights 11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | copies of the Software, and to permit persons to whom the Software is 13 | furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all 16 | copies or substantial portions of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | SOFTWARE. 25 | 26 | 2. The modifications/work detailed in: 27 | 28 | work: 29 | training_code/collators.py 30 | training_code/datasets_classes.py 31 | training_code/ds_dict.py 32 | training_code/train.py 33 | training_code/utils.py 34 | training_code/whisper_module.py 35 | 36 | modifications/additions: 37 | careless_whisper_stream/__init__.py: function load_streaming_model, function load_streaming_model_for_train 38 | careless_whisper_stream/audio.py: class MyStream, class SpectrogramStream 39 | careless_whisper_stream/streaming_decoding.py: class StreamingDecoder, class BeamStreamingDecoder 40 | careless_whisper_stream/streaming_model.py: class StreamingAudioEncoder (kv_cache additions in forward), class StreamingTextDecoder (kv_cache for beam search in streaming), function install_decoder_kv_cache_hooks. 41 | careless_whisper_stream/streaming_transcribe.py 42 | 43 | 44 | is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License as follows: 45 | 46 | Attribution-NonCommercial 4.0 International 47 | 48 | ======================================================================= 49 | 50 | Creative Commons Corporation ("Creative Commons") is not a law firm and 51 | does not provide legal services or legal advice. Distribution of 52 | Creative Commons public licenses does not create a lawyer-client or 53 | other relationship. Creative Commons makes its licenses and related 54 | information available on an "as-is" basis. Creative Commons gives no 55 | warranties regarding its licenses, any material licensed under their 56 | terms and conditions, or any related information. Creative Commons 57 | disclaims all liability for damages resulting from their use to the 58 | fullest extent possible. 59 | 60 | Using Creative Commons Public Licenses 61 | 62 | Creative Commons public licenses provide a standard set of terms and 63 | conditions that creators and other rights holders may use to share 64 | original works of authorship and other material subject to copyright 65 | and certain other rights specified in the public license below. The 66 | following considerations are for informational purposes only, are not 67 | exhaustive, and do not form part of our licenses. 68 | 69 | Considerations for licensors: Our public licenses are 70 | intended for use by those authorized to give the public 71 | permission to use material in ways otherwise restricted by 72 | copyright and certain other rights. Our licenses are 73 | irrevocable. Licensors should read and understand the terms 74 | and conditions of the license they choose before applying it. 75 | Licensors should also secure all rights necessary before 76 | applying our licenses so that the public can reuse the 77 | material as expected. Licensors should clearly mark any 78 | material not subject to the license. This includes other CC- 79 | licensed material, or material used under an exception or 80 | limitation to copyright. More considerations for licensors: 81 | wiki.creativecommons.org/Considerations_for_licensors 82 | 83 | Considerations for the public: By using one of our public 84 | licenses, a licensor grants the public permission to use the 85 | licensed material under specified terms and conditions. If 86 | the licensor's permission is not necessary for any reason--for 87 | example, because of any applicable exception or limitation to 88 | copyright--then that use is not regulated by the license. Our 89 | licenses grant only permissions under copyright and certain 90 | other rights that a licensor has authority to grant. Use of 91 | the licensed material may still be restricted for other 92 | reasons, including because others have copyright or other 93 | rights in the material. A licensor may make special requests, 94 | such as asking that all changes be marked or described. 95 | Although not required by our licenses, you are encouraged to 96 | respect those requests where reasonable. More considerations 97 | for the public: 98 | wiki.creativecommons.org/Considerations_for_licensees 99 | 100 | ======================================================================= 101 | 102 | Creative Commons Attribution-NonCommercial 4.0 International Public 103 | License 104 | 105 | By exercising the Licensed Rights (defined below), You accept and agree 106 | to be bound by the terms and conditions of this Creative Commons 107 | Attribution-NonCommercial 4.0 International Public License ("Public 108 | License"). To the extent this Public License may be interpreted as a 109 | contract, You are granted the Licensed Rights in consideration of Your 110 | acceptance of these terms and conditions, and the Licensor grants You 111 | such rights in consideration of benefits the Licensor receives from 112 | making the Licensed Material available under these terms and 113 | conditions. 114 | 115 | 116 | Section 1 -- Definitions. 117 | 118 | a. Adapted Material means material subject to Copyright and Similar 119 | Rights that is derived from or based upon the Licensed Material 120 | and in which the Licensed Material is translated, altered, 121 | arranged, transformed, or otherwise modified in a manner requiring 122 | permission under the Copyright and Similar Rights held by the 123 | Licensor. For purposes of this Public License, where the Licensed 124 | Material is a musical work, performance, or sound recording, 125 | Adapted Material is always produced where the Licensed Material is 126 | synched in timed relation with a moving image. 127 | 128 | b. Adapter's License means the license You apply to Your Copyright 129 | and Similar Rights in Your contributions to Adapted Material in 130 | accordance with the terms and conditions of this Public License. 131 | 132 | c. Copyright and Similar Rights means copyright and/or similar rights 133 | closely related to copyright including, without limitation, 134 | performance, broadcast, sound recording, and Sui Generis Database 135 | Rights, without regard to how the rights are labeled or 136 | categorized. For purposes of this Public License, the rights 137 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 138 | Rights. 139 | d. Effective Technological Measures means those measures that, in the 140 | absence of proper authority, may not be circumvented under laws 141 | fulfilling obligations under Article 11 of the WIPO Copyright 142 | Treaty adopted on December 20, 1996, and/or similar international 143 | agreements. 144 | 145 | e. Exceptions and Limitations means fair use, fair dealing, and/or 146 | any other exception or limitation to Copyright and Similar Rights 147 | that applies to Your use of the Licensed Material. 148 | 149 | f. Licensed Material means the artistic or literary work, database, 150 | or other material to which the Licensor applied this Public 151 | License. 152 | 153 | g. Licensed Rights means the rights granted to You subject to the 154 | terms and conditions of this Public License, which are limited to 155 | all Copyright and Similar Rights that apply to Your use of the 156 | Licensed Material and that the Licensor has authority to license. 157 | 158 | h. Licensor means the individual(s) or entity(ies) granting rights 159 | under this Public License. 160 | 161 | i. NonCommercial means not primarily intended for or directed towards 162 | commercial advantage or monetary compensation. For purposes of 163 | this Public License, the exchange of the Licensed Material for 164 | other material subject to Copyright and Similar Rights by digital 165 | file-sharing or similar means is NonCommercial provided there is 166 | no payment of monetary compensation in connection with the 167 | exchange. 168 | 169 | j. Share means to provide material to the public by any means or 170 | process that requires permission under the Licensed Rights, such 171 | as reproduction, public display, public performance, distribution, 172 | dissemination, communication, or importation, and to make material 173 | available to the public including in ways that members of the 174 | public may access the material from a place and at a time 175 | individually chosen by them. 176 | 177 | k. Sui Generis Database Rights means rights other than copyright 178 | resulting from Directive 96/9/EC of the European Parliament and of 179 | the Council of 11 March 1996 on the legal protection of databases, 180 | as amended and/or succeeded, as well as other essentially 181 | equivalent rights anywhere in the world. 182 | 183 | l. You means the individual or entity exercising the Licensed Rights 184 | under this Public License. Your has a corresponding meaning. 185 | 186 | 187 | Section 2 -- Scope. 188 | 189 | a. License grant. 190 | 191 | 1. Subject to the terms and conditions of this Public License, 192 | the Licensor hereby grants You a worldwide, royalty-free, 193 | non-sublicensable, non-exclusive, irrevocable license to 194 | exercise the Licensed Rights in the Licensed Material to: 195 | 196 | a. reproduce and Share the Licensed Material, in whole or 197 | in part, for NonCommercial purposes only; and 198 | 199 | b. produce, reproduce, and Share Adapted Material for 200 | NonCommercial purposes only. 201 | 202 | 2. Exceptions and Limitations. For the avoidance of doubt, where 203 | Exceptions and Limitations apply to Your use, this Public 204 | License does not apply, and You do not need to comply with 205 | its terms and conditions. 206 | 207 | 3. Term. The term of this Public License is specified in Section 208 | 6(a). 209 | 210 | 4. Media and formats; technical modifications allowed. The 211 | Licensor authorizes You to exercise the Licensed Rights in 212 | all media and formats whether now known or hereafter created, 213 | and to make technical modifications necessary to do so. The 214 | Licensor waives and/or agrees not to assert any right or 215 | authority to forbid You from making technical modifications 216 | necessary to exercise the Licensed Rights, including 217 | technical modifications necessary to circumvent Effective 218 | Technological Measures. For purposes of this Public License, 219 | simply making modifications authorized by this Section 2(a) 220 | (4) never produces Adapted Material. 221 | 222 | 5. Downstream recipients. 223 | 224 | a. Offer from the Licensor -- Licensed Material. Every 225 | recipient of the Licensed Material automatically 226 | receives an offer from the Licensor to exercise the 227 | Licensed Rights under the terms and conditions of this 228 | Public License. 229 | 230 | b. No downstream restrictions. You may not offer or impose 231 | any additional or different terms or conditions on, or 232 | apply any Effective Technological Measures to, the 233 | Licensed Material if doing so restricts exercise of the 234 | Licensed Rights by any recipient of the Licensed 235 | Material. 236 | 237 | 6. No endorsement. Nothing in this Public License constitutes or 238 | may be construed as permission to assert or imply that You 239 | are, or that Your use of the Licensed Material is, connected 240 | with, or sponsored, endorsed, or granted official status by, 241 | the Licensor or others designated to receive attribution as 242 | provided in Section 3(a)(1)(A)(i). 243 | 244 | b. Other rights. 245 | 246 | 1. Moral rights, such as the right of integrity, are not 247 | licensed under this Public License, nor are publicity, 248 | privacy, and/or other similar personality rights; however, to 249 | the extent possible, the Licensor waives and/or agrees not to 250 | assert any such rights held by the Licensor to the limited 251 | extent necessary to allow You to exercise the Licensed 252 | Rights, but not otherwise. 253 | 254 | 2. Patent and trademark rights are not licensed under this 255 | Public License. 256 | 257 | 3. To the extent possible, the Licensor waives any right to 258 | collect royalties from You for the exercise of the Licensed 259 | Rights, whether directly or through a collecting society 260 | under any voluntary or waivable statutory or compulsory 261 | licensing scheme. In all other cases the Licensor expressly 262 | reserves any right to collect such royalties, including when 263 | the Licensed Material is used other than for NonCommercial 264 | purposes. 265 | 266 | 267 | Section 3 -- License Conditions. 268 | 269 | Your exercise of the Licensed Rights is expressly made subject to the 270 | following conditions. 271 | 272 | a. Attribution. 273 | 274 | 1. If You Share the Licensed Material (including in modified 275 | form), You must: 276 | 277 | a. retain the following if it is supplied by the Licensor 278 | with the Licensed Material: 279 | 280 | i. identification of the creator(s) of the Licensed 281 | Material and any others designated to receive 282 | attribution, in any reasonable manner requested by 283 | the Licensor (including by pseudonym if 284 | designated); 285 | 286 | ii. a copyright notice; 287 | 288 | iii. a notice that refers to this Public License; 289 | 290 | iv. a notice that refers to the disclaimer of 291 | warranties; 292 | 293 | v. a URI or hyperlink to the Licensed Material to the 294 | extent reasonably practicable; 295 | 296 | b. indicate if You modified the Licensed Material and 297 | retain an indication of any previous modifications; and 298 | 299 | c. indicate the Licensed Material is licensed under this 300 | Public License, and include the text of, or the URI or 301 | hyperlink to, this Public License. 302 | 303 | 2. You may satisfy the conditions in Section 3(a)(1) in any 304 | reasonable manner based on the medium, means, and context in 305 | which You Share the Licensed Material. For example, it may be 306 | reasonable to satisfy the conditions by providing a URI or 307 | hyperlink to a resource that includes the required 308 | information. 309 | 310 | 3. If requested by the Licensor, You must remove any of the 311 | information required by Section 3(a)(1)(A) to the extent 312 | reasonably practicable. 313 | 314 | 4. If You Share Adapted Material You produce, the Adapter's 315 | License You apply must not prevent recipients of the Adapted 316 | Material from complying with this Public License. 317 | 318 | 319 | Section 4 -- Sui Generis Database Rights. 320 | 321 | Where the Licensed Rights include Sui Generis Database Rights that 322 | apply to Your use of the Licensed Material: 323 | 324 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 325 | to extract, reuse, reproduce, and Share all or a substantial 326 | portion of the contents of the database for NonCommercial purposes 327 | only; 328 | 329 | b. if You include all or a substantial portion of the database 330 | contents in a database in which You have Sui Generis Database 331 | Rights, then the database in which You have Sui Generis Database 332 | Rights (but not its individual contents) is Adapted Material; and 333 | 334 | c. You must comply with the conditions in Section 3(a) if You Share 335 | all or a substantial portion of the contents of the database. 336 | 337 | For the avoidance of doubt, this Section 4 supplements and does not 338 | replace Your obligations under this Public License where the Licensed 339 | Rights include other Copyright and Similar Rights. 340 | 341 | 342 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 343 | 344 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 345 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 346 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 347 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 348 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 349 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 350 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 351 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 352 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 353 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 354 | 355 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 356 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 357 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 358 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 359 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 360 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 361 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 362 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 363 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 364 | 365 | c. The disclaimer of warranties and limitation of liability provided 366 | above shall be interpreted in a manner that, to the extent 367 | possible, most closely approximates an absolute disclaimer and 368 | waiver of all liability. 369 | 370 | 371 | Section 6 -- Term and Termination. 372 | 373 | a. This Public License applies for the term of the Copyright and 374 | Similar Rights licensed here. However, if You fail to comply with 375 | this Public License, then Your rights under this Public License 376 | terminate automatically. 377 | 378 | b. Where Your right to use the Licensed Material has terminated under 379 | Section 6(a), it reinstates: 380 | 381 | 1. automatically as of the date the violation is cured, provided 382 | it is cured within 30 days of Your discovery of the 383 | violation; or 384 | 385 | 2. upon express reinstatement by the Licensor. 386 | 387 | For the avoidance of doubt, this Section 6(b) does not affect any 388 | right the Licensor may have to seek remedies for Your violations 389 | of this Public License. 390 | 391 | c. For the avoidance of doubt, the Licensor may also offer the 392 | Licensed Material under separate terms or conditions or stop 393 | distributing the Licensed Material at any time; however, doing so 394 | will not terminate this Public License. 395 | 396 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 397 | License. 398 | 399 | 400 | Section 7 -- Other Terms and Conditions. 401 | 402 | a. The Licensor shall not be bound by any additional or different 403 | terms or conditions communicated by You unless expressly agreed. 404 | 405 | b. Any arrangements, understandings, or agreements regarding the 406 | Licensed Material not stated herein are separate from and 407 | independent of the terms and conditions of this Public License. 408 | 409 | 410 | Section 8 -- Interpretation. 411 | 412 | a. For the avoidance of doubt, this Public License does not, and 413 | shall not be interpreted to, reduce, limit, restrict, or impose 414 | conditions on any use of the Licensed Material that could lawfully 415 | be made without permission under this Public License. 416 | 417 | b. To the extent possible, if any provision of this Public License is 418 | deemed unenforceable, it shall be automatically reformed to the 419 | minimum extent necessary to make it enforceable. If the provision 420 | cannot be reformed, it shall be severed from this Public License 421 | without affecting the enforceability of the remaining terms and 422 | conditions. 423 | 424 | c. No term or condition of this Public License will be waived and no 425 | failure to comply consented to unless expressly agreed to by the 426 | Licensor. 427 | 428 | d. Nothing in this Public License constitutes or may be interpreted 429 | as a limitation upon, or waiver of, any privileges and immunities 430 | that apply to the Licensor or You, including from the legal 431 | processes of any jurisdiction or authority. 432 | 433 | ======================================================================= 434 | 435 | Creative Commons is not a party to its public 436 | licenses. Notwithstanding, Creative Commons may elect to apply one of 437 | its public licenses to material it publishes and in those instances 438 | will be considered the "Licensor." The text of the Creative Commons 439 | public licenses is dedicated to the public domain under the CC0 Public 440 | Domain Dedication. Except for the limited purpose of indicating that 441 | material is shared under a Creative Commons public license or as 442 | otherwise permitted by the Creative Commons policies published at 443 | creativecommons.org/policies, Creative Commons does not authorize the 444 | use of the trademark "Creative Commons" or any other trademark or logo 445 | of Creative Commons without its prior written consent including, 446 | without limitation, in connection with any unauthorized modifications 447 | to any of its public licenses or any other arrangements, 448 | understandings, or agreements concerning use of licensed material. For 449 | the avoidance of doubt, this paragraph does not form part of the 450 | public licenses. 451 | 452 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /careless_whisper_stream/normalizers/english.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | from fractions import Fraction 5 | from typing import Iterator, List, Match, Optional, Union 6 | 7 | from more_itertools import windowed 8 | 9 | from .basic import remove_symbols_and_diacritics 10 | 11 | 12 | class EnglishNumberNormalizer: 13 | """ 14 | Convert any spelled-out numbers into arabic numbers, while handling: 15 | 16 | - remove any commas 17 | - keep the suffixes such as: `1960s`, `274th`, `32nd`, etc. 18 | - spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars` 19 | - spell out `one` and `ones` 20 | - interpret successive single-digit numbers as nominal: `one oh one` -> `101` 21 | """ 22 | 23 | def __init__(self): 24 | super().__init__() 25 | 26 | self.zeros = {"o", "oh", "zero"} 27 | self.ones = { 28 | name: i 29 | for i, name in enumerate( 30 | [ 31 | "one", 32 | "two", 33 | "three", 34 | "four", 35 | "five", 36 | "six", 37 | "seven", 38 | "eight", 39 | "nine", 40 | "ten", 41 | "eleven", 42 | "twelve", 43 | "thirteen", 44 | "fourteen", 45 | "fifteen", 46 | "sixteen", 47 | "seventeen", 48 | "eighteen", 49 | "nineteen", 50 | ], 51 | start=1, 52 | ) 53 | } 54 | self.ones_plural = { 55 | "sixes" if name == "six" else name + "s": (value, "s") 56 | for name, value in self.ones.items() 57 | } 58 | self.ones_ordinal = { 59 | "zeroth": (0, "th"), 60 | "first": (1, "st"), 61 | "second": (2, "nd"), 62 | "third": (3, "rd"), 63 | "fifth": (5, "th"), 64 | "twelfth": (12, "th"), 65 | **{ 66 | name + ("h" if name.endswith("t") else "th"): (value, "th") 67 | for name, value in self.ones.items() 68 | if value > 3 and value != 5 and value != 12 69 | }, 70 | } 71 | self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal} 72 | 73 | self.tens = { 74 | "twenty": 20, 75 | "thirty": 30, 76 | "forty": 40, 77 | "fifty": 50, 78 | "sixty": 60, 79 | "seventy": 70, 80 | "eighty": 80, 81 | "ninety": 90, 82 | } 83 | self.tens_plural = { 84 | name.replace("y", "ies"): (value, "s") for name, value in self.tens.items() 85 | } 86 | self.tens_ordinal = { 87 | name.replace("y", "ieth"): (value, "th") 88 | for name, value in self.tens.items() 89 | } 90 | self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal} 91 | 92 | self.multipliers = { 93 | "hundred": 100, 94 | "thousand": 1_000, 95 | "million": 1_000_000, 96 | "billion": 1_000_000_000, 97 | "trillion": 1_000_000_000_000, 98 | "quadrillion": 1_000_000_000_000_000, 99 | "quintillion": 1_000_000_000_000_000_000, 100 | "sextillion": 1_000_000_000_000_000_000_000, 101 | "septillion": 1_000_000_000_000_000_000_000_000, 102 | "octillion": 1_000_000_000_000_000_000_000_000_000, 103 | "nonillion": 1_000_000_000_000_000_000_000_000_000_000, 104 | "decillion": 1_000_000_000_000_000_000_000_000_000_000_000, 105 | } 106 | self.multipliers_plural = { 107 | name + "s": (value, "s") for name, value in self.multipliers.items() 108 | } 109 | self.multipliers_ordinal = { 110 | name + "th": (value, "th") for name, value in self.multipliers.items() 111 | } 112 | self.multipliers_suffixed = { 113 | **self.multipliers_plural, 114 | **self.multipliers_ordinal, 115 | } 116 | self.decimals = {*self.ones, *self.tens, *self.zeros} 117 | 118 | self.preceding_prefixers = { 119 | "minus": "-", 120 | "negative": "-", 121 | "plus": "+", 122 | "positive": "+", 123 | } 124 | self.following_prefixers = { 125 | "pound": "£", 126 | "pounds": "£", 127 | "euro": "€", 128 | "euros": "€", 129 | "dollar": "$", 130 | "dollars": "$", 131 | "cent": "¢", 132 | "cents": "¢", 133 | } 134 | self.prefixes = set( 135 | list(self.preceding_prefixers.values()) 136 | + list(self.following_prefixers.values()) 137 | ) 138 | self.suffixers = { 139 | "per": {"cent": "%"}, 140 | "percent": "%", 141 | } 142 | self.specials = {"and", "double", "triple", "point"} 143 | 144 | self.words = set( 145 | [ 146 | key 147 | for mapping in [ 148 | self.zeros, 149 | self.ones, 150 | self.ones_suffixed, 151 | self.tens, 152 | self.tens_suffixed, 153 | self.multipliers, 154 | self.multipliers_suffixed, 155 | self.preceding_prefixers, 156 | self.following_prefixers, 157 | self.suffixers, 158 | self.specials, 159 | ] 160 | for key in mapping 161 | ] 162 | ) 163 | self.literal_words = {"one", "ones"} 164 | 165 | def process_words(self, words: List[str]) -> Iterator[str]: 166 | prefix: Optional[str] = None 167 | value: Optional[Union[str, int]] = None 168 | skip = False 169 | 170 | def to_fraction(s: str): 171 | try: 172 | return Fraction(s) 173 | except ValueError: 174 | return None 175 | 176 | def output(result: Union[str, int]): 177 | nonlocal prefix, value 178 | result = str(result) 179 | if prefix is not None: 180 | result = prefix + result 181 | value = None 182 | prefix = None 183 | return result 184 | 185 | if len(words) == 0: 186 | return 187 | 188 | for prev, current, next in windowed([None] + words + [None], 3): 189 | if skip: 190 | skip = False 191 | continue 192 | 193 | next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next) 194 | has_prefix = current[0] in self.prefixes 195 | current_without_prefix = current[1:] if has_prefix else current 196 | if re.match(r"^\d+(\.\d+)?$", current_without_prefix): 197 | # arabic numbers (potentially with signs and fractions) 198 | f = to_fraction(current_without_prefix) 199 | assert f is not None 200 | if value is not None: 201 | if isinstance(value, str) and value.endswith("."): 202 | # concatenate decimals / ip address components 203 | value = str(value) + str(current) 204 | continue 205 | else: 206 | yield output(value) 207 | 208 | prefix = current[0] if has_prefix else prefix 209 | if f.denominator == 1: 210 | value = f.numerator # store integers as int 211 | else: 212 | value = current_without_prefix 213 | elif current not in self.words: 214 | # non-numeric words 215 | if value is not None: 216 | yield output(value) 217 | yield output(current) 218 | elif current in self.zeros: 219 | value = str(value or "") + "0" 220 | elif current in self.ones: 221 | ones = self.ones[current] 222 | 223 | if value is None: 224 | value = ones 225 | elif isinstance(value, str) or prev in self.ones: 226 | if ( 227 | prev in self.tens and ones < 10 228 | ): # replace the last zero with the digit 229 | assert value[-1] == "0" 230 | value = value[:-1] + str(ones) 231 | else: 232 | value = str(value) + str(ones) 233 | elif ones < 10: 234 | if value % 10 == 0: 235 | value += ones 236 | else: 237 | value = str(value) + str(ones) 238 | else: # eleven to nineteen 239 | if value % 100 == 0: 240 | value += ones 241 | else: 242 | value = str(value) + str(ones) 243 | elif current in self.ones_suffixed: 244 | # ordinal or cardinal; yield the number right away 245 | ones, suffix = self.ones_suffixed[current] 246 | if value is None: 247 | yield output(str(ones) + suffix) 248 | elif isinstance(value, str) or prev in self.ones: 249 | if prev in self.tens and ones < 10: 250 | assert value[-1] == "0" 251 | yield output(value[:-1] + str(ones) + suffix) 252 | else: 253 | yield output(str(value) + str(ones) + suffix) 254 | elif ones < 10: 255 | if value % 10 == 0: 256 | yield output(str(value + ones) + suffix) 257 | else: 258 | yield output(str(value) + str(ones) + suffix) 259 | else: # eleven to nineteen 260 | if value % 100 == 0: 261 | yield output(str(value + ones) + suffix) 262 | else: 263 | yield output(str(value) + str(ones) + suffix) 264 | value = None 265 | elif current in self.tens: 266 | tens = self.tens[current] 267 | if value is None: 268 | value = tens 269 | elif isinstance(value, str): 270 | value = str(value) + str(tens) 271 | else: 272 | if value % 100 == 0: 273 | value += tens 274 | else: 275 | value = str(value) + str(tens) 276 | elif current in self.tens_suffixed: 277 | # ordinal or cardinal; yield the number right away 278 | tens, suffix = self.tens_suffixed[current] 279 | if value is None: 280 | yield output(str(tens) + suffix) 281 | elif isinstance(value, str): 282 | yield output(str(value) + str(tens) + suffix) 283 | else: 284 | if value % 100 == 0: 285 | yield output(str(value + tens) + suffix) 286 | else: 287 | yield output(str(value) + str(tens) + suffix) 288 | elif current in self.multipliers: 289 | multiplier = self.multipliers[current] 290 | if value is None: 291 | value = multiplier 292 | elif isinstance(value, str) or value == 0: 293 | f = to_fraction(value) 294 | p = f * multiplier if f is not None else None 295 | if f is not None and p.denominator == 1: 296 | value = p.numerator 297 | else: 298 | yield output(value) 299 | value = multiplier 300 | else: 301 | before = value // 1000 * 1000 302 | residual = value % 1000 303 | value = before + residual * multiplier 304 | elif current in self.multipliers_suffixed: 305 | multiplier, suffix = self.multipliers_suffixed[current] 306 | if value is None: 307 | yield output(str(multiplier) + suffix) 308 | elif isinstance(value, str): 309 | f = to_fraction(value) 310 | p = f * multiplier if f is not None else None 311 | if f is not None and p.denominator == 1: 312 | yield output(str(p.numerator) + suffix) 313 | else: 314 | yield output(value) 315 | yield output(str(multiplier) + suffix) 316 | else: # int 317 | before = value // 1000 * 1000 318 | residual = value % 1000 319 | value = before + residual * multiplier 320 | yield output(str(value) + suffix) 321 | value = None 322 | elif current in self.preceding_prefixers: 323 | # apply prefix (positive, minus, etc.) if it precedes a number 324 | if value is not None: 325 | yield output(value) 326 | 327 | if next in self.words or next_is_numeric: 328 | prefix = self.preceding_prefixers[current] 329 | else: 330 | yield output(current) 331 | elif current in self.following_prefixers: 332 | # apply prefix (dollars, cents, etc.) only after a number 333 | if value is not None: 334 | prefix = self.following_prefixers[current] 335 | yield output(value) 336 | else: 337 | yield output(current) 338 | elif current in self.suffixers: 339 | # apply suffix symbols (percent -> '%') 340 | if value is not None: 341 | suffix = self.suffixers[current] 342 | if isinstance(suffix, dict): 343 | if next in suffix: 344 | yield output(str(value) + suffix[next]) 345 | skip = True 346 | else: 347 | yield output(value) 348 | yield output(current) 349 | else: 350 | yield output(str(value) + suffix) 351 | else: 352 | yield output(current) 353 | elif current in self.specials: 354 | if next not in self.words and not next_is_numeric: 355 | # apply special handling only if the next word can be numeric 356 | if value is not None: 357 | yield output(value) 358 | yield output(current) 359 | elif current == "and": 360 | # ignore "and" after hundreds, thousands, etc. 361 | if prev not in self.multipliers: 362 | if value is not None: 363 | yield output(value) 364 | yield output(current) 365 | elif current == "double" or current == "triple": 366 | if next in self.ones or next in self.zeros: 367 | repeats = 2 if current == "double" else 3 368 | ones = self.ones.get(next, 0) 369 | value = str(value or "") + str(ones) * repeats 370 | skip = True 371 | else: 372 | if value is not None: 373 | yield output(value) 374 | yield output(current) 375 | elif current == "point": 376 | if next in self.decimals or next_is_numeric: 377 | value = str(value or "") + "." 378 | else: 379 | # should all have been covered at this point 380 | raise ValueError(f"Unexpected token: {current}") 381 | else: 382 | # all should have been covered at this point 383 | raise ValueError(f"Unexpected token: {current}") 384 | 385 | if value is not None: 386 | yield output(value) 387 | 388 | def preprocess(self, s: str): 389 | # replace " and a half" with " point five" 390 | results = [] 391 | 392 | segments = re.split(r"\band\s+a\s+half\b", s) 393 | for i, segment in enumerate(segments): 394 | if len(segment.strip()) == 0: 395 | continue 396 | if i == len(segments) - 1: 397 | results.append(segment) 398 | else: 399 | results.append(segment) 400 | last_word = segment.rsplit(maxsplit=2)[-1] 401 | if last_word in self.decimals or last_word in self.multipliers: 402 | results.append("point five") 403 | else: 404 | results.append("and a half") 405 | 406 | s = " ".join(results) 407 | 408 | # put a space at number/letter boundary 409 | s = re.sub(r"([a-z])([0-9])", r"\1 \2", s) 410 | s = re.sub(r"([0-9])([a-z])", r"\1 \2", s) 411 | 412 | # but remove spaces which could be a suffix 413 | s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s) 414 | 415 | return s 416 | 417 | def postprocess(self, s: str): 418 | def combine_cents(m: Match): 419 | try: 420 | currency = m.group(1) 421 | integer = m.group(2) 422 | cents = int(m.group(3)) 423 | return f"{currency}{integer}.{cents:02d}" 424 | except ValueError: 425 | return m.string 426 | 427 | def extract_cents(m: Match): 428 | try: 429 | return f"¢{int(m.group(1))}" 430 | except ValueError: 431 | return m.string 432 | 433 | # apply currency postprocessing; "$2 and ¢7" -> "$2.07" 434 | s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s) 435 | s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s) 436 | 437 | # write "one(s)" instead of "1(s)", just for the readability 438 | s = re.sub(r"\b1(s?)\b", r"one\1", s) 439 | 440 | return s 441 | 442 | def __call__(self, s: str): 443 | s = self.preprocess(s) 444 | s = " ".join(word for word in self.process_words(s.split()) if word is not None) 445 | s = self.postprocess(s) 446 | 447 | return s 448 | 449 | 450 | class EnglishSpellingNormalizer: 451 | """ 452 | Applies British-American spelling mappings as listed in [1]. 453 | 454 | [1] https://www.tysto.com/uk-us-spelling-list.html 455 | """ 456 | 457 | def __init__(self): 458 | mapping_path = os.path.join(os.path.dirname(__file__), "english.json") 459 | self.mapping = json.load(open(mapping_path)) 460 | 461 | def __call__(self, s: str): 462 | return " ".join(self.mapping.get(word, word) for word in s.split()) 463 | 464 | 465 | class EnglishTextNormalizer: 466 | def __init__(self): 467 | self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b" 468 | self.replacers = { 469 | # common contractions 470 | r"\bwon't\b": "will not", 471 | r"\bcan't\b": "can not", 472 | r"\blet's\b": "let us", 473 | r"\bain't\b": "aint", 474 | r"\by'all\b": "you all", 475 | r"\bwanna\b": "want to", 476 | r"\bgotta\b": "got to", 477 | r"\bgonna\b": "going to", 478 | r"\bi'ma\b": "i am going to", 479 | r"\bimma\b": "i am going to", 480 | r"\bwoulda\b": "would have", 481 | r"\bcoulda\b": "could have", 482 | r"\bshoulda\b": "should have", 483 | r"\bma'am\b": "madam", 484 | # contractions in titles/prefixes 485 | r"\bmr\b": "mister ", 486 | r"\bmrs\b": "missus ", 487 | r"\bst\b": "saint ", 488 | r"\bdr\b": "doctor ", 489 | r"\bprof\b": "professor ", 490 | r"\bcapt\b": "captain ", 491 | r"\bgov\b": "governor ", 492 | r"\bald\b": "alderman ", 493 | r"\bgen\b": "general ", 494 | r"\bsen\b": "senator ", 495 | r"\brep\b": "representative ", 496 | r"\bpres\b": "president ", 497 | r"\brev\b": "reverend ", 498 | r"\bhon\b": "honorable ", 499 | r"\basst\b": "assistant ", 500 | r"\bassoc\b": "associate ", 501 | r"\blt\b": "lieutenant ", 502 | r"\bcol\b": "colonel ", 503 | r"\bjr\b": "junior ", 504 | r"\bsr\b": "senior ", 505 | r"\besq\b": "esquire ", 506 | # prefect tenses, ideally it should be any past participles, but it's harder.. 507 | r"'d been\b": " had been", 508 | r"'s been\b": " has been", 509 | r"'d gone\b": " had gone", 510 | r"'s gone\b": " has gone", 511 | r"'d done\b": " had done", # "'s done" is ambiguous 512 | r"'s got\b": " has got", 513 | # general contractions 514 | r"n't\b": " not", 515 | r"'re\b": " are", 516 | r"'s\b": " is", 517 | r"'d\b": " would", 518 | r"'ll\b": " will", 519 | r"'t\b": " not", 520 | r"'ve\b": " have", 521 | r"'m\b": " am", 522 | } 523 | self.standardize_numbers = EnglishNumberNormalizer() 524 | self.standardize_spellings = EnglishSpellingNormalizer() 525 | 526 | def __call__(self, s: str): 527 | s = s.lower() 528 | 529 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets 530 | s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis 531 | s = re.sub(self.ignore_patterns, "", s) 532 | s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe 533 | 534 | for pattern, replacement in self.replacers.items(): 535 | s = re.sub(pattern, replacement, s) 536 | 537 | s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits 538 | s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers 539 | s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols 540 | 541 | s = self.standardize_numbers(s) 542 | s = self.standardize_spellings(s) 543 | 544 | # now remove prefix/suffix symbols that are not preceded/followed by numbers 545 | s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s) 546 | s = re.sub(r"([^0-9])%", r"\1 ", s) 547 | 548 | s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space 549 | 550 | return s 551 | --------------------------------------------------------------------------------