├── training_code
├── __init__.py
├── ds_dict.py
├── collators.py
├── utils.py
├── train.py
├── datasets_classes.py
└── whisper_module.py
├── .gitignore
├── careless_whisper_stream
├── version.py
├── __main__.py
├── assets
│ └── mel_filters.npz
├── normalizers
│ ├── __init__.py
│ ├── basic.py
│ └── english.py
├── triton_ops.py
├── streaming_transcribe.py
├── utils.py
├── model.py
├── __init__.py
├── tokenizer.py
├── audio.py
├── timing.py
└── streaming_model.py
├── tests
├── jfk.wav
└── streaming_transcription.py
├── transcribe.py
├── requirements.txt
├── environment.yml
├── README.md
└── LICENSE
/training_code/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.ipynb
2 | __pycache__/
3 | notebooks/
4 | static/
--------------------------------------------------------------------------------
/careless_whisper_stream/version.py:
--------------------------------------------------------------------------------
1 | __version__ = "20231117"
2 |
--------------------------------------------------------------------------------
/careless_whisper_stream/__main__.py:
--------------------------------------------------------------------------------
1 | from .streaming_transcribe import cli
2 |
3 | cli()
4 |
--------------------------------------------------------------------------------
/tests/jfk.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomer9080/CarelessWhisper-Streaming/HEAD/tests/jfk.wav
--------------------------------------------------------------------------------
/transcribe.py:
--------------------------------------------------------------------------------
1 | from careless_whisper_stream.streaming_transcribe import cli
2 |
3 | if __name__ == "__main__":
4 | cli()
--------------------------------------------------------------------------------
/careless_whisper_stream/assets/mel_filters.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomer9080/CarelessWhisper-Streaming/HEAD/careless_whisper_stream/assets/mel_filters.npz
--------------------------------------------------------------------------------
/careless_whisper_stream/normalizers/__init__.py:
--------------------------------------------------------------------------------
1 | from .basic import BasicTextNormalizer as BasicTextNormalizer
2 | from .english import EnglishTextNormalizer as EnglishTextNormalizer
3 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | aiohttp
2 | evaluate
3 | huggingface_hub
4 | jiwer
5 | more_itertools
6 | numba
7 | numpy
8 | openai_whisper==20240930
9 | pandas
10 | praatio
11 | pyaudio
12 | pytorch_lightning==2.5.0.post0
13 | regex
14 | soundfile
15 | tiktoken
16 | tqdm
17 | triton
18 | websockets
19 | wandb
20 |
--------------------------------------------------------------------------------
/tests/streaming_transcription.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append("./")
4 |
5 | import torch
6 | import careless_whisper_stream
7 |
8 | device = "cuda" if torch.cuda.is_available() else "cpu"
9 | model = careless_whisper_stream.load_streaming_model("small", 300, False, device)
10 | texts = model.transcribe(simulate_stream=True, wav_file=os.path.join(os.path.dirname(os.path.abspath(__file__)), "jfk.wav"), beam_size=5, ca_kv_cache=True)
11 | print(texts)
--------------------------------------------------------------------------------
/training_code/ds_dict.py:
--------------------------------------------------------------------------------
1 | import os
2 | """
3 | Taken from README.md:
4 | ### Dataset Structure
5 |
6 | Before starting model training using the command-line interface provided below, you must first configure your dataset dictionary file located at `training_code/ds_dict.py`.
7 |
8 | This file defines a Python dictionary named `ds_paths`, where you should specify paths to the `train`, `val`, and `test` partitions of your dataset. Each partition should be a CSV file with the following three columns:
9 |
10 | 1. `wav_path` — Path to the WAV audio file.
11 | 2. `tg_path` — Path to the corresponding `.TextGrid` file containing forced alignment.
12 | 3. `raw_text` — Ground truth transcription.
13 |
14 | > **Note:** The dictionary key (i.e., the name of the dataset) will be used by the training script to identify and load the dataset correctly.
15 | """
16 |
17 | ds_paths = {
18 | 'LIBRI-960-ALIGNED': {
19 | 'train': f"{os.environ.get('HOME')}/path/to/datasets/libri_train_960_train.csv", # used on training steps.
20 | 'val': f"{os.environ.get('HOME')}/path/to/datasets/libri_train_960_val.csv", # used on validation steps.
21 | 'test': f"{os.environ.get('HOME')}/path/to/datasets/libri_train_960_test.csv" # used on evaluation.
22 | },
23 | # Add you entries below.
24 | }
25 |
--------------------------------------------------------------------------------
/careless_whisper_stream/normalizers/basic.py:
--------------------------------------------------------------------------------
1 | import re
2 | import unicodedata
3 |
4 | import regex
5 |
6 | # non-ASCII letters that are not separated by "NFKD" normalization
7 | ADDITIONAL_DIACRITICS = {
8 | "œ": "oe",
9 | "Œ": "OE",
10 | "ø": "o",
11 | "Ø": "O",
12 | "æ": "ae",
13 | "Æ": "AE",
14 | "ß": "ss",
15 | "ẞ": "SS",
16 | "đ": "d",
17 | "Đ": "D",
18 | "ð": "d",
19 | "Ð": "D",
20 | "þ": "th",
21 | "Þ": "th",
22 | "ł": "l",
23 | "Ł": "L",
24 | }
25 |
26 |
27 | def remove_symbols_and_diacritics(s: str, keep=""):
28 | """
29 | Replace any other markers, symbols, and punctuations with a space,
30 | and drop any diacritics (category 'Mn' and some manual mappings)
31 | """
32 | return "".join(
33 | c
34 | if c in keep
35 | else ADDITIONAL_DIACRITICS[c]
36 | if c in ADDITIONAL_DIACRITICS
37 | else ""
38 | if unicodedata.category(c) == "Mn"
39 | else " "
40 | if unicodedata.category(c)[0] in "MSP"
41 | else c
42 | for c in unicodedata.normalize("NFKD", s)
43 | )
44 |
45 |
46 | def remove_symbols(s: str):
47 | """
48 | Replace any other markers, symbols, punctuations with a space, keeping diacritics
49 | """
50 | return "".join(
51 | " " if unicodedata.category(c)[0] in "MSP" else c
52 | for c in unicodedata.normalize("NFKC", s)
53 | )
54 |
55 |
56 | class BasicTextNormalizer:
57 | def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
58 | self.clean = (
59 | remove_symbols_and_diacritics if remove_diacritics else remove_symbols
60 | )
61 | self.split_letters = split_letters
62 |
63 | def __call__(self, s: str):
64 | s = s.lower()
65 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
66 | s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
67 | s = self.clean(s).lower()
68 |
69 | if self.split_letters:
70 | s = " ".join(regex.findall(r"\X", s, regex.U))
71 |
72 | s = re.sub(
73 | r"\s+", " ", s
74 | ) # replace any successive whitespace characters with a space
75 |
76 | return s
77 |
--------------------------------------------------------------------------------
/training_code/collators.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn.functional as F
4 |
5 | class WhisperDataCollatorWithPadding:
6 | def __call__(self, features):
7 |
8 | input_ids, labels, dec_input_ids, labels_classes, unique_ids = [], [], [], [], []
9 | for f in features:
10 | input_ids.append(f["input_ids"])
11 | labels.append(f["labels"])
12 | dec_input_ids.append(f["dec_input_ids"])
13 | labels_classes.append([int(item==50257) for item in f["labels"]])
14 | unique_ids.append(f.get("u_id", 0))
15 |
16 | input_ids = torch.concat([input_id[None, :] for input_id in input_ids])
17 |
18 | label_lengths = [len(lab) for lab in labels]
19 | dec_input_ids_length = [len(e) for e in dec_input_ids]
20 | max_label_len = max(label_lengths+dec_input_ids_length)
21 |
22 | labels = [np.pad(lab, (0, max_label_len - lab_len), 'constant', constant_values=-100) for lab, lab_len in zip(labels, label_lengths)]
23 | # labels = [np.pad(lab, (0, max_label_len - lab_len), 'constant', constant_values=50257) for lab, lab_len in zip(labels, label_lengths)]
24 | labels_classes = [np.pad(lab, (0, max_label_len - lab_len), 'constant', constant_values=-100) for lab, lab_len in zip(labels_classes, label_lengths)]
25 | dec_input_ids = [np.pad(e, (0, max_label_len - e_len), 'constant', constant_values=50257) for e, e_len in zip(dec_input_ids, dec_input_ids_length)] # 50257 is eot token id
26 |
27 | batch = {
28 | "labels": labels,
29 | "dec_input_ids": dec_input_ids,
30 | "labels_classes": labels_classes,
31 | "unique_id": unique_ids
32 | }
33 |
34 | batch = {k: torch.tensor(np.array(v), requires_grad=False) for k, v in batch.items()}
35 |
36 | batch["input_ids"] = input_ids
37 |
38 | return batch
39 |
40 |
41 | def pad_2d_sequences(arrays: list, dim: int = 0, padding_value: int = 0) -> torch.Tensor:
42 | lens = [array.shape[dim] for array in arrays]
43 | max_len = max(lens)
44 |
45 | padded_arrays = [F.pad(array, (0, int(dim == 1) * (max_len - array.shape[1]), 0, int(dim == 0) * (max_len - array.shape[0])), mode="constant", value=padding_value) for array in arrays]
46 | return torch.cat([padded_array[None] for padded_array in padded_arrays]) # adding batch dim and concatanating
47 |
48 |
49 | class LoRAWhisperDataCollatorWithPadding:
50 | def __call__(self, features):
51 |
52 | input_ids, labels, dec_input_ids, endpoints = [], [], [], []
53 |
54 | for f in features:
55 | input_ids.append(f["input_ids"])
56 | labels.append(f["labels"])
57 | dec_input_ids.append(f["dec_input_ids"])
58 | endpoints.append(f["endpoints"])
59 |
60 | # make a batch
61 | input_ids = torch.concat([input_id[None, :] for input_id in input_ids])
62 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
63 | dec_input_ids = torch.nn.utils.rnn.pad_sequence(dec_input_ids, batch_first=True, padding_value=50257)
64 | endpoints = torch.nn.utils.rnn.pad_sequence(endpoints, batch_first=True, padding_value=-100)
65 |
66 | batch = {
67 | "labels": labels,
68 | "dec_input_ids": dec_input_ids,
69 | "endpoints": endpoints
70 | }
71 |
72 | batch = {k: v.detach() for k, v in batch.items()}
73 |
74 | batch["input_ids"] = input_ids.squeeze(1)
75 |
76 | return batch
77 |
78 |
--------------------------------------------------------------------------------
/careless_whisper_stream/triton_ops.py:
--------------------------------------------------------------------------------
1 | from functools import lru_cache
2 |
3 | import numpy as np
4 | import torch
5 |
6 | try:
7 | import triton
8 | import triton.language as tl
9 | except ImportError:
10 | raise RuntimeError("triton import failed; try `pip install --pre triton`")
11 |
12 |
13 | @triton.jit
14 | def dtw_kernel(
15 | cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr
16 | ):
17 | offsets = tl.arange(0, BLOCK_SIZE)
18 | mask = offsets < M
19 |
20 | for k in range(1, N + M + 1): # k = i + j
21 | tl.debug_barrier()
22 |
23 | p0 = cost + (k - 1) * cost_stride
24 | p1 = cost + k * cost_stride
25 | p2 = cost + k * cost_stride + 1
26 |
27 | c0 = tl.load(p0 + offsets, mask=mask)
28 | c1 = tl.load(p1 + offsets, mask=mask)
29 | c2 = tl.load(p2 + offsets, mask=mask)
30 |
31 | x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0)
32 | cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2)
33 |
34 | cost_ptr = cost + (k + 1) * cost_stride + 1
35 | tl.store(cost_ptr + offsets, cost_row, mask=mask)
36 |
37 | trace_ptr = trace + (k + 1) * trace_stride + 1
38 | tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1))
39 | tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2))
40 | tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2))
41 |
42 |
43 | @lru_cache(maxsize=None)
44 | def median_kernel(filter_width: int):
45 | @triton.jit
46 | def kernel(
47 | y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr
48 | ): # x.shape[-1] == filter_width
49 | row_idx = tl.program_id(0)
50 | offsets = tl.arange(0, BLOCK_SIZE)
51 | mask = offsets < y_stride
52 |
53 | x_ptr = x + row_idx * x_stride # noqa: F841
54 | y_ptr = y + row_idx * y_stride
55 |
56 | LOAD_ALL_ROWS_HERE # noqa: F821
57 |
58 | BUBBLESORT_HERE # noqa: F821
59 |
60 | tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
61 |
62 | kernel = triton.JITFunction(kernel.fn)
63 | kernel.src = kernel.src.replace(
64 | " LOAD_ALL_ROWS_HERE",
65 | "\n".join(
66 | [
67 | f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
68 | for i in range(filter_width)
69 | ]
70 | ),
71 | )
72 | kernel.src = kernel.src.replace(
73 | " BUBBLESORT_HERE",
74 | "\n\n".join(
75 | [
76 | "\n\n".join(
77 | [
78 | "\n".join(
79 | [
80 | f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
81 | f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
82 | f" row{j} = smaller",
83 | f" row{j + 1} = larger",
84 | ]
85 | )
86 | for j in range(filter_width - i - 1)
87 | ]
88 | )
89 | for i in range(filter_width // 2 + 1)
90 | ]
91 | ),
92 | )
93 | kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
94 |
95 | return kernel
96 |
97 |
98 | def median_filter_cuda(x: torch.Tensor, filter_width: int):
99 | """Apply a median filter of given width along the last dimension of x"""
100 | slices = x.contiguous().unfold(-1, filter_width, 1)
101 | grid = np.prod(slices.shape[:-2])
102 |
103 | kernel = median_kernel(filter_width)
104 | y = torch.empty_like(slices[..., 0])
105 |
106 | BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length()
107 | kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE)
108 |
109 | return y
110 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: careless_whisper
2 | channels:
3 | - defaults
4 | - conda-forge
5 | - https://repo.anaconda.com/pkgs/main
6 | - https://repo.anaconda.com/pkgs/r
7 | dependencies:
8 | - _libgcc_mutex=0.1
9 | - _openmp_mutex=5.1
10 | - annotated-types=0.6.0
11 | - appdirs=1.4.4
12 | - asttokens=3.0.0
13 | - brotlicffi=1.0.9.2
14 | - ca-certificates=2025.7.14
15 | - certifi=2025.7.14
16 | - cffi=1.17.1
17 | - click=8.1.8
18 | - comm=0.2.2
19 | - debugpy=1.8.14
20 | - decorator=5.2.1
21 | - docker-pycreds=0.4.0
22 | - eval_type_backport=0.2.2
23 | - exceptiongroup=1.3.0
24 | - executing=2.2.0
25 | - gitdb=4.0.7
26 | - gitpython=3.1.43
27 | - importlib-metadata=8.7.0
28 | - ipykernel=6.29.5
29 | - ipython=8.18.1
30 | - jedi=0.19.2
31 | - jupyter_client=8.6.3
32 | - jupyter_core=5.8.1
33 | - krb5=1.21.3
34 | - ld_impl_linux-64=2.40
35 | - libabseil=20250127.0
36 | - libedit=3.1.20230828
37 | - libffi=3.4.4
38 | - libgcc=15.1.0
39 | - libgcc-ng=15.1.0
40 | - libgomp=15.1.0
41 | - libsodium=1.0.20
42 | - libstdcxx=15.1.0
43 | - libstdcxx-ng=11.2.0
44 | - libxcb=1.17.0
45 | - matplotlib-inline=0.1.7
46 | - ncurses=6.4
47 | - nest-asyncio=1.6.0
48 | - openssl=3.5.1
49 | - packaging=25.0
50 | - parso=0.8.4
51 | - pexpect=4.9.0
52 | - pickleshare=0.7.5
53 | - pip=25.1
54 | - platformdirs=4.3.8
55 | - prompt-toolkit=3.0.51
56 | - protobuf=5.29.3
57 | - psutil=7.0.0
58 | - pthread-stubs=0.3
59 | - ptyprocess=0.7.0
60 | - pure_eval=0.2.3
61 | - pydantic=2.11.7
62 | - pydantic-core=2.33.2
63 | - pygments=2.19.2
64 | - pysocks=1.7.1
65 | - python=3.9.16
66 | - python-dateutil=2.9.0.post0
67 | - python_abi=3.9
68 | - pyyaml=6.0.2
69 | - pyzmq=27.0.0
70 | - readline=8.2
71 | - requests=2.32.4
72 | - sentry-sdk=2.18.0
73 | - setproctitle=1.2.2
74 | - setuptools=78.1.1
75 | - six=1.17.0
76 | - smmap=5.0.2
77 | - sqlite=3.45.3
78 | - stack_data=0.6.3
79 | - tk=8.6.14
80 | - tornado=6.5.1
81 | - traitlets=5.14.3
82 | - typing=3.10.0.0
83 | - typing-inspection=0.4.0
84 | - typing_extensions=4.14.1
85 | - urllib3=2.5.0
86 | - wandb=0.21.0
87 | - wcwidth=0.2.13
88 | - wheel=0.45.1
89 | - xorg-libx11=1.8.12
90 | - xorg-libxau=1.0.12
91 | - xorg-libxdmcp=1.1.5
92 | - xorg-xorgproto=2024.1
93 | - xz=5.6.4
94 | - yaml=0.2.5
95 | - zeromq=4.3.5
96 | - zipp=3.23.0
97 | - zlib=1.2.13
98 | - pip:
99 | - aiohappyeyeballs==2.6.1
100 | - aiohttp==3.12.14
101 | - aiosignal==1.4.0
102 | - async-timeout==5.0.1
103 | - attrs==25.3.0
104 | - audioread==3.0.1
105 | - charset-normalizer==3.4.2
106 | - contourpy==1.3.0
107 | - cycler==0.12.1
108 | - datasets==4.0.0
109 | - dill==0.3.8
110 | - evaluate==0.4.3
111 | - filelock==3.13.1
112 | - fonttools==4.59.0
113 | - frozenlist==1.7.0
114 | - fsspec==2024.6.1
115 | - hf-xet==1.1.5
116 | - huggingface-hub==0.33.4
117 | - idna==3.10
118 | - importlib-resources==6.5.2
119 | - jinja2==3.1.4
120 | - jiwer==4.0.0
121 | - joblib==1.5.1
122 | - kiwisolver==1.4.7
123 | - lazy-loader==0.4
124 | - librosa==0.10.2.post1
125 | - lightning-utilities==0.14.3
126 | - llvmlite==0.43.0
127 | - markupsafe==2.1.5
128 | - matplotlib==3.9.4
129 | - more-itertools==10.7.0
130 | - mpmath==1.3.0
131 | - msgpack==1.1.1
132 | - multidict==6.6.3
133 | - multiprocess==0.70.16
134 | - networkx==3.2.1
135 | - numba==0.60.0
136 | - numpy==1.26.0
137 | - openai-whisper==20240930
138 | - pandas==2.3.1
139 | - pillow==11.0.0
140 | - pooch==1.8.2
141 | - praatio==6.2.0
142 | - propcache==0.3.2
143 | - pyarrow==20.0.0
144 | - pyaudio==0.2.11
145 | - pycparser==2.22
146 | - pyparsing==3.2.3
147 | - pytorch-lightning==2.5.0.post0
148 | - pytz==2025.2
149 | - rapidfuzz==3.13.0
150 | - regex==2024.11.6
151 | - scikit-learn==1.6.1
152 | - scipy==1.13.1
153 | - soundfile==0.12.1
154 | - soxr==0.5.0.post1
155 | - sympy==1.13.1
156 | - threadpoolctl==3.6.0
157 | - tiktoken==0.8.0
158 | - tqdm==4.66.5
159 | - triton==3.2.0
160 | - typing-extensions==4.12.2
161 | - tzdata==2025.2
162 | - websockets==15.0.1
163 | - xxhash==3.5.0
164 | - yarl==1.20.1
165 |
--------------------------------------------------------------------------------
/training_code/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from dataclasses import dataclass
3 |
4 |
5 | @dataclass(frozen=False)
6 | class Config:
7 | # DL Args
8 | learning_rate: float = 0.0005
9 | weight_decay: float = 0.01
10 | adam_epsilon: float = 1e-6
11 | warmup_steps: int = 100
12 | batch_size: int = 16
13 | num_worker: int = 16
14 | num_train_epochs: int = 10
15 | gradient_accumulation_steps: int = 1
16 | sample_rate: int = 16000
17 | ckpt: str = None
18 | no_logger: bool = False
19 | dataset: str = None
20 | name: str = None
21 | top_k: int = -1
22 | early_stop: bool = False
23 | fast_dev_run: int = None
24 | strategy: str = "ddp"
25 | seed: int = 3407
26 | custom_len: int = 0
27 |
28 | # Whisper args
29 | lang: str = "en"
30 | size: str = "tiny"
31 |
32 | # LoRA + streaming args
33 | lora: bool = False
34 | lora_ckpt: str = None
35 | rank: int = 16
36 | gran: int = 15
37 | extra_gran_blocks: int = 0
38 | sim_stream: bool = False
39 | uniform_sampling: bool = False
40 | streaming_train: bool = False
41 | streaming_fraction: float = 1
42 | streaming_random: bool = False
43 | multilingual: bool = False
44 |
45 | def parse_cmdl():
46 | # parser
47 | parser = argparse.ArgumentParser(description="Training whisper models, using different configurations")
48 |
49 | # switches for train.py
50 | parser.add_argument('--lora', action="store_true", help="run LoRA training")
51 | parser.add_argument('--no_logger', action="store_true", help="set logger to False")
52 |
53 | # variables for trainer
54 | parser.add_argument('--name', type=str, help="Trained model name", default="model")
55 | parser.add_argument('--size', type=str, help="Whisper size - can use only [tiny, base, small, medium, large, large-v2]", default="tiny")
56 | parser.add_argument('--ckpt', type=str, help="ckpt loading to resume training from", default=None)
57 | parser.add_argument('--fast_dev_run', type=int, help="run few dev runs for sanity checks on lightning trainer", default=None)
58 | parser.add_argument('--top_k', type=int, help="Top K checkpoints to save, use 1 for the best, -1 for last", default=1)
59 | parser.add_argument('--early_stop', action="store_true", help="Use early stopping callback")
60 | parser.add_argument('--custom_len', type=int, help="Number of samples to train on", default=0)
61 |
62 | # DL Hyper Parameters
63 | parser.add_argument('--epochs', type=int, help="Number of training epochs", default=10)
64 | parser.add_argument('--batch_size', type=int, help="Batch size for training and evaluation. Better be 2^n, where n is a positive integer.", default=16)
65 | parser.add_argument('--dataset', type=str, help="Name of dataset to load", default='TIMIT-WORD')
66 | parser.add_argument('--learning_rate', type=float, help="Custom learning rate", default=0.0001)
67 | parser.add_argument('--gacc', type=int, help="Number of gradient accumulation steps", default=1)
68 | parser.add_argument('--weight_decay', type=float, help="Weight decay factor", default=0.01)
69 | parser.add_argument('--adam_epsilon', type=float, help="Adam epsilon", default=1e-6)
70 | parser.add_argument('--warmup_steps', type=int, help="Scheduler warmup steps", default=100)
71 | parser.add_argument('--num_worker', type=int, help="Data loader workers", default=16)
72 | parser.add_argument('--strategy', type=str, help="Trainer strategy [ddp, fsdp, ddp_find_unused_parameters_true]", default="ddp")
73 |
74 | # LoRA
75 | parser.add_argument('--lora_ckpt', type=str, help="ckpt loading (for LoRA training mode only)", default=None)
76 | parser.add_argument('--rank', type=int, help="LoRA rank", default=16)
77 | parser.add_argument('--gran', type=int, help="Granularity in encoder frames to calc attention on", default=15)
78 | parser.add_argument('--extra_gran_blocks', type=int, help="How many extra granularity blocks we add on encoder causal block matrix", default=0)
79 | parser.add_argument('--streaming_fraction', type=float, help="Fraction of the available streaming sample points to train on.", default=1)
80 | parser.add_argument('--simulate_stream', action="store_true", help="Spectrogram input is simulated to a stream scenario")
81 | parser.add_argument('--streaming_train', action="store_true", help="Train sequentially on a stream of data.")
82 | parser.add_argument('--streaming_random', action="store_true", help="Train using random sample points, not sequentially!")
83 | parser.add_argument('--multilingual', action="store_true", help="Train using multilingual dataset, assuming lang field is available.")
84 |
85 | return parser.parse_args()
86 |
--------------------------------------------------------------------------------
/training_code/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os
3 | import torch
4 | from pathlib import Path
5 | from ds_dict import ds_paths
6 | from whisper_module import LoRAStreamedWhisper
7 | from training_code.utils import Config, parse_cmdl
8 | from pytorch_lightning.loggers import WandbLogger
9 | from pytorch_lightning import Trainer, seed_everything
10 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping
11 |
12 |
13 | SEED = 3407
14 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15 | # DEVICE = "cpu"
16 | seed_everything(SEED, workers=True)
17 | torch.set_float32_matmul_precision("high")
18 |
19 | logs_root = "/mlspeech/data/tomer/streaming_whisper/models/logs"
20 | ckpt_root = "/mlspeech/data/tomer/streaming_whisper/models/ckpts"
21 |
22 | whisper_lrs: dict[str, float] = {'tiny': 1.5e-3, 'base': 1e-3, 'small': 5e-4, 'medium': 2.5e-4, 'large': 1.75e-4, 'large-v2': 2e-4}
23 |
24 | project_names = {
25 | "lora": "LoRA_whisper_stream",
26 | }
27 |
28 | def train_model(log_output_dir, check_output_dir, model_name, train_set, val_set, train_name, project_name, cfg: Config) -> None:
29 |
30 | Path(log_output_dir).mkdir(exist_ok=True)
31 | Path(check_output_dir).mkdir(exist_ok=True)
32 |
33 | wandblogger = WandbLogger(
34 | save_dir=log_output_dir,
35 | name=train_name,
36 | project=project_names[project_name]
37 | )
38 |
39 | checkpoint_callback = ModelCheckpoint(
40 | dirpath=f"{check_output_dir}/checkpoint",
41 | filename="checkpoint-{epoch:04d}",
42 | save_top_k=cfg.top_k, # Best model save,
43 | monitor="val/wer"
44 | )
45 |
46 | callback_list = [checkpoint_callback, LearningRateMonitor(logging_interval="epoch")]
47 |
48 | if cfg.early_stop:
49 | early_stop_callback = EarlyStopping(
50 | monitor="val/wer",
51 | min_delta=0.00,
52 | patience=2,
53 | mode="min",
54 | )
55 | callback_list.append(early_stop_callback)
56 |
57 | # Model mux
58 | if cfg.lora and cfg.streaming_train:
59 | model = LoRAStreamedWhisper(cfg, model_name, cfg.lang, train_set, val_set, rank=cfg.rank, enc_emb_gran=cfg.gran, enc_context=cfg.extra_gran_blocks, sim_stream=cfg.sim_stream)
60 |
61 | trainer = Trainer(
62 | accelerator=DEVICE,
63 | max_epochs=cfg.num_train_epochs,
64 | callbacks=callback_list,
65 | logger=wandblogger if not cfg.no_logger else False,
66 | deterministic=True,
67 | num_sanity_val_steps=1,
68 | strategy=cfg.strategy,
69 | fast_dev_run=cfg.fast_dev_run,
70 | # precision="16"
71 | # accumulate_grad_batches=cfg.gradient_accumulation_steps,
72 | )
73 |
74 | if cfg.ckpt is None: trainer.fit(model)
75 | else: trainer.fit(model, ckpt_path=cfg.ckpt)
76 |
77 |
78 | if __name__ == "__main__":
79 |
80 | project_name = None
81 |
82 | args = parse_cmdl()
83 |
84 | # Training config
85 | cfg = Config(
86 | learning_rate=args.learning_rate,
87 | weight_decay=args.weight_decay,
88 | adam_epsilon=args.adam_epsilon,
89 | warmup_steps=args.warmup_steps,
90 | batch_size=args.batch_size,
91 | num_worker=args.num_worker,
92 | num_train_epochs=args.epochs,
93 | gradient_accumulation_steps=args.gacc,
94 | no_logger=args.no_logger,
95 | dataset=args.dataset,
96 | name=args.name,
97 | top_k=args.top_k,
98 | sample_rate=16_000,
99 | ckpt=args.ckpt,
100 | size=args.size,
101 | lora=args.lora,
102 | lora_ckpt=args.lora_ckpt,
103 | rank=args.rank,
104 | gran=args.gran,
105 | extra_gran_blocks=args.extra_gran_blocks,
106 | sim_stream=args.simulate_stream,
107 | fast_dev_run=args.fast_dev_run,
108 | early_stop=args.early_stop,
109 | strategy=args.strategy,
110 | streaming_train=args.streaming_train,
111 | streaming_random=args.streaming_random,
112 | streaming_fraction=args.streaming_fraction,
113 | seed=SEED,
114 | multilingual=args.multilingual,
115 | custom_len=args.custom_len
116 | )
117 |
118 | if cfg.streaming_train:
119 | assert cfg.sim_stream == cfg.streaming_train, "When running in full stream mode you must simulate streaming!"
120 | cfg.sim_stream = True
121 |
122 | lr_addition = f"_LR-{cfg.learning_rate}"
123 | effective_bsize = cfg.batch_size * cfg.gradient_accumulation_steps
124 |
125 | if cfg.lora and cfg.streaming_train:
126 | dir_name = f"LoRA_streamed_whisper_{cfg.size}_{cfg.dataset}_{effective_bsize}_{cfg.name}{lr_addition}_r{cfg.rank}_g{cfg.gran}_eg{cfg.extra_gran_blocks}_top{cfg.top_k}_full-stream{cfg.streaming_train}_random-order{cfg.streaming_random}_fraction{cfg.streaming_fraction}"
127 | project_name = "lora"
128 |
129 | # Run trainer
130 | train_model(
131 | log_output_dir=os.path.join(logs_root, dir_name),
132 | check_output_dir=os.path.join(ckpt_root, dir_name),
133 | model_name=args.size,
134 | train_set=ds_paths[args.dataset]['train'],
135 | val_set=ds_paths[args.dataset]['val'],
136 | train_name=dir_name,
137 | project_name=project_name,
138 | cfg=cfg
139 | )
140 |
141 |
--------------------------------------------------------------------------------
/training_code/datasets_classes.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pickle
3 | import pandas as pd
4 | import careless_whisper_stream
5 | import careless_whisper_stream.tokenizer
6 | from praatio import textgrid
7 | from dataclasses import dataclass
8 | from careless_whisper_stream.tokenizer import Tokenizer
9 | from careless_whisper_stream.audio import SpectrogramStream
10 |
11 | class WAVsDataset(torch.utils.data.Dataset):
12 | def __init__(self, ds_path: str,
13 | sep="\t",
14 | tokenizer: careless_whisper_stream.tokenizer = None,
15 | no_labels: bool = False,
16 | custom_len: int = 0,
17 | get_streamed_mel: bool = False) -> None:
18 | super().__init__()
19 |
20 | if not no_labels:
21 | self.tokenizer = tokenizer if tokenizer else careless_whisper_stream.tokenizer.get_tokenizer(True, language="en", task="transcribe")
22 | self.ds_df = pd.read_csv(ds_path, sep=sep)
23 | self.sr = 16_000
24 | self.no_labels = no_labels
25 | self.custom_len = custom_len
26 | self.get_streamed_mel = get_streamed_mel
27 |
28 | def __len__(self):
29 | return int(self.custom_len) if 0 < self.custom_len < len(self.ds_df) else int(len(self.ds_df))
30 |
31 | def _calc_mel(self, audio):
32 | if self.get_streamed_mel:
33 | spec_streamer = SpectrogramStream()
34 | return spec_streamer._simulate_streaming_log_spec(torch.tensor(audio)).squeeze(0)
35 |
36 | return careless_whisper_stream.log_mel_spectrogram(audio)
37 |
38 | def __getitem__(self, idx):
39 | item = self.ds_df.iloc[idx]
40 |
41 | audio = careless_whisper_stream.load_audio(item["wav_path"], sr=self.sr)
42 | audio = careless_whisper_stream.pad_or_trim(audio.flatten())
43 | mel = self._calc_mel(audio)
44 |
45 | if self.no_labels: return dict(input_ids=mel)
46 |
47 | text = item["raw_text"]
48 | text = [*self.tokenizer.sot_sequence_including_notimestamps] + self.tokenizer.encode(text)
49 | labels = text[1:] + [self.tokenizer.eot]
50 |
51 | return dict(input_ids=mel, labels=labels, dec_input_ids=text)
52 |
53 |
54 | @dataclass
55 | class Interval:
56 | label: str = None
57 | start: float = 0.0
58 | end: float = 0.0
59 |
60 | class AlignedTextGridDataset(torch.utils.data.Dataset):
61 | def __init__(self,
62 | ds_path: str,
63 | tokenizer: Tokenizer = None,
64 | sample_rate: int = 16_000,
65 | custom_len: int = 0,
66 | get_streamed_mel: bool = False,
67 | gran: int = 15,
68 | extra_gran_blocks: int = 0,
69 | n_mels: int = 80,
70 | multilingual: bool = False): # most of the times we train just on english librispeech
71 | super().__init__()
72 |
73 | self.tokenizer = tokenizer if tokenizer else careless_whisper_stream.tokenizer.get_tokenizer(True, language="en", task="transcribe")
74 | self.ds_df = pd.read_csv(ds_path)
75 | self.sr = sample_rate
76 | self.custom_len = custom_len
77 | self.get_streamed_mel = get_streamed_mel
78 | self.gran = gran
79 | self.extra_gran_blocks = extra_gran_blocks
80 | self.n_mels = n_mels
81 | self.multilingual = multilingual
82 |
83 | def __len__(self):
84 | return int(self.custom_len) if 0 < self.custom_len < len(self.ds_df) else len(self.ds_df)
85 |
86 | def _calc_mel(self, audio):
87 | if self.get_streamed_mel:
88 | spec_streamer = SpectrogramStream(n_mels=self.n_mels)
89 | return spec_streamer._simulate_streaming_log_spec(torch.tensor(audio))
90 |
91 | return careless_whisper_stream.log_mel_spectrogram(audio)
92 |
93 | def _get_intervals_from_wrd_file(self, path: str):
94 | with open(path, "r") as file:
95 | lines = file.readlines()
96 |
97 | intervals = []
98 | for line in lines:
99 | start, end, label = line.strip().split()
100 | intervals.append(Interval(label, int(start) / self.sr, int(end) / self.sr))
101 |
102 | return intervals
103 |
104 | def __getitem__(self, index):
105 | item = self.ds_df.iloc[index]
106 |
107 | audio = careless_whisper_stream.pad_or_trim(careless_whisper_stream.load_audio(item["wav_path"], sr=self.sr))
108 | mel = self._calc_mel(audio)
109 |
110 | if ".wrd" in item["tg_path"]:
111 | text_intervals = self._get_intervals_from_wrd_file(item["tg_path"])
112 | else:
113 | tg = textgrid.openTextgrid(item["tg_path"], includeEmptyIntervals=False)
114 | text_intervals = tg.getTier("words")
115 |
116 | tokenizer = self.tokenizer if not self.multilingual else careless_whisper_stream.tokenizer.get_tokenizer(True, language=item["lang"], task="transcribe")
117 |
118 | endpoints = [0, 0, 0]
119 | tokens = []
120 | for i, interval in enumerate(text_intervals):
121 | curr_tokens = self.tokenizer.encode(interval.label if i == 0 else " " + interval.label)
122 | n_diff = (interval.end - interval.start) / len(curr_tokens)
123 | endpoints.extend([interval.start + (i + 1) * n_diff for i in range(len(curr_tokens))])
124 | tokens.extend(curr_tokens)
125 |
126 | text = [*tokenizer.sot_sequence_including_notimestamps] + tokens
127 | labels = text[1:] + [self.tokenizer.eot]
128 | endpoints.append(endpoints[-1] + 0.5)
129 |
130 | assert len(endpoints) == len(labels) == len(text)
131 |
132 | return dict(input_ids=mel,
133 | dec_input_ids=torch.tensor(text),
134 | labels=torch.tensor(labels),
135 | endpoints=torch.tensor(endpoints))
136 |
137 |
138 | class TIMIT(torch.utils.data.Dataset):
139 | def __init__(self, ds_path: str, tokenizer: Tokenizer = None, n_state: int = 384) -> None:
140 |
141 | self.tokenizer = tokenizer if tokenizer else careless_whisper_stream.tokenizer.get_tokenizer(True, language="en", task="transcribe")
142 |
143 | with open(ds_path, 'rb') as file:
144 | self.dataset = pickle.load(file)
145 |
146 | self.n_state = n_state
147 |
148 | def __len__(self):
149 | return len(self.dataset)
150 |
151 | def __getitem__(self, index):
152 | audio, sr, text, _, _ = self.dataset[index]
153 | audio_len = audio.shape[-1]
154 | assert sr == 16000
155 | audio = careless_whisper_stream.pad_or_trim(torch.Tensor(audio).flatten())
156 | mel = careless_whisper_stream.log_mel_spectrogram(audio)
157 |
158 | text = [*self.tokenizer.sot_sequence_including_notimestamps] + self.tokenizer.encode(text)
159 | labels = text[1:] + [self.tokenizer.eot]
160 |
161 | num_frames = ((audio_len // 16000) * 50) + 2
162 | mask = torch.ones(1, 1500, self.n_state)
163 | mask[0, num_frames:, :] = 0
164 |
165 | return dict(
166 | input_ids=mel,
167 | labels=labels,
168 | dec_input_ids=text,
169 | mask=mask,
170 | )
171 |
--------------------------------------------------------------------------------
/careless_whisper_stream/streaming_transcribe.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('.')
3 |
4 | import argparse
5 | from typing import TYPE_CHECKING, List
6 |
7 | import numpy as np
8 | import torch
9 |
10 | from .audio import (
11 | SAMPLE_RATE,
12 | SpectrogramStream,
13 | MyStream
14 | )
15 | from .streaming_decoding import DecodingOptions
16 |
17 | if TYPE_CHECKING:
18 | from .streaming_model import StreamingWhisper
19 |
20 |
21 | def transcribe(
22 | model: "StreamingWhisper" = None,
23 | output_filename: str = None,
24 | channels: int = 2,
25 | language: str = "en",
26 | simulate_stream: bool = False,
27 | wav_file: str = None,
28 | single_frame_mel: bool = True,
29 | temperature: float = 0,
30 | beam_size: int = 5,
31 | stream_decode: bool = True,
32 | ca_kv_cache: bool = False,
33 | sa_kv_cache: bool = False,
34 | use_latency: bool = False,
35 | pad_trim: bool = False,
36 | max_sec_context: int = 30,
37 | streaming_timestamps: bool = False,
38 | force_first_tokens_timestamps: bool = False,
39 | **kwargs
40 | ) -> List[str]:
41 | """
42 | Open a stream and transcribe it using streaming whisper model
43 |
44 | A very thin implementation of the transcribe function, compared to Whisper implementation.
45 |
46 | Parameters
47 | ----------
48 | model: Whisper
49 | The Whisper model instance
50 |
51 | Returns -
52 | -------
53 | A dict with a text, tokens field with all of the text that was transcribed till the stream stopped.
54 | """
55 | model.reset(use_stream=True) # we first reset the model before starting a stream, cleaning any cache.
56 | model.eval()
57 |
58 | # Instantiate streaming instance and open a stream
59 | ms_gran = model.encoder.gran * 20
60 | stream_instance = MyStream(ms_gran,
61 | channels=channels,
62 | filename=output_filename,
63 | simulate_stream=simulate_stream,
64 | wav_file=wav_file,
65 | use_latency=use_latency,
66 | pad_trim=pad_trim,)
67 |
68 | stream_instance.open_stream()
69 |
70 | # frames - used only when filename is given, in order to save a long wav at the end of the conversation.
71 | frames = []
72 |
73 | # first we'll use
74 | decoding_options = DecodingOptions(
75 | language=language,
76 | gran=model.encoder.gran,
77 | single_frame_mel=single_frame_mel,
78 | without_timestamps=True,
79 | beam_size=beam_size if temperature == 0 else None,
80 | temperature=temperature,
81 | length_penalty=None,
82 | look_ahead_blocks=model.encoder.extra_gran_blocks,
83 | patience=None,
84 | stream_decode=stream_decode,
85 | use_kv_cache=sa_kv_cache,
86 | use_ca_kv_cache=ca_kv_cache,
87 | streaming_timestamps=streaming_timestamps,
88 | force_first_tokens_timestamps=force_first_tokens_timestamps,
89 | **kwargs
90 | )
91 |
92 | streamed_spectrogram = SpectrogramStream(n_mels=model.dims.n_mels) # default values are whisper default values
93 |
94 | texts = []
95 | reset_len = (max_sec_context * SAMPLE_RATE) + 360 # 360 is for the mel padding
96 | try:
97 | for frame in stream_instance.read():
98 | # save frames for optional save
99 | frames.extend(frame)
100 |
101 | if len(frames) > reset_len: # When we surpass the max_sec_context - reset model (positional embeddings constrain us)
102 | frame = np.concatenate((frames[-360:], frame))
103 | frames = []
104 | frames.extend(frame.tolist())
105 | model.reset(use_stream=True)
106 | streamed_spectrogram.reset()
107 |
108 | frame_tensor = torch.from_numpy(frame).pin_memory()
109 | mel_frame = streamed_spectrogram.calc_mel_with_new_frame(frame_tensor.to(model.device, non_blocking=True))
110 |
111 | # decode given the new mel frame and print results
112 | result = model.decode(mel_frame.squeeze(0), decoding_options)
113 |
114 | print(result.text)
115 |
116 | texts.append(result)
117 |
118 | except KeyboardInterrupt:
119 | stream_instance.close_stream(frames)
120 |
121 | print("Finished capturing audio.")
122 |
123 | return texts
124 |
125 |
126 | def cli():
127 | parser = argparse.ArgumentParser(description="Transcribe streaming audio with customizable options")
128 |
129 | # Model choices
130 | parser.add_argument("--model", type=str, default="small", help="Model size to transcribe with")
131 | parser.add_argument("--device", type=str, default="cpu", help="Device to run model inference on.")
132 | parser.add_argument("--chunk_size", type=int, default=300, help="Chunk size for streaming")
133 | parser.add_argument("--multilingual", action="store_true", help="Use a multilingual checkpoint if exists.", default=False)
134 |
135 | # Local streaming args
136 | parser.add_argument("--output_filename", type=str, help="Path to the output audio file when using local streaming")
137 | parser.add_argument("--channels", type=int, default=2, help="Number of audio channels - relevant for local streaming")
138 |
139 | # Streaming simulation wav file
140 | parser.add_argument("--wav_file", type=str, help="Optional WAV file path to stream, using a stream simulation")
141 |
142 | # Streaming behavior
143 | parser.add_argument("--simulate_stream", action="store_true", help="Simulate a stream from a file")
144 | parser.add_argument("--single_frame_mel", action="store_true", default=True, help="Use single frame MELs")
145 | parser.add_argument("--stream_decode", action="store_true", default=True, help="Use streaming decode")
146 | parser.add_argument("--ca_kv_cache", action="store_true", help="Use cross-attention key-value cache")
147 | parser.add_argument("--sa_kv_cache", action="store_true", help="Use self-attention key-value cache")
148 | parser.add_argument("--wait_for_all", action="store_true", help="Wait for all results before outputting")
149 | parser.add_argument("--use_latency", action="store_true", help="Add latency for streaming simulation")
150 | parser.add_argument("--pad_trim", action="store_true", default=False, help="Enable padding and trimming")
151 | parser.add_argument("--streaming_timestamps", action="store_true", help="Use timestamps in streaming")
152 | parser.add_argument("--force_first_tokens_timestamps", action="store_true", help="Force timestamps on first tokens")
153 |
154 | # Model behavior
155 | parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature")
156 | parser.add_argument("--beam_size", type=int, default=5, help="Beam size for beam search decoding")
157 | parser.add_argument("--language", type=str, default="en", help="Language of transcription")
158 | parser.add_argument("--max_sec_context", type=int, default=30, help="Max context window size in seconds")
159 |
160 | args = parser.parse_args().__dict__
161 |
162 | from . import load_streaming_model
163 |
164 | model_size: str = args.pop("model")
165 | chunk_size: int = args.pop("chunk_size")
166 | multilingual: bool = args.pop("multilingual")
167 | device: str = args.pop("device")
168 |
169 | model = load_streaming_model(model_size, chunk_size, multilingual, device)
170 |
171 | texts = transcribe(model, **args)
172 | return texts
173 |
174 | if __name__ == "__main__":
175 | cli()
176 |
177 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CarelessWhisper - Causal Whisper Streaming Model
2 | Causal Whisper Streaming is a fine tuned version of OpenAI Whisper, which can handle causal data and perform real-time transcription.
3 |
4 | [](https://arxiv.org/abs/2508.12301) [](https://huggingface.co/spaces/MLSpeech/CarelessWhisper-causal-streaming)
5 |
6 | ## 📄 Paper
7 |
8 | For more details, see our [paper](https://arxiv.org/abs/2508.12301).
9 |
10 | ## 🔧 Setup
11 | We used Python 3.9.16, PyTorch 2.6.0, and PyTorch-Lightning 2.5.0 to train and test our models.
12 | Portions of this code are adapted from [OpenAI's Whisper](https://github.com/openai/whisper).
13 |
14 | To set up the project environment using `conda`, follow these steps:
15 |
16 | 1. **Clone the repository**
17 | ```bash
18 | git clone https://github.com/tomer9080/CarelessWhisper-streaming
19 | cd CarelessWhisper-streaming
20 | ```
21 |
22 | > 💡 Make sure you have [Miniconda](https://docs.conda.io/en/latest/miniconda.html) or [Anaconda](https://www.anaconda.com/products/distribution) installed before proceeding.
23 |
24 | 2. **Create the conda environment**
25 | ```bash
26 | conda env create -f environment.yml
27 | ```
28 |
29 | 3. **Activate The environment**
30 | ```bash
31 | conda activate careless_whisper
32 | ```
33 |
34 | 4. **Install the appropriate PyTorch version**
35 | Depending on your hardware and CUDA version, install PyTorch by following the instructions at [https://pytorch.org/get-started/locally](https://pytorch.org/get-started/locally).
36 | This project was tested with CUDA 12.4, but it should also work with compatible earlier or later versions.
37 |
38 | After installing all of the dependencies, you can try to run inference.
39 |
40 | ## 🤖 Available Models
41 | We fine-tuned three different sizes of Whisper, all support english only transcription.
42 | A `large-v2` that was fine tuned on multilingual data is available, and supports English, French, Spanish, German and Portuguese with chunk size of 300 miliseconds.
43 |
44 | | Size | Chunk Size [msec] | Multilingual |
45 | |:----:|:-----------------:|:------------:|
46 | | base | 40, 100, 200, 300 | N/A |
47 | | small| 40, 100, 200, 300, 1000| N/A |
48 | |large-v2| 40, 100, 200, 300, 1000| 300 |
49 |
50 |
51 | ## 🎤 Running Inference
52 | To run inference, download the repo content, and run from the repository root accroding to following sections.
53 |
54 | > **Note:** The models are hosted on the [Hugging Face Hub](https://huggingface.co/), which requires an access token.
55 | > Make sure you are logged in with your token to access the models.
56 |
57 | ### How to Apply Your Hugging Face 🤗 Access Token
58 |
59 | 1. **Create a Hugging Face account** (if you don’t have one) at [https://huggingface.co/join](https://huggingface.co/join).
60 |
61 | 2. **Generate an access token:**
62 | - Go to your Hugging Face account settings: [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
63 | - Click on **"New token"**, give it a name, select the appropriate scopes (usually `read` is enough), and create it.
64 |
65 | 3. **Login using the Hugging Face CLI:**
66 | Install the CLI if you don’t have it:
67 | ```bash
68 | pip install huggingface_hub
69 | ```
70 | Then login:
71 | ```bash
72 | huggingface-cli login
73 | ```
74 | Paste your token when prompted.
75 |
76 |
77 | ### 🖥️ CLI Usage
78 | The transcription model is easily activated using the next command:
79 | ```bash
80 | # Using a local microphone for streaming transcription, dumping the recording to out.wav
81 | python transcribe.py \
82 | --output_filename out.wav \
83 | --channels 2 \
84 | --model small \
85 | --chunk_size 300 \
86 | --device cuda \
87 | --beam_size 5 \
88 | --ca_kv_cache \
89 | ```
90 |
91 | A simulation of a stream on a wav file is also available:
92 | ```bash
93 | # Simulating a stream on a wav file
94 | python transcribe.py \
95 | --model small \
96 | --chunk_size 300 \
97 | --device cuda \
98 | --beam_size 5 \
99 | --ca_kv_cache \
100 | --wav_file /path/to/audio.wav \
101 | --simulate_stream \
102 | --use_latency
103 | ```
104 |
105 | ### 🐍 Python Usage
106 | If you prefer using python, a code sinppet utilizing a microphone or a wav file is provided below:
107 |
108 | ```python
109 | import torch
110 | import careless_whisper_stream
111 |
112 | model_size = "small" # model size
113 | chunk_size = 300 # chunk size in milliseconds
114 | multilingual = False # currently on large-v2_300msec supports other languages than english.
115 | device = "cuda" if torch.cuda.is_available() else "cpu"
116 |
117 | model = careless_whisper_stream.load_streaming_model(name=model_size,
118 | gran=chunk_size,
119 | multilingual=multilingual,
120 | device=device)
121 |
122 | # using a local microphone recording
123 | texts_microphone = model.transcribe(output_filename="/path/to/dump/file.wav",
124 | channels=2,
125 | beam_size=5,
126 | ca_kv_cache=True)
127 |
128 | # Simulating on a wav file
129 | texts_wav_simulation = model.transcribe(simulate_stream=True,
130 | wav_file="/path/to/file/you/want/to/transcribe.wav",
131 | beam_size=5,
132 | ca_kv_cache=True)
133 | ```
134 |
135 | ## 🦾 Training
136 | In order to train using LoRA, you can use our existing code. Make sure all the requirements are installed.
137 |
138 | ### 📂 Dataset Structure
139 |
140 | Before starting model training using the command-line interface provided below, you must first configure your dataset dictionary file located at `training_code/ds_dict.py`.
141 |
142 | This file defines a Python dictionary named `ds_paths`, where you should specify paths to the `train`, `val`, and `test` partitions of your dataset. Each partition should be a CSV file with the following three columns:
143 |
144 | 1. `wav_path` — Path to the WAV audio file.
145 | 2. `tg_path` — Path to the corresponding `.TextGrid` file containing forced alignment.
146 | 3. `raw_text` — Ground truth transcription.
147 |
148 | > **Note:** The dictionary key (i.e., the name of the dataset) will be used by the training script to identify and load the dataset correctly.
149 |
150 | You can find an example entry in `training_code/ds_dict.py`.
151 |
152 | > **Note:** We used [Montreal Forced Aligner (MFA)](https://montreal-forced-aligner.readthedocs.io/en/latest/index.html) to force-align our dataset.
153 |
154 | To run the same force-alignment process as described in the paper, use:
155 |
156 | ```bash
157 | mfa align --clean /dataset/root/path english_us_arpa english_us_arpa /aligned_dataset/root/path
158 | ```
159 |
160 | For more details on how to run using `mfa` command, visit [MFA site](https://montreal-forced-aligner.readthedocs.io/en/latest/index.html).
161 |
162 | ### 🖥️ CLI Interface
163 | ```bash
164 | python training_code/train.py \
165 | --lora \
166 | --streaming_train \
167 | --simulate_stream \
168 | --dataset LIBRI-960-ALIGNED \
169 | --name example_training_base_model \
170 | --size base \
171 | --batch_size 32 \
172 | --epochs 10 \
173 | --learning_rate 1e-5 \
174 | --rank 32 \
175 | --gran 15 \
176 | --extra_gran_blocks 1 \
177 | --streaming_fraction 0.25 \
178 | --top_k 5 \
179 | ```
180 |
181 | For more options and training configurations, run:
182 | ```bash
183 | python training_code/train.py --help
184 | ```
185 |
186 | ## 📜 License
187 |
188 | This repository uses a dual license:
189 |
190 | [](https://opensource.org/licenses/MIT)
191 | Portions derived from [OpenAI Whisper](https://github.com/openai/whisper) are licensed under the **MIT License**.
192 |
193 | [](https://creativecommons.org/licenses/by-nc/4.0/)
194 | All other original code in this repository is licensed under the **Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0)**.
195 |
196 | See the [LICENSE](./LICENSE) file for full details.
197 |
--------------------------------------------------------------------------------
/careless_whisper_stream/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import re
4 | import sys
5 | import zlib
6 | from typing import Callable, List, Optional, TextIO
7 |
8 | system_encoding = sys.getdefaultencoding()
9 |
10 | if system_encoding != "utf-8":
11 |
12 | def make_safe(string):
13 | # replaces any character not representable using the system default encoding with an '?',
14 | # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
15 | return string.encode(system_encoding, errors="replace").decode(system_encoding)
16 |
17 | else:
18 |
19 | def make_safe(string):
20 | # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
21 | return string
22 |
23 |
24 | def exact_div(x, y):
25 | assert x % y == 0
26 | return x // y
27 |
28 |
29 | def str2bool(string):
30 | str2val = {"True": True, "False": False}
31 | if string in str2val:
32 | return str2val[string]
33 | else:
34 | raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
35 |
36 |
37 | def optional_int(string):
38 | return None if string == "None" else int(string)
39 |
40 |
41 | def optional_float(string):
42 | return None if string == "None" else float(string)
43 |
44 |
45 | def compression_ratio(text) -> float:
46 | text_bytes = text.encode("utf-8")
47 | return len(text_bytes) / len(zlib.compress(text_bytes))
48 |
49 |
50 | def format_timestamp(
51 | seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
52 | ):
53 | assert seconds >= 0, "non-negative timestamp expected"
54 | milliseconds = round(seconds * 1000.0)
55 |
56 | hours = milliseconds // 3_600_000
57 | milliseconds -= hours * 3_600_000
58 |
59 | minutes = milliseconds // 60_000
60 | milliseconds -= minutes * 60_000
61 |
62 | seconds = milliseconds // 1_000
63 | milliseconds -= seconds * 1_000
64 |
65 | hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
66 | return (
67 | f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
68 | )
69 |
70 |
71 | def get_start(segments: List[dict]) -> Optional[float]:
72 | return next(
73 | (w["start"] for s in segments for w in s["words"]),
74 | segments[0]["start"] if segments else None,
75 | )
76 |
77 |
78 | def get_end(segments: List[dict]) -> Optional[float]:
79 | return next(
80 | (w["end"] for s in reversed(segments) for w in reversed(s["words"])),
81 | segments[-1]["end"] if segments else None,
82 | )
83 |
84 |
85 | class ResultWriter:
86 | extension: str
87 |
88 | def __init__(self, output_dir: str):
89 | self.output_dir = output_dir
90 |
91 | def __call__(
92 | self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
93 | ):
94 | audio_basename = os.path.basename(audio_path)
95 | audio_basename = os.path.splitext(audio_basename)[0]
96 | output_path = os.path.join(
97 | self.output_dir, audio_basename + "." + self.extension
98 | )
99 |
100 | with open(output_path, "w", encoding="utf-8") as f:
101 | self.write_result(result, file=f, options=options, **kwargs)
102 |
103 | def write_result(
104 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
105 | ):
106 | raise NotImplementedError
107 |
108 |
109 | class WriteTXT(ResultWriter):
110 | extension: str = "txt"
111 |
112 | def write_result(
113 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
114 | ):
115 | for segment in result["segments"]:
116 | print(segment["text"].strip(), file=file, flush=True)
117 |
118 |
119 | class SubtitlesWriter(ResultWriter):
120 | always_include_hours: bool
121 | decimal_marker: str
122 |
123 | def iterate_result(
124 | self,
125 | result: dict,
126 | options: Optional[dict] = None,
127 | *,
128 | max_line_width: Optional[int] = None,
129 | max_line_count: Optional[int] = None,
130 | highlight_words: bool = False,
131 | max_words_per_line: Optional[int] = None,
132 | ):
133 | options = options or {}
134 | max_line_width = max_line_width or options.get("max_line_width")
135 | max_line_count = max_line_count or options.get("max_line_count")
136 | highlight_words = highlight_words or options.get("highlight_words", False)
137 | max_words_per_line = max_words_per_line or options.get("max_words_per_line")
138 | preserve_segments = max_line_count is None or max_line_width is None
139 | max_line_width = max_line_width or 1000
140 | max_words_per_line = max_words_per_line or 1000
141 |
142 | def iterate_subtitles():
143 | line_len = 0
144 | line_count = 1
145 | # the next subtitle to yield (a list of word timings with whitespace)
146 | subtitle: List[dict] = []
147 | last: float = get_start(result["segments"]) or 0.0
148 | for segment in result["segments"]:
149 | chunk_index = 0
150 | words_count = max_words_per_line
151 | while chunk_index < len(segment["words"]):
152 | remaining_words = len(segment["words"]) - chunk_index
153 | if max_words_per_line > len(segment["words"]) - chunk_index:
154 | words_count = remaining_words
155 | for i, original_timing in enumerate(
156 | segment["words"][chunk_index : chunk_index + words_count]
157 | ):
158 | timing = original_timing.copy()
159 | long_pause = (
160 | not preserve_segments and timing["start"] - last > 3.0
161 | )
162 | has_room = line_len + len(timing["word"]) <= max_line_width
163 | seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
164 | if (
165 | line_len > 0
166 | and has_room
167 | and not long_pause
168 | and not seg_break
169 | ):
170 | # line continuation
171 | line_len += len(timing["word"])
172 | else:
173 | # new line
174 | timing["word"] = timing["word"].strip()
175 | if (
176 | len(subtitle) > 0
177 | and max_line_count is not None
178 | and (long_pause or line_count >= max_line_count)
179 | or seg_break
180 | ):
181 | # subtitle break
182 | yield subtitle
183 | subtitle = []
184 | line_count = 1
185 | elif line_len > 0:
186 | # line break
187 | line_count += 1
188 | timing["word"] = "\n" + timing["word"]
189 | line_len = len(timing["word"].strip())
190 | subtitle.append(timing)
191 | last = timing["start"]
192 | chunk_index += max_words_per_line
193 | if len(subtitle) > 0:
194 | yield subtitle
195 |
196 | if len(result["segments"]) > 0 and "words" in result["segments"][0]:
197 | for subtitle in iterate_subtitles():
198 | subtitle_start = self.format_timestamp(subtitle[0]["start"])
199 | subtitle_end = self.format_timestamp(subtitle[-1]["end"])
200 | subtitle_text = "".join([word["word"] for word in subtitle])
201 | if highlight_words:
202 | last = subtitle_start
203 | all_words = [timing["word"] for timing in subtitle]
204 | for i, this_word in enumerate(subtitle):
205 | start = self.format_timestamp(this_word["start"])
206 | end = self.format_timestamp(this_word["end"])
207 | if last != start:
208 | yield last, start, subtitle_text
209 |
210 | yield start, end, "".join(
211 | [
212 | re.sub(r"^(\s*)(.*)$", r"\1\2", word)
213 | if j == i
214 | else word
215 | for j, word in enumerate(all_words)
216 | ]
217 | )
218 | last = end
219 | else:
220 | yield subtitle_start, subtitle_end, subtitle_text
221 | else:
222 | for segment in result["segments"]:
223 | segment_start = self.format_timestamp(segment["start"])
224 | segment_end = self.format_timestamp(segment["end"])
225 | segment_text = segment["text"].strip().replace("-->", "->")
226 | yield segment_start, segment_end, segment_text
227 |
228 | def format_timestamp(self, seconds: float):
229 | return format_timestamp(
230 | seconds=seconds,
231 | always_include_hours=self.always_include_hours,
232 | decimal_marker=self.decimal_marker,
233 | )
234 |
235 |
236 | class WriteVTT(SubtitlesWriter):
237 | extension: str = "vtt"
238 | always_include_hours: bool = False
239 | decimal_marker: str = "."
240 |
241 | def write_result(
242 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
243 | ):
244 | print("WEBVTT\n", file=file)
245 | for start, end, text in self.iterate_result(result, options, **kwargs):
246 | print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
247 |
248 |
249 | class WriteSRT(SubtitlesWriter):
250 | extension: str = "srt"
251 | always_include_hours: bool = True
252 | decimal_marker: str = ","
253 |
254 | def write_result(
255 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
256 | ):
257 | for i, (start, end, text) in enumerate(
258 | self.iterate_result(result, options, **kwargs), start=1
259 | ):
260 | print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
261 |
262 |
263 | class WriteTSV(ResultWriter):
264 | """
265 | Write a transcript to a file in TSV (tab-separated values) format containing lines like:
266 | \t\t
267 |
268 | Using integer milliseconds as start and end times means there's no chance of interference from
269 | an environment setting a language encoding that causes the decimal in a floating point number
270 | to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
271 | """
272 |
273 | extension: str = "tsv"
274 |
275 | def write_result(
276 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
277 | ):
278 | print("start", "end", "text", sep="\t", file=file)
279 | for segment in result["segments"]:
280 | print(round(1000 * segment["start"]), file=file, end="\t")
281 | print(round(1000 * segment["end"]), file=file, end="\t")
282 | print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
283 |
284 |
285 | class WriteJSON(ResultWriter):
286 | extension: str = "json"
287 |
288 | def write_result(
289 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
290 | ):
291 | json.dump(result, file)
292 |
293 |
294 | def get_writer(
295 | output_format: str, output_dir: str
296 | ) -> Callable[[dict, TextIO, dict], None]:
297 | writers = {
298 | "txt": WriteTXT,
299 | "vtt": WriteVTT,
300 | "srt": WriteSRT,
301 | "tsv": WriteTSV,
302 | "json": WriteJSON,
303 | }
304 |
305 | if output_format == "all":
306 | all_writers = [writer(output_dir) for writer in writers.values()]
307 |
308 | def write_all(
309 | result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
310 | ):
311 | for writer in all_writers:
312 | writer(result, file, options, **kwargs)
313 |
314 | return write_all
315 |
316 | return writers[output_format](output_dir)
317 |
--------------------------------------------------------------------------------
/careless_whisper_stream/model.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import gzip
3 | from dataclasses import dataclass
4 | from typing import Dict, Iterable, Optional, Tuple
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | from torch import Tensor, nn
10 |
11 | from .decoding import decode as decode_function
12 | from .decoding import detect_language as detect_language_function
13 | from .transcribe import transcribe as transcribe_function
14 |
15 |
16 | @dataclass
17 | class ModelDimensions:
18 | n_mels: int
19 | n_audio_ctx: int
20 | n_audio_state: int
21 | n_audio_head: int
22 | n_audio_layer: int
23 | n_vocab: int
24 | n_text_ctx: int
25 | n_text_state: int
26 | n_text_head: int
27 | n_text_layer: int
28 |
29 |
30 | class LayerNorm(nn.LayerNorm):
31 | def forward(self, x: Tensor) -> Tensor:
32 | return super().forward(x.float()).type(x.dtype)
33 |
34 |
35 | class Linear(nn.Linear):
36 | def forward(self, x: Tensor) -> Tensor:
37 | return F.linear(
38 | x,
39 | self.weight.to(x.dtype),
40 | None if self.bias is None else self.bias.to(x.dtype),
41 | )
42 |
43 |
44 | class Conv1d(nn.Conv1d):
45 | def _conv_forward(
46 | self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
47 | ) -> Tensor:
48 | return super()._conv_forward(
49 | x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
50 | )
51 |
52 |
53 | def sinusoids(length, channels, max_timescale=10000):
54 | """Returns sinusoids for positional embedding"""
55 | assert channels % 2 == 0
56 | log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
57 | inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
58 | scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
59 | return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
60 |
61 |
62 | class MultiHeadAttention(nn.Module):
63 | def __init__(self, n_state: int, n_head: int):
64 | super().__init__()
65 | self.n_head = n_head
66 | self.query = Linear(n_state, n_state)
67 | self.key = Linear(n_state, n_state, bias=False)
68 | self.value = Linear(n_state, n_state)
69 | self.out = Linear(n_state, n_state)
70 |
71 | def forward(
72 | self,
73 | x: Tensor,
74 | xa: Optional[Tensor] = None,
75 | mask: Optional[Tensor] = None,
76 | kv_cache: Optional[dict[any, Tensor]] = None,
77 | ):
78 | q = self.query(x)
79 |
80 | if kv_cache is None or xa is None or self.key not in kv_cache:
81 | # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
82 | # otherwise, perform key/value projections for self- or cross-attention as usual.
83 | k = self.key(x if xa is None else xa)
84 | v = self.value(x if xa is None else xa)
85 | else:
86 | # for cross-attention, calculate keys and values once and reuse in subsequent calls.
87 | k = kv_cache[self.key]
88 | v = kv_cache[self.value]
89 |
90 | wv, qk = self.qkv_attention(q, k, v, mask)
91 |
92 | return self.out(wv), qk
93 |
94 | def qkv_attention(
95 | self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
96 | ):
97 | # print(f"q shape: {q.shape}")
98 | n_batch, n_ctx, n_state = q.shape
99 | scale = (n_state // self.n_head) ** -0.25
100 | q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
101 | k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
102 | v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
103 |
104 | qk = q @ k
105 | if mask is not None:
106 | qk = qk + mask[:n_ctx, :n_ctx]
107 | qk = qk.float()
108 |
109 | w = F.softmax(qk, dim=-1).to(q.dtype)
110 | return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
111 |
112 |
113 | class ResidualAttentionBlock(nn.Module):
114 | def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
115 | super().__init__()
116 |
117 | self.attn = MultiHeadAttention(n_state, n_head)
118 | self.attn_ln = LayerNorm(n_state)
119 |
120 | self.cross_attn = (
121 | MultiHeadAttention(n_state, n_head) if cross_attention else None
122 | )
123 | self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
124 |
125 | n_mlp = n_state * 4
126 | self.mlp = nn.Sequential(
127 | Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
128 | )
129 | self.mlp_ln = LayerNorm(n_state)
130 |
131 | def forward(
132 | self,
133 | x: Tensor,
134 | xa: Optional[Tensor] = None,
135 | mask: Optional[Tensor] = None,
136 | kv_cache: Optional[dict] = None,
137 | ):
138 | # SA
139 | x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
140 |
141 | # CA
142 | if self.cross_attn:
143 | x = x + self.cross_attn(self.cross_attn_ln(x), xa)[0]
144 |
145 | # MLP
146 | x = x + self.mlp(self.mlp_ln(x))
147 |
148 | return x
149 |
150 |
151 | class AudioEncoder(nn.Module):
152 | def __init__(
153 | self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
154 | ):
155 | super().__init__()
156 | self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
157 | self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
158 | self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
159 |
160 | self.n_head = n_head
161 | self.n_layer = n_layer
162 | self.n_state = n_state
163 |
164 | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
165 | [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
166 | )
167 | self.ln_post = LayerNorm(n_state)
168 |
169 | def forward(self, x: Tensor):
170 | """
171 | x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
172 | the mel spectrogram of the audio
173 | """
174 | x = F.gelu(self.conv1(x))
175 | x = F.gelu(self.conv2(x))
176 | x = x.permute(0, 2, 1)
177 |
178 | assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
179 | x = (x + self.positional_embedding).to(x.dtype)
180 |
181 | for block in self.blocks:
182 | x = block(x)
183 |
184 | x = self.ln_post(x)
185 | return x
186 |
187 |
188 | class TextDecoder(nn.Module):
189 | def __init__(
190 | self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
191 | ):
192 | super().__init__()
193 |
194 | self.token_embedding = nn.Embedding(n_vocab, n_state)
195 | self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
196 |
197 | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
198 | [
199 | ResidualAttentionBlock(n_state, n_head, cross_attention=True)
200 | for _ in range(n_layer)
201 | ]
202 | )
203 | self.ln = LayerNorm(n_state)
204 |
205 | mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
206 | self.register_buffer("mask", mask, persistent=False)
207 |
208 | def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
209 | """
210 | x : torch.LongTensor, shape = (batch_size, <= n_ctx)
211 | the text tokens
212 | xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
213 | the encoded audio features to be attended on
214 | dump_type: str - specifies which dump to return (MLP, pre_MLP, ATT)
215 | """
216 | offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
217 | x = (
218 | self.token_embedding(x)
219 | + self.positional_embedding[offset : offset + x.shape[-1]]
220 | )
221 | x = x.to(xa.dtype)
222 |
223 | for block in self.blocks:
224 | x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
225 |
226 | x = self.ln(x)
227 | logits = (
228 | x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
229 | ).float()
230 |
231 | return logits
232 |
233 |
234 | class Whisper(nn.Module):
235 | def __init__(self, dims: ModelDimensions):
236 | super().__init__()
237 | self.dims = dims
238 | self.encoder = AudioEncoder(
239 | self.dims.n_mels,
240 | self.dims.n_audio_ctx,
241 | self.dims.n_audio_state,
242 | self.dims.n_audio_head,
243 | self.dims.n_audio_layer,
244 | )
245 | self.decoder = TextDecoder(
246 | self.dims.n_vocab,
247 | self.dims.n_text_ctx,
248 | self.dims.n_text_state,
249 | self.dims.n_text_head,
250 | self.dims.n_text_layer,
251 | )
252 | # use the last half among the decoder layers for time alignment by default;
253 | # to use a specific set of heads, see `set_alignment_heads()` below.
254 | all_heads = torch.zeros(
255 | self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
256 | )
257 | all_heads[self.dims.n_text_layer // 2 :] = True
258 | # self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
259 | self.register_buffer("alignment_heads", all_heads, persistent=False) # To use lightning can't use sparse weights
260 |
261 | def set_alignment_heads(self, dump: bytes):
262 | array = np.frombuffer(
263 | gzip.decompress(base64.b85decode(dump)), dtype=bool
264 | ).copy()
265 | mask = torch.from_numpy(array).reshape(
266 | self.dims.n_text_layer, self.dims.n_text_head
267 | )
268 | # self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
269 | self.register_buffer("alignment_heads", mask, persistent=False) # To use lightning can't use sparse weights
270 |
271 | def embed_audio(self, mel: torch.Tensor):
272 | return self.encoder(mel)
273 |
274 | def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
275 | return self.decoder(tokens, audio_features)
276 |
277 | def forward(
278 | self, mel: torch.Tensor, tokens: torch.Tensor
279 | ) -> Dict[str, torch.Tensor]:
280 | return self.decoder(tokens, self.encoder(mel))
281 |
282 | @property
283 | def device(self):
284 | return next(self.parameters()).device
285 |
286 | @property
287 | def is_multilingual(self):
288 | return self.dims.n_vocab >= 51865
289 |
290 | @property
291 | def num_languages(self):
292 | return self.dims.n_vocab - 51765 - int(self.is_multilingual)
293 |
294 | def install_kv_cache_hooks(self, cache: Optional[dict] = None):
295 | """
296 | The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
297 | tensors calculated for the previous positions. This method returns a dictionary that stores
298 | all caches, and the necessary hooks for the key and value projection modules that save the
299 | intermediate tensors to be reused during later calculations.
300 |
301 | Returns
302 | -------
303 | cache : Dict[nn.Module, torch.Tensor]
304 | A dictionary object mapping the key/value projection modules to its cache
305 | hooks : List[RemovableHandle]
306 | List of PyTorch RemovableHandle objects to stop the hooks to be called
307 | """
308 | cache = {**cache} if cache is not None else {}
309 | hooks = []
310 |
311 | def save_to_cache(module, _, output):
312 | if module not in cache or output.shape[1] > self.dims.n_text_ctx:
313 | # save as-is, for the first token or cross attention
314 | cache[module] = output
315 | else:
316 | cache[module] = torch.cat([cache[module], output], dim=1).detach()
317 | return cache[module]
318 |
319 | def install_hooks(layer: nn.Module):
320 | if isinstance(layer, MultiHeadAttention):
321 | hooks.append(layer.key.register_forward_hook(save_to_cache))
322 | hooks.append(layer.value.register_forward_hook(save_to_cache))
323 |
324 | self.decoder.apply(install_hooks)
325 | return cache, hooks
326 |
327 | detect_language = detect_language_function
328 | transcribe = transcribe_function
329 | decode = decode_function
330 |
--------------------------------------------------------------------------------
/careless_whisper_stream/__init__.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import io
3 | import os
4 | import urllib
5 | import warnings
6 | from typing import List, Optional, Union
7 |
8 | import torch
9 | from tqdm import tqdm
10 |
11 | from .audio import load_audio, pad_or_trim, log_mel_spectrogram
12 | from .model import ModelDimensions, Whisper
13 | from .streaming_model import StreamingWhisper
14 | from .version import __version__
15 |
16 | _MODELS = {
17 | "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
18 | "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
19 | "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
20 | "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
21 | "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
22 | "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
23 | "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
24 | "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
25 | "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
26 | "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
27 | "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
28 | "large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
29 | "large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
30 | "turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
31 | }
32 |
33 | _STREAMING_MODELS_HF = {
34 | "base": {
35 | "300": "base_300.pt",
36 | "200": "base_200.pt",
37 | "100": "base_100.pt",
38 | "40": "base_40.pt",
39 | },
40 | "small": {
41 | "1000": "small_1000.pt",
42 | "300": "small_300.pt",
43 | "200": "small_200.pt",
44 | "100": "small_100.pt",
45 | "40": "small_40.pt",
46 | },
47 | "large-v2": {
48 | "1000": "large-v2_1000.pt",
49 | "300": "large-v2_300.pt",
50 | "200": "large-v2_200.pt",
51 | "100": "large-v2_100.pt",
52 | "40": "large-v2_40.pt",
53 | "300-multi": "large-v2_300_multi.pt",
54 | }
55 | }
56 |
57 | # base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
58 | # highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
59 | _ALIGNMENT_HEADS = {
60 | "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
61 | "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
62 | "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
63 | "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
65 | "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P%R7%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
68 | "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
70 | "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
71 | "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
72 | "large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
73 | "turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
74 | }
75 |
76 |
77 | def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
78 | os.makedirs(root, exist_ok=True)
79 |
80 | expected_sha256 = url.split("/")[-2]
81 | download_target = os.path.join(root, os.path.basename(url))
82 |
83 | if os.path.exists(download_target) and not os.path.isfile(download_target):
84 | raise RuntimeError(f"{download_target} exists and is not a regular file")
85 |
86 | if os.path.isfile(download_target):
87 | with open(download_target, "rb") as f:
88 | model_bytes = f.read()
89 | if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
90 | return model_bytes if in_memory else download_target
91 | else:
92 | warnings.warn(
93 | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
94 | )
95 |
96 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
97 | with tqdm(
98 | total=int(source.info().get("Content-Length")),
99 | ncols=80,
100 | unit="iB",
101 | unit_scale=True,
102 | unit_divisor=1024,
103 | ) as loop:
104 | while True:
105 | buffer = source.read(8192)
106 | if not buffer:
107 | break
108 |
109 | output.write(buffer)
110 | loop.update(len(buffer))
111 |
112 | model_bytes = open(download_target, "rb").read()
113 | if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
114 | raise RuntimeError(
115 | "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
116 | )
117 |
118 | return model_bytes if in_memory else download_target
119 |
120 |
121 | def available_models() -> List[str]:
122 | """Returns the names of available models"""
123 | return list(_MODELS.keys())
124 |
125 |
126 | def load_model(
127 | name: str,
128 | device: Optional[Union[str, torch.device]] = None,
129 | download_root: str = None,
130 | in_memory: bool = False,
131 | ) -> Whisper:
132 | """
133 | Load a Whisper ASR model
134 |
135 | Parameters
136 | ----------
137 | name : str
138 | one of the official model names listed by `whisper.available_models()`, or
139 | path to a model checkpoint containing the model dimensions and the model state_dict.
140 | device : Union[str, torch.device]
141 | the PyTorch device to put the model into
142 | download_root: str
143 | path to download the model files; by default, it uses "~/.cache/whisper"
144 | in_memory: bool
145 | whether to preload the model weights into host memory
146 |
147 | Returns
148 | -------
149 | model : Whisper
150 | The Whisper ASR model instance
151 | """
152 |
153 | if device is None:
154 | device = "cuda" if torch.cuda.is_available() else "cpu"
155 | if download_root is None:
156 | default = os.path.join(os.path.expanduser("~"), ".cache")
157 | download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
158 |
159 | if name in _MODELS:
160 | checkpoint_file = _download(_MODELS[name], download_root, in_memory)
161 | alignment_heads = _ALIGNMENT_HEADS[name]
162 | elif os.path.isfile(name):
163 | checkpoint_file = open(name, "rb").read() if in_memory else name
164 | alignment_heads = None
165 | else:
166 | raise RuntimeError(
167 | f"Model {name} not found; available models = {available_models()}"
168 | )
169 |
170 | with (
171 | io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
172 | ) as fp:
173 | checkpoint = torch.load(fp, map_location=device)
174 | del checkpoint_file
175 |
176 | dims = ModelDimensions(**checkpoint["dims"])
177 | model = Whisper(dims)
178 | model.load_state_dict(checkpoint["model_state_dict"])
179 |
180 | if alignment_heads is not None:
181 | model.set_alignment_heads(alignment_heads)
182 |
183 | return model.to(device)
184 |
185 |
186 | def load_streaming_model_for_train(
187 | name: str,
188 | advisor_ckpt_path: str = None,
189 | ft_model_ckpt_path: str = None,
190 | device: Optional[Union[str, torch.device]] = None,
191 | download_root: str = None,
192 | in_memory: bool = False,
193 | cache_gran: bool = True,
194 | gran: int = 15,
195 | rank: int = 8,
196 | extra_gran_blocks: int = 0,
197 | n_advisor_class: int = 4,
198 | **kwargs: any
199 | ) -> StreamingWhisper:
200 | """
201 | Load a StreamingWhisper ASR model
202 |
203 | Parameters
204 | ----------
205 | name : str
206 | one of the official model names listed by `whisper.available_models()`, or
207 | path to a model checkpoint containing the model dimensions and the model state_dict.
208 | device : Union[str, torch.device]
209 | the PyTorch device to put the model into
210 | download_root: str
211 | path to download the model files; by default, it uses "~/.cache/whisper"
212 | in_memory: bool
213 | whether to preload the model weights into host memory
214 |
215 | Returns
216 | -------
217 | model : Whisper
218 | The Whisper ASR model instance
219 | """
220 | if ft_model_ckpt_path is None:
221 | if device is None:
222 | device = "cuda" if torch.cuda.is_available() else "cpu"
223 | if download_root is None:
224 | default = os.path.join(os.path.expanduser("~"), ".cache")
225 | download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
226 |
227 | if name in _MODELS:
228 | checkpoint_file = _download(_MODELS[name], download_root, in_memory)
229 | alignment_heads = _ALIGNMENT_HEADS[name]
230 | elif os.path.isfile(name):
231 | checkpoint_file = open(name, "rb").read() if in_memory else name
232 | alignment_heads = None
233 | else:
234 | raise RuntimeError(
235 | f"Model {name} not found; available models = {available_models()}"
236 | )
237 |
238 | with (
239 | io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
240 | ) as fp:
241 | checkpoint = torch.load(fp, map_location=device)
242 | del checkpoint_file
243 | else:
244 | checkpoint = torch.load(ft_model_ckpt_path, weights_only=False)
245 |
246 | decoder_advisor_chkpt = torch.load(advisor_ckpt_path, weights_only=False) if advisor_ckpt_path is not None else {"state_dict": {}}
247 | advisor_state_dict = {k: v for k, v in decoder_advisor_chkpt["state_dict"].items() if "decoder_advisor" in k}
248 |
249 | whisper_dict = checkpoint["model_state_dict"] if "model_state_dict" in checkpoint.keys() else checkpoint["state_dict"]
250 |
251 | whisper_dict = {k.replace("weight", "base_layer.weight") if "attn." in k and "weight" in k
252 | else k.replace("bias", "base_layer.bias") if "attn." in k and "bias" in k
253 | else k: v for k, v in whisper_dict.items()}
254 |
255 | streaming_whisper_state_dict = {**advisor_state_dict, **whisper_dict}
256 |
257 | dims = ModelDimensions(**checkpoint["dims"])
258 |
259 | model = StreamingWhisper(dims,
260 | cache_gran=cache_gran,
261 | gran=gran,
262 | rank=rank,
263 | extra_gran_blocks=extra_gran_blocks)
264 |
265 | model.load_state_dict(streaming_whisper_state_dict, strict=False)
266 |
267 | # for n, p in model.named_parameters():
268 | # print(n, p)
269 |
270 | if ft_model_ckpt_path is None and alignment_heads is not None:
271 | model.set_alignment_heads(alignment_heads)
272 |
273 | return model.to(device)
274 |
275 |
276 | def load_streaming_model(
277 | name: str,
278 | gran: int = 300,
279 | multilingual: bool = False,
280 | device: Optional[Union[str, torch.device]] = None,
281 | download_root: str = None,
282 | in_memory: bool = False,
283 | ) -> StreamingWhisper:
284 |
285 | subname = (str(gran) + '-multi') if multilingual else str(gran)
286 |
287 | from huggingface_hub import hf_hub_download
288 |
289 | try:
290 | ckpt_path = hf_hub_download(repo_id="MLSpeech/CarelessWhisper-Streaming", filename=_STREAMING_MODELS_HF[name][subname], repo_type="model", token=True)
291 | except KeyError as e:
292 | print(f"Streaming model with the next configs: size {name}, multilingual: {multilingual} and chunk size: {gran} is not available.")
293 |
294 | checkpoint = torch.load(ckpt_path, weights_only=False)
295 |
296 | dims = ModelDimensions(**checkpoint["dims"])
297 |
298 | model = StreamingWhisper(dims,
299 | gran=checkpoint['cfg']['gran'],
300 | rank=checkpoint['cfg']['rank'],
301 | extra_gran_blocks=checkpoint['cfg']['extra_gran_blocks'])
302 |
303 | model.load_state_dict(checkpoint['state_dict'], strict=False)
304 |
305 | return model.to(device)
306 |
--------------------------------------------------------------------------------
/careless_whisper_stream/tokenizer.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import os
3 | import string
4 | from dataclasses import dataclass, field
5 | from functools import cached_property, lru_cache
6 | from typing import Dict, List, Optional, Tuple
7 |
8 | import tiktoken
9 |
10 | LANGUAGES = {
11 | "en": "english",
12 | "zh": "chinese",
13 | "de": "german",
14 | "es": "spanish",
15 | "ru": "russian",
16 | "ko": "korean",
17 | "fr": "french",
18 | "ja": "japanese",
19 | "pt": "portuguese",
20 | "tr": "turkish",
21 | "pl": "polish",
22 | "ca": "catalan",
23 | "nl": "dutch",
24 | "ar": "arabic",
25 | "sv": "swedish",
26 | "it": "italian",
27 | "id": "indonesian",
28 | "hi": "hindi",
29 | "fi": "finnish",
30 | "vi": "vietnamese",
31 | "he": "hebrew",
32 | "uk": "ukrainian",
33 | "el": "greek",
34 | "ms": "malay",
35 | "cs": "czech",
36 | "ro": "romanian",
37 | "da": "danish",
38 | "hu": "hungarian",
39 | "ta": "tamil",
40 | "no": "norwegian",
41 | "th": "thai",
42 | "ur": "urdu",
43 | "hr": "croatian",
44 | "bg": "bulgarian",
45 | "lt": "lithuanian",
46 | "la": "latin",
47 | "mi": "maori",
48 | "ml": "malayalam",
49 | "cy": "welsh",
50 | "sk": "slovak",
51 | "te": "telugu",
52 | "fa": "persian",
53 | "lv": "latvian",
54 | "bn": "bengali",
55 | "sr": "serbian",
56 | "az": "azerbaijani",
57 | "sl": "slovenian",
58 | "kn": "kannada",
59 | "et": "estonian",
60 | "mk": "macedonian",
61 | "br": "breton",
62 | "eu": "basque",
63 | "is": "icelandic",
64 | "hy": "armenian",
65 | "ne": "nepali",
66 | "mn": "mongolian",
67 | "bs": "bosnian",
68 | "kk": "kazakh",
69 | "sq": "albanian",
70 | "sw": "swahili",
71 | "gl": "galician",
72 | "mr": "marathi",
73 | "pa": "punjabi",
74 | "si": "sinhala",
75 | "km": "khmer",
76 | "sn": "shona",
77 | "yo": "yoruba",
78 | "so": "somali",
79 | "af": "afrikaans",
80 | "oc": "occitan",
81 | "ka": "georgian",
82 | "be": "belarusian",
83 | "tg": "tajik",
84 | "sd": "sindhi",
85 | "gu": "gujarati",
86 | "am": "amharic",
87 | "yi": "yiddish",
88 | "lo": "lao",
89 | "uz": "uzbek",
90 | "fo": "faroese",
91 | "ht": "haitian creole",
92 | "ps": "pashto",
93 | "tk": "turkmen",
94 | "nn": "nynorsk",
95 | "mt": "maltese",
96 | "sa": "sanskrit",
97 | "lb": "luxembourgish",
98 | "my": "myanmar",
99 | "bo": "tibetan",
100 | "tl": "tagalog",
101 | "mg": "malagasy",
102 | "as": "assamese",
103 | "tt": "tatar",
104 | "haw": "hawaiian",
105 | "ln": "lingala",
106 | "ha": "hausa",
107 | "ba": "bashkir",
108 | "jw": "javanese",
109 | "su": "sundanese",
110 | "yue": "cantonese",
111 | }
112 |
113 | # language code lookup by name, with a few language aliases
114 | TO_LANGUAGE_CODE = {
115 | **{language: code for code, language in LANGUAGES.items()},
116 | "burmese": "my",
117 | "valencian": "ca",
118 | "flemish": "nl",
119 | "haitian": "ht",
120 | "letzeburgesch": "lb",
121 | "pushto": "ps",
122 | "panjabi": "pa",
123 | "moldavian": "ro",
124 | "moldovan": "ro",
125 | "sinhalese": "si",
126 | "castilian": "es",
127 | "mandarin": "zh",
128 | }
129 |
130 |
131 | @dataclass
132 | class Tokenizer:
133 | """A thin wrapper around `tiktoken` providing quick access to special tokens"""
134 |
135 | encoding: tiktoken.Encoding
136 | num_languages: int
137 | language: Optional[str] = None
138 | task: Optional[str] = None
139 | sot_sequence: Tuple[int] = ()
140 | special_tokens: Dict[str, int] = field(default_factory=dict)
141 |
142 | def __post_init__(self):
143 | for special in self.encoding.special_tokens_set:
144 | special_token = self.encoding.encode_single_token(special)
145 | self.special_tokens[special] = special_token
146 |
147 | sot: int = self.special_tokens["<|startoftranscript|>"]
148 | translate: int = self.special_tokens["<|translate|>"]
149 | transcribe: int = self.special_tokens["<|transcribe|>"]
150 |
151 | langs = tuple(LANGUAGES.keys())[: self.num_languages]
152 | sot_sequence = [sot]
153 | if self.language is not None:
154 | sot_sequence.append(sot + 1 + langs.index(self.language))
155 | if self.task is not None:
156 | task_token: int = transcribe if self.task == "transcribe" else translate
157 | sot_sequence.append(task_token)
158 |
159 | self.sot_sequence = tuple(sot_sequence)
160 |
161 | def encode(self, text, **kwargs):
162 | return self.encoding.encode(text, **kwargs)
163 |
164 | def decode(self, token_ids: List[int], **kwargs) -> str:
165 | token_ids = [t for t in token_ids if t < self.timestamp_begin]
166 | return self.encoding.decode(token_ids, **kwargs)
167 |
168 | def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
169 | """
170 | Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
171 | This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
172 | """
173 | return self.encoding.decode(token_ids, **kwargs)
174 |
175 | @cached_property
176 | def eot(self) -> int:
177 | return self.encoding.eot_token
178 |
179 | @cached_property
180 | def transcribe(self) -> int:
181 | return self.special_tokens["<|transcribe|>"]
182 |
183 | @cached_property
184 | def translate(self) -> int:
185 | return self.special_tokens["<|translate|>"]
186 |
187 | @cached_property
188 | def sot(self) -> int:
189 | return self.special_tokens["<|startoftranscript|>"]
190 |
191 | @cached_property
192 | def sot_lm(self) -> int:
193 | return self.special_tokens["<|startoflm|>"]
194 |
195 | @cached_property
196 | def sot_prev(self) -> int:
197 | return self.special_tokens["<|startofprev|>"]
198 |
199 | @cached_property
200 | def no_speech(self) -> int:
201 | return self.special_tokens["<|nospeech|>"]
202 |
203 | @cached_property
204 | def no_timestamps(self) -> int:
205 | return self.special_tokens["<|notimestamps|>"]
206 |
207 | @cached_property
208 | def timestamp_begin(self) -> int:
209 | return self.special_tokens["<|0.00|>"]
210 |
211 | @cached_property
212 | def language_token(self) -> int:
213 | """Returns the token id corresponding to the value of the `language` field"""
214 | if self.language is None:
215 | raise ValueError("This tokenizer does not have language token configured")
216 |
217 | return self.to_language_token(self.language)
218 |
219 | def to_language_token(self, language):
220 | if token := self.special_tokens.get(f"<|{language}|>", None):
221 | return token
222 |
223 | raise KeyError(f"Language {language} not found in tokenizer.")
224 |
225 | @cached_property
226 | def all_language_tokens(self) -> Tuple[int]:
227 | result = []
228 | for token, token_id in self.special_tokens.items():
229 | if token.strip("<|>") in LANGUAGES:
230 | result.append(token_id)
231 | return tuple(result)[: self.num_languages]
232 |
233 | @cached_property
234 | def all_language_codes(self) -> Tuple[str]:
235 | return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
236 |
237 | @cached_property
238 | def sot_sequence_including_notimestamps(self) -> Tuple[int]:
239 | return tuple(list(self.sot_sequence) + [self.no_timestamps])
240 |
241 | @cached_property
242 | def non_speech_tokens(self) -> Tuple[int]:
243 | """
244 | Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
245 | annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
246 |
247 | - ♪♪♪
248 | - ( SPEAKING FOREIGN LANGUAGE )
249 | - [DAVID] Hey there,
250 |
251 | keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
252 | """
253 | symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
254 | symbols += (
255 | "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
256 | )
257 |
258 | # symbols that may be a single token or multiple tokens depending on the tokenizer.
259 | # In case they're multiple tokens, suppress the first token, which is safe because:
260 | # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
261 | # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
262 | miscellaneous = set("♩♪♫♬♭♮♯")
263 | assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
264 |
265 | # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
266 | result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
267 | for symbol in symbols + list(miscellaneous):
268 | for tokens in [
269 | self.encoding.encode(symbol),
270 | self.encoding.encode(" " + symbol),
271 | ]:
272 | if len(tokens) == 1 or symbol in miscellaneous:
273 | result.add(tokens[0])
274 |
275 | return tuple(sorted(result))
276 |
277 | def split_to_word_tokens(self, tokens: List[int]):
278 | if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
279 | # These languages don't typically use spaces, so it is difficult to split words
280 | # without morpheme analysis. Here, we instead split words at any
281 | # position where the tokens are decoded as valid unicode points
282 | return self.split_tokens_on_unicode(tokens)
283 |
284 | return self.split_tokens_on_spaces(tokens)
285 |
286 | def split_tokens_on_unicode(self, tokens: List[int]):
287 | decoded_full = self.decode_with_timestamps(tokens)
288 | replacement_char = "\ufffd"
289 |
290 | words = []
291 | word_tokens = []
292 | current_tokens = []
293 | unicode_offset = 0
294 |
295 | for token in tokens:
296 | current_tokens.append(token)
297 | decoded = self.decode_with_timestamps(current_tokens)
298 |
299 | if (
300 | replacement_char not in decoded
301 | or decoded_full[unicode_offset + decoded.index(replacement_char)]
302 | == replacement_char
303 | ):
304 | words.append(decoded)
305 | word_tokens.append(current_tokens)
306 | current_tokens = []
307 | unicode_offset += len(decoded)
308 |
309 | return words, word_tokens
310 |
311 | def split_tokens_on_spaces(self, tokens: List[int]):
312 | subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
313 | words = []
314 | word_tokens = []
315 |
316 | for subword, subword_tokens in zip(subwords, subword_tokens_list):
317 | special = subword_tokens[0] >= self.eot
318 | with_space = subword.startswith(" ")
319 | punctuation = subword.strip() in string.punctuation
320 | if special or with_space or punctuation or len(words) == 0:
321 | words.append(subword)
322 | word_tokens.append(subword_tokens)
323 | else:
324 | words[-1] = words[-1] + subword
325 | word_tokens[-1].extend(subword_tokens)
326 |
327 | return words, word_tokens
328 |
329 |
330 | @lru_cache(maxsize=None)
331 | def get_encoding(name: str = "gpt2", num_languages: int = 99):
332 | vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
333 | ranks = {
334 | base64.b64decode(token): int(rank)
335 | for token, rank in (line.split() for line in open(vocab_path) if line)
336 | }
337 | n_vocab = len(ranks)
338 | special_tokens = {}
339 |
340 | specials = [
341 | "<|endoftext|>",
342 | "<|startoftranscript|>",
343 | *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
344 | "<|translate|>",
345 | "<|transcribe|>",
346 | "<|startoflm|>",
347 | "<|startofprev|>",
348 | "<|nospeech|>",
349 | "<|notimestamps|>",
350 | *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
351 | ]
352 |
353 | for token in specials:
354 | special_tokens[token] = n_vocab
355 | n_vocab += 1
356 |
357 | return tiktoken.Encoding(
358 | name=os.path.basename(vocab_path),
359 | explicit_n_vocab=n_vocab,
360 | pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
361 | mergeable_ranks=ranks,
362 | special_tokens=special_tokens,
363 | )
364 |
365 |
366 | @lru_cache(maxsize=None)
367 | def get_tokenizer(
368 | multilingual: bool,
369 | *,
370 | num_languages: int = 99,
371 | language: Optional[str] = None,
372 | task: Optional[str] = None, # Literal["transcribe", "translate", None]
373 | ) -> Tokenizer:
374 | if language is not None:
375 | language = language.lower()
376 | if language not in LANGUAGES:
377 | if language in TO_LANGUAGE_CODE:
378 | language = TO_LANGUAGE_CODE[language]
379 | else:
380 | raise ValueError(f"Unsupported language: {language}")
381 |
382 | if multilingual:
383 | encoding_name = "multilingual"
384 | language = language or "en"
385 | task = task or "transcribe"
386 | else:
387 | encoding_name = "gpt2"
388 | language = None
389 | task = None
390 |
391 | encoding = get_encoding(name=encoding_name, num_languages=num_languages)
392 |
393 | return Tokenizer(
394 | encoding=encoding, num_languages=num_languages, language=language, task=task
395 | )
396 |
--------------------------------------------------------------------------------
/careless_whisper_stream/audio.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from functools import lru_cache
4 | from subprocess import CalledProcessError, run
5 | from typing import Optional, Union
6 |
7 | import wave
8 | import torch
9 | import pyaudio
10 | import numpy as np
11 | import soundfile as sf
12 | import torch.nn.functional as F
13 |
14 | from .utils import exact_div
15 |
16 | # hard-coded audio hyperparameters
17 | SAMPLE_RATE = 16000
18 | N_FFT = 400
19 | HOP_LENGTH = 160
20 | CHUNK_LENGTH = 30
21 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
22 | N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
23 |
24 | N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
25 | FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
26 | TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
27 |
28 |
29 | class MyStream:
30 | def __init__(self,
31 | ms_gran: int = 200,
32 | sample_rate: int = 16000,
33 | channels: int = 2,
34 | filename: str = None,
35 | inp_dtype: any = pyaudio.paInt16,
36 | simulate_stream: bool = False,
37 | wav_file: str = None,
38 | relay: bool = False,
39 | use_latency: bool = False,
40 | pad_trim: bool = True,
41 | use_remote_machine: bool = False):
42 |
43 | assert ms_gran % 20 == 0, "ms_gran must be a multiple of 20"
44 |
45 | self.ms_gran = ms_gran
46 | self.sample_rate = sample_rate
47 | self.channels = channels
48 | self.inp_dtype = inp_dtype
49 | self.relay = relay
50 | self.use_latency = use_latency
51 | self.use_remote_machine = use_remote_machine
52 |
53 | rate_fraction = ms_gran / 1000
54 | self.chunk_size = int(rate_fraction * sample_rate)
55 | self.filename = filename
56 | self.streamed_wav_file = wav_file
57 |
58 | self.simulate_stream = simulate_stream
59 | if self.simulate_stream:
60 | assert wav_file is not None, "when simulating stream a wav file must be provided."
61 | if pad_trim:
62 | self.wav_array = pad_or_trim(load_audio(wav_file, sample_rate), length=N_SAMPLES+180) # wav array
63 | else:
64 | audio = load_audio(wav_file, sample_rate)
65 | self.wav_array = pad_or_trim(audio, length=audio.shape[-1]+180)
66 | print(f"{self.wav_array.shape=}")
67 |
68 | def _simulate_stream_using_wav(self):
69 | print("Streaming simulation of a wav started...")
70 |
71 | for i in range(self.wav_array.shape[-1] // self.chunk_size):
72 | if i == 0:
73 | yield self.wav_array[..., :(((i + 1) * self.chunk_size) + 40 + 320)] # 320 is extra 20 msec buffer we need!
74 | else:
75 | yield self.wav_array[..., ((i * self.chunk_size) + 40 + 320):(((i + 1) * self.chunk_size) + 40 + 320)]
76 |
77 | if self.use_latency: time.sleep(self.ms_gran / 1000) # simulating the latency between audio chunks
78 |
79 | def open_stream(self):
80 | if self.simulate_stream or self.relay or self.use_remote_machine: return
81 |
82 | self.audio = pyaudio.PyAudio()
83 | self.stream = self.audio.open(input=True, format=self.inp_dtype, channels=self.channels, rate=self.sample_rate, frames_per_buffer=self.chunk_size)
84 |
85 | def _read_from_stream(self):
86 | print("Streaming instance recording started...")
87 |
88 | while True:
89 | yield self.stream.read(self.chunk_size)
90 |
91 | def _follow_growing_wav(self):
92 | while not os.path.exists(self.streamed_wav_file):
93 | time.sleep(0.1)
94 |
95 | with sf.SoundFile(self.streamed_wav_file, mode='r') as f:
96 | while True:
97 | block = f.read(self.chunk_size)
98 | if len(block) == 0:
99 | time.sleep(self.ms_gran / 1000) # Wait for more data
100 | continue
101 | yield block
102 |
103 | def _read_raw_pcm(self):
104 | samples_per_chunk = int(self.sample_rate * (self.ms_gran / 1000))
105 | bytes_per_sample = 2 # s16le = 16 bits = 2 bytes
106 | chunk_size = samples_per_chunk * bytes_per_sample
107 |
108 | while not os.path.exists(self.streamed_wav_file):
109 | time.sleep(0.1)
110 |
111 | with open(self.streamed_wav_file, 'rb') as f:
112 | while True:
113 | chunk = f.read(chunk_size)
114 | if not chunk:
115 | time.sleep((self.ms_gran / 1000))
116 | continue
117 | yield np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32768.0
118 |
119 | def read(self):
120 | if self.simulate_stream:
121 | return self._simulate_stream_using_wav()
122 |
123 | if self.use_remote_machine:
124 | return self._read_raw_pcm()
125 |
126 | return self._read_from_stream()
127 |
128 | def _save_recording_file(self, frames: list):
129 | print(f"Saving recorded audio file on path {self.filename}")
130 |
131 | waveFile = wave.open(self.filename, 'wb')
132 | waveFile.setnchannels(self.channels)
133 | waveFile.setsampwidth(self.audio.get_sample_size(self.inp_dtype))
134 | waveFile.setframerate(self.sample_rate)
135 | waveFile.writeframes(b''.join(frames))
136 | waveFile.close()
137 |
138 | def close_stream(self, frames: list):
139 | if self.simulate_stream: return
140 |
141 | # Stop Recording
142 | self.stream.stop_stream()
143 | self.stream.close()
144 | self.audio.terminate()
145 |
146 | print("Finished recording, stream and audio terminated.")
147 |
148 | if self.filename: self._save_recording_file(frames)
149 |
150 |
151 | def load_audio(file: str, sr: int = SAMPLE_RATE):
152 | """
153 | Open an audio file and read as mono waveform, resampling as necessary
154 |
155 | Parameters
156 | ----------
157 | file: str
158 | The audio file to open
159 |
160 | sr: int
161 | The sample rate to resample the audio if necessary
162 |
163 | Returns
164 | -------
165 | A NumPy array containing the audio waveform, in float32 dtype.
166 | """
167 |
168 | # This launches a subprocess to decode audio while down-mixing
169 | # and resampling as necessary. Requires the ffmpeg CLI in PATH.
170 | # fmt: off
171 | cmd = [
172 | "ffmpeg",
173 | "-nostdin",
174 | "-threads", "0",
175 | "-i", file,
176 | "-f", "s16le",
177 | "-ac", "1",
178 | "-acodec", "pcm_s16le",
179 | "-ar", str(sr),
180 | "-"
181 | ]
182 | # fmt: on
183 | try:
184 | out = run(cmd, capture_output=True, check=True).stdout
185 | except CalledProcessError as e:
186 | raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
187 |
188 | return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
189 |
190 |
191 | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
192 | """
193 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
194 | """
195 | if torch.is_tensor(array):
196 | if array.shape[axis] > length:
197 | array = array.index_select(
198 | dim=axis, index=torch.arange(length, device=array.device)
199 | )
200 |
201 | if array.shape[axis] < length:
202 | pad_widths = [(0, 0)] * array.ndim
203 | pad_widths[axis] = (0, length - array.shape[axis])
204 | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
205 | else:
206 | if array.shape[axis] > length:
207 | array = array.take(indices=range(length), axis=axis)
208 |
209 | if array.shape[axis] < length:
210 | pad_widths = [(0, 0)] * array.ndim
211 | pad_widths[axis] = (0, length - array.shape[axis])
212 | array = np.pad(array, pad_widths)
213 |
214 | return array
215 |
216 |
217 | @lru_cache(maxsize=None)
218 | def mel_filters(device, n_mels: int) -> torch.Tensor:
219 | """
220 | load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
221 | Allows decoupling librosa dependency; saved using:
222 |
223 | np.savez_compressed(
224 | "mel_filters.npz",
225 | mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
226 | mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
227 | )
228 | """
229 | assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
230 |
231 | filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
232 | with np.load(filters_path, allow_pickle=False) as f:
233 | return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
234 |
235 |
236 | def log_mel_spectrogram(
237 | audio: Union[str, np.ndarray, torch.Tensor],
238 | n_mels: int = 80,
239 | padding: int = 0,
240 | device: Optional[Union[str, torch.device]] = None,
241 | ):
242 | """
243 | Compute the log-Mel spectrogram of
244 |
245 | Parameters
246 | ----------
247 | audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
248 | The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
249 |
250 | n_mels: int
251 | The number of Mel-frequency filters, only 80 is supported
252 |
253 | padding: int
254 | Number of zero samples to pad to the right
255 |
256 | device: Optional[Union[str, torch.device]]
257 | If given, the audio tensor is moved to this device before STFT
258 |
259 | Returns
260 | -------
261 | torch.Tensor, shape = (80, n_frames)
262 | A Tensor that contains the Mel spectrogram
263 | """
264 | if not torch.is_tensor(audio):
265 | if isinstance(audio, str):
266 | audio = load_audio(audio)
267 | audio = torch.from_numpy(audio)
268 |
269 | if device is not None:
270 | audio = audio.to(device)
271 | if padding > 0:
272 | audio = F.pad(audio, (0, padding))
273 |
274 | window = torch.hann_window(N_FFT).to(audio.device)
275 | stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
276 | magnitudes = stft[..., :-1].abs() ** 2
277 |
278 | filters = mel_filters(audio.device, n_mels)
279 | mel_spec = filters @ magnitudes
280 |
281 | log_spec = torch.clamp(mel_spec, min=1e-10).log10()
282 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
283 | log_spec = (log_spec + 4.0) / 4.0
284 | return log_spec
285 |
286 |
287 | class SpectrogramStream:
288 | def __init__(self, n_fft: int = N_FFT, hop_length: int = HOP_LENGTH, n_mels: int = 80, window: Optional[str] = "hann", pad_mode: str = "reflect"):
289 |
290 | self.n_fft = n_fft
291 | self.hop_length = hop_length
292 | self.pad_mode = pad_mode
293 | self.n_mels = n_mels
294 |
295 | self.window = torch.hann_window(n_fft)
296 | self.window_type = window
297 |
298 | self.ctx_samples = self.n_fft - self.hop_length
299 |
300 | self.reset()
301 |
302 | def reset(self):
303 | self.is_first = True
304 | self.audio_ctx = torch.tensor([])
305 | self.log_spec_max = -torch.inf
306 |
307 | def calc_mel_with_new_frame(self, audio_frame: torch.Tensor, is_last: bool = False):
308 |
309 | self.window = self.window.to(audio_frame.device)
310 |
311 | if len(audio_frame.shape) == 1:
312 | audio_frame = audio_frame.unsqueeze(0)
313 |
314 | n_batch = audio_frame.shape[0]
315 |
316 | if isinstance(self.log_spec_max, float):
317 | self.log_spec_max = torch.ones((n_batch)).to(audio_frame.device) * -torch.inf
318 |
319 | # check if we are on first frame, if so, pad using reflection
320 | if self.is_first:
321 | pad = int(self.n_fft // 2) + 1
322 | audio_input = F.pad(audio_frame, [pad, 0], self.pad_mode)
323 | self.is_first = False
324 | else: # pad with previous context
325 | audio_input = torch.cat([self.audio_ctx[..., -self.ctx_samples:], audio_frame], dim=-1)
326 |
327 | if is_last: # pad reflect last frame
328 | pad = int(self.n_fft // 4) + 1
329 | audio_input = F.pad(audio_input, [pad, 0], self.pad_mode)
330 |
331 | self.audio_ctx = audio_frame # now audio ctx is the last frame
332 |
333 | stft = torch.stft(audio_input, self.n_fft, self.hop_length, window=self.window, return_complex=True, center=False)
334 | magnitudes = stft.abs() ** 2
335 | filters = mel_filters(audio_frame.device, self.n_mels)
336 | mel_spec = filters @ magnitudes
337 |
338 | log_spec = torch.clamp(mel_spec, min=1e-10).log10() # from shape (b, n_mels, audio_frames)
339 | self.log_spec_max = torch.maximum(log_spec.view(n_batch, -1).max(dim=-1).values, self.log_spec_max).to(log_spec.device)
340 |
341 | log_spec = torch.maximum(log_spec.view(n_batch, -1).permute(1, 0), self.log_spec_max - 8.0).permute(1, 0).view(n_batch, self.n_mels, -1)
342 | log_spec = (log_spec + 4.0) / 4.0
343 | return log_spec
344 |
345 | def _simulate_streaming_log_spec(self, audio: torch.Tensor, ms_gran: int = 300, total_frames: int = 3000, get_gt: bool = False):
346 | self.reset()
347 |
348 | samples_gran = HOP_LENGTH * (ms_gran // 10)
349 | sub_mel_frames = int(total_frames / ms_gran) * 10
350 | # print(samples_gran, sub_mel_frames)
351 | pred_mel = torch.cat([self.calc_mel_with_new_frame(audio[..., (i * samples_gran) + (40 * int(i != 0)): ((i + 1) * samples_gran) + 40], is_last=(i == sub_mel_frames - 1)) for i in range(sub_mel_frames)], dim=-1)
352 |
353 | if get_gt:
354 | gt_mel = log_mel_spectrogram(audio)
355 | return pred_mel, gt_mel
356 |
357 | return pred_mel
358 |
--------------------------------------------------------------------------------
/careless_whisper_stream/timing.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import subprocess
3 | import warnings
4 | from dataclasses import dataclass
5 | from typing import TYPE_CHECKING, List
6 |
7 | import numba
8 | import numpy as np
9 | import torch
10 | import torch.nn.functional as F
11 |
12 | from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
13 | from .tokenizer import Tokenizer
14 |
15 | if TYPE_CHECKING:
16 | from .model import Whisper
17 |
18 |
19 | def median_filter(x: torch.Tensor, filter_width: int):
20 | """Apply a median filter of width `filter_width` along the last dimension of `x`"""
21 | pad_width = filter_width // 2
22 | if x.shape[-1] <= pad_width:
23 | # F.pad requires the padding width to be smaller than the input dimension
24 | return x
25 |
26 | if (ndim := x.ndim) <= 2:
27 | # `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
28 | x = x[None, None, :]
29 |
30 | assert (
31 | filter_width > 0 and filter_width % 2 == 1
32 | ), "`filter_width` should be an odd number"
33 |
34 | result = None
35 | x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
36 | if x.is_cuda:
37 | try:
38 | from .triton_ops import median_filter_cuda
39 |
40 | result = median_filter_cuda(x, filter_width)
41 | except (RuntimeError, subprocess.CalledProcessError):
42 | warnings.warn(
43 | "Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
44 | "falling back to a slower median kernel implementation..."
45 | )
46 |
47 | if result is None:
48 | # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
49 | result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
50 |
51 | if ndim <= 2:
52 | result = result[0, 0]
53 |
54 | return result
55 |
56 |
57 | @numba.jit(nopython=True)
58 | def backtrace(trace: np.ndarray):
59 | i = trace.shape[0] - 1
60 | j = trace.shape[1] - 1
61 | trace[0, :] = 2
62 | trace[:, 0] = 1
63 |
64 | result = []
65 | while i > 0 or j > 0:
66 | result.append((i - 1, j - 1))
67 |
68 | if trace[i, j] == 0:
69 | i -= 1
70 | j -= 1
71 | elif trace[i, j] == 1:
72 | i -= 1
73 | elif trace[i, j] == 2:
74 | j -= 1
75 | else:
76 | raise ValueError("Unexpected trace[i, j]")
77 |
78 | result = np.array(result)
79 | return result[::-1, :].T
80 |
81 |
82 | @numba.jit(nopython=True, parallel=True)
83 | def dtw_cpu(x: np.ndarray):
84 | N, M = x.shape
85 | cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
86 | trace = -np.ones((N + 1, M + 1), dtype=np.float32)
87 |
88 | cost[0, 0] = 0
89 | for j in range(1, M + 1):
90 | for i in range(1, N + 1):
91 | c0 = cost[i - 1, j - 1]
92 | c1 = cost[i - 1, j]
93 | c2 = cost[i, j - 1]
94 |
95 | if c0 < c1 and c0 < c2:
96 | c, t = c0, 0
97 | elif c1 < c0 and c1 < c2:
98 | c, t = c1, 1
99 | else:
100 | c, t = c2, 2
101 |
102 | cost[i, j] = x[i - 1, j - 1] + c
103 | trace[i, j] = t
104 |
105 | return backtrace(trace)
106 |
107 |
108 | def dtw_cuda(x, BLOCK_SIZE=1024):
109 | from .triton_ops import dtw_kernel
110 |
111 | M, N = x.shape
112 | assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
113 |
114 | x_skew = (
115 | F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
116 | )
117 | x_skew = x_skew.T.contiguous()
118 | cost = torch.ones(N + M + 2, M + 2) * np.inf
119 | cost[0, 0] = 0
120 | cost = cost.cuda()
121 | trace = torch.zeros_like(cost, dtype=torch.int32)
122 |
123 | dtw_kernel[(1,)](
124 | cost,
125 | trace,
126 | x_skew,
127 | x_skew.stride(0),
128 | cost.stride(0),
129 | trace.stride(0),
130 | N,
131 | M,
132 | BLOCK_SIZE=BLOCK_SIZE,
133 | )
134 |
135 | trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
136 | :, : N + 1
137 | ]
138 | return backtrace(trace.cpu().numpy())
139 |
140 |
141 | def dtw(x: torch.Tensor) -> np.ndarray:
142 | if x.is_cuda:
143 | try:
144 | return dtw_cuda(x)
145 | except (RuntimeError, subprocess.CalledProcessError):
146 | warnings.warn(
147 | "Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
148 | "falling back to a slower DTW implementation..."
149 | )
150 |
151 | return dtw_cpu(x.double().cpu().numpy())
152 |
153 |
154 | @dataclass
155 | class WordTiming:
156 | word: str
157 | tokens: List[int]
158 | start: float
159 | end: float
160 | probability: float
161 |
162 |
163 | def find_alignment(
164 | model: "Whisper",
165 | tokenizer: Tokenizer,
166 | text_tokens: List[int],
167 | mel: torch.Tensor,
168 | num_frames: int,
169 | *,
170 | medfilt_width: int = 7,
171 | qk_scale: float = 1.0,
172 | ) -> List[WordTiming]:
173 | if len(text_tokens) == 0:
174 | return []
175 |
176 | tokens = torch.tensor(
177 | [
178 | *tokenizer.sot_sequence,
179 | tokenizer.no_timestamps,
180 | *text_tokens,
181 | tokenizer.eot,
182 | ]
183 | ).to(model.device)
184 |
185 | # install hooks on the cross attention layers to retrieve the attention weights
186 | QKs = [None] * model.dims.n_text_layer
187 | hooks = [
188 | block.cross_attn.register_forward_hook(
189 | lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
190 | )
191 | for i, block in enumerate(model.decoder.blocks)
192 | ]
193 |
194 | with torch.no_grad():
195 | logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
196 | sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
197 | token_probs = sampled_logits.softmax(dim=-1)
198 | text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
199 | text_token_probs = text_token_probs.tolist()
200 |
201 | for hook in hooks:
202 | hook.remove()
203 |
204 | # heads * tokens * frames
205 | weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
206 | weights = weights[:, :, : num_frames // 2]
207 | weights = (weights * qk_scale).softmax(dim=-1)
208 | std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
209 | weights = (weights - mean) / std
210 | weights = median_filter(weights, medfilt_width)
211 |
212 | matrix = weights.mean(axis=0)
213 | matrix = matrix[len(tokenizer.sot_sequence) : -1]
214 | text_indices, time_indices = dtw(-matrix)
215 |
216 | words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
217 | if len(word_tokens) <= 1:
218 | # return on eot only
219 | # >>> np.pad([], (1, 0))
220 | # array([0.])
221 | # This results in crashes when we lookup jump_times with float, like
222 | # IndexError: arrays used as indices must be of integer (or boolean) type
223 | return []
224 | word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
225 |
226 | jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
227 | jump_times = time_indices[jumps] / TOKENS_PER_SECOND
228 | start_times = jump_times[word_boundaries[:-1]]
229 | end_times = jump_times[word_boundaries[1:]]
230 | word_probabilities = [
231 | np.mean(text_token_probs[i:j])
232 | for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
233 | ]
234 |
235 | return [
236 | WordTiming(word, tokens, start, end, probability)
237 | for word, tokens, start, end, probability in zip(
238 | words, word_tokens, start_times, end_times, word_probabilities
239 | )
240 | ]
241 |
242 |
243 | def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
244 | # merge prepended punctuations
245 | i = len(alignment) - 2
246 | j = len(alignment) - 1
247 | while i >= 0:
248 | previous = alignment[i]
249 | following = alignment[j]
250 | if previous.word.startswith(" ") and previous.word.strip() in prepended:
251 | # prepend it to the following word
252 | following.word = previous.word + following.word
253 | following.tokens = previous.tokens + following.tokens
254 | previous.word = ""
255 | previous.tokens = []
256 | else:
257 | j = i
258 | i -= 1
259 |
260 | # merge appended punctuations
261 | i = 0
262 | j = 1
263 | while j < len(alignment):
264 | previous = alignment[i]
265 | following = alignment[j]
266 | if not previous.word.endswith(" ") and following.word in appended:
267 | # append it to the previous word
268 | previous.word = previous.word + following.word
269 | previous.tokens = previous.tokens + following.tokens
270 | following.word = ""
271 | following.tokens = []
272 | else:
273 | i = j
274 | j += 1
275 |
276 |
277 | def add_word_timestamps(
278 | *,
279 | segments: List[dict],
280 | model: "Whisper",
281 | tokenizer: Tokenizer,
282 | mel: torch.Tensor,
283 | num_frames: int,
284 | prepend_punctuations: str = "\"'“¿([{-",
285 | append_punctuations: str = "\"'.。,,!!??::”)]}、",
286 | last_speech_timestamp: float,
287 | **kwargs,
288 | ):
289 | if len(segments) == 0:
290 | return
291 |
292 | text_tokens_per_segment = [
293 | [token for token in segment["tokens"] if token < tokenizer.eot]
294 | for segment in segments
295 | ]
296 |
297 | text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
298 | alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
299 | word_durations = np.array([t.end - t.start for t in alignment])
300 | word_durations = word_durations[word_durations.nonzero()]
301 | median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
302 | median_duration = min(0.7, float(median_duration))
303 | max_duration = median_duration * 2
304 |
305 | # hack: truncate long words at sentence boundaries.
306 | # a better segmentation algorithm based on VAD should be able to replace this.
307 | if len(word_durations) > 0:
308 | sentence_end_marks = ".。!!??"
309 | # ensure words at sentence boundaries are not longer than twice the median word duration.
310 | for i in range(1, len(alignment)):
311 | if alignment[i].end - alignment[i].start > max_duration:
312 | if alignment[i].word in sentence_end_marks:
313 | alignment[i].end = alignment[i].start + max_duration
314 | elif alignment[i - 1].word in sentence_end_marks:
315 | alignment[i].start = alignment[i].end - max_duration
316 |
317 | merge_punctuations(alignment, prepend_punctuations, append_punctuations)
318 |
319 | time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
320 | word_index = 0
321 |
322 | for segment, text_tokens in zip(segments, text_tokens_per_segment):
323 | saved_tokens = 0
324 | words = []
325 |
326 | while word_index < len(alignment) and saved_tokens < len(text_tokens):
327 | timing = alignment[word_index]
328 |
329 | if timing.word:
330 | words.append(
331 | dict(
332 | word=timing.word,
333 | start=round(time_offset + timing.start, 2),
334 | end=round(time_offset + timing.end, 2),
335 | probability=timing.probability,
336 | )
337 | )
338 |
339 | saved_tokens += len(timing.tokens)
340 | word_index += 1
341 |
342 | # hack: truncate long words at segment boundaries.
343 | # a better segmentation algorithm based on VAD should be able to replace this.
344 | if len(words) > 0:
345 | # ensure the first and second word after a pause is not longer than
346 | # twice the median word duration.
347 | if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
348 | words[0]["end"] - words[0]["start"] > max_duration
349 | or (
350 | len(words) > 1
351 | and words[1]["end"] - words[0]["start"] > max_duration * 2
352 | )
353 | ):
354 | if (
355 | len(words) > 1
356 | and words[1]["end"] - words[1]["start"] > max_duration
357 | ):
358 | boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration)
359 | words[0]["end"] = words[1]["start"] = boundary
360 | words[0]["start"] = max(0, words[0]["end"] - max_duration)
361 |
362 | # prefer the segment-level start timestamp if the first word is too long.
363 | if (
364 | segment["start"] < words[0]["end"]
365 | and segment["start"] - 0.5 > words[0]["start"]
366 | ):
367 | words[0]["start"] = max(
368 | 0, min(words[0]["end"] - median_duration, segment["start"])
369 | )
370 | else:
371 | segment["start"] = words[0]["start"]
372 |
373 | # prefer the segment-level end timestamp if the last word is too long.
374 | if (
375 | segment["end"] > words[-1]["start"]
376 | and segment["end"] + 0.5 < words[-1]["end"]
377 | ):
378 | words[-1]["end"] = max(
379 | words[-1]["start"] + median_duration, segment["end"]
380 | )
381 | else:
382 | segment["end"] = words[-1]["end"]
383 |
384 | last_speech_timestamp = segment["end"]
385 |
386 | segment["words"] = words
387 |
--------------------------------------------------------------------------------
/training_code/whisper_module.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("./")
3 | import torch
4 | import random
5 | import evaluate
6 | import careless_whisper_stream
7 | import careless_whisper_stream.tokenizer as whisper_tokenizer
8 |
9 | from torch import nn, Tensor
10 | from torch.optim.adamw import AdamW
11 | from training_code.utils import Config
12 | from careless_whisper_stream import StreamingWhisper
13 | from careless_whisper_stream.audio import HOP_LENGTH
14 | from pytorch_lightning import LightningModule
15 | from torch.optim.lr_scheduler import LinearLR, ReduceLROnPlateau
16 | from training_code.datasets_classes import TIMIT, WAVsDataset, AlignedTextGridDataset
17 | from training_code.collators import WhisperDataCollatorWithPadding, LoRAWhisperDataCollatorWithPadding
18 |
19 | class WhisperCustomModel(LightningModule):
20 | def __init__(self, cfg:Config, model_name="tiny", lang="en", train_dataset: str = None, eval_dataset: str = None, task="transcribe") -> None:
21 | super().__init__()
22 | self.save_hyperparameters()
23 | self.task = task
24 | self.lang = lang
25 | self.model = careless_whisper_stream.load_model(model_name)
26 | self.tokenizer: whisper_tokenizer = whisper_tokenizer.get_tokenizer(True, language=lang)
27 |
28 | self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
29 | self.metrics_wer = evaluate.load("wer")
30 | self.metrics_cer = evaluate.load("cer")
31 |
32 | self.params = self.model
33 |
34 | self.cfg = cfg
35 | self.__train_dataset = train_dataset
36 | self.__eval_dataset = eval_dataset
37 |
38 | def forward(self, x):
39 | return self.model(x)
40 |
41 | def calc_wer_val(self, out: Tensor, labels: Tensor):
42 | out[out == -100] = self.tokenizer.eot
43 | labels[labels == -100] = self.tokenizer.eot
44 |
45 | o_list, l_list = [], []
46 | for o, l in zip(out, labels):
47 | o = torch.argmax(o, dim=1)
48 | o_list.append(self.tokenizer.decode(o))
49 | l_list.append(self.tokenizer.decode(l))
50 |
51 | wer = self.metrics_wer.compute(references=l_list, predictions=o_list)
52 | return wer
53 |
54 | def training_step(self, batch, batch_id):
55 | input_ids = batch["input_ids"]
56 | labels = batch["labels"].long()
57 | dec_input_ids = batch["dec_input_ids"].long()
58 |
59 | with torch.no_grad():
60 | audio_features = self.model.encoder(input_ids)
61 |
62 | out, _, _ = self.model.decoder(dec_input_ids, audio_features)
63 |
64 | loss = self.loss_fn(out.view(-1, out.size(-1)), labels.view(-1))
65 | self.log("train/loss", loss, on_step=True, prog_bar=True, logger=True)
66 | return loss
67 |
68 | def validation_step(self, batch, batch_id):
69 | input_ids = batch["input_ids"]
70 | labels = batch["labels"].long()
71 | dec_input_ids = batch["dec_input_ids"].long()
72 |
73 | audio_features = self.model.encoder(input_ids)
74 | out, _, _ = self.model.decoder(dec_input_ids, audio_features)
75 |
76 | loss = self.loss_fn(out.view(-1, out.size(-1)), labels.view(-1))
77 | wer = self.calc_wer_val(out, labels)
78 |
79 | self.log("val/loss", loss, on_step=True, prog_bar=True, logger=True, on_epoch=True)
80 | self.log("val/wer", wer, on_step=True, prog_bar=True, logger=True, on_epoch=True)
81 |
82 | return {
83 | "wer": wer,
84 | "loss": loss
85 | }
86 |
87 | def configure_optimizers(self):
88 | model = self.params
89 | no_decay = ["bias", "LayerNorm.weight"]
90 | optimizer_grouped_parameters = [
91 | {
92 | "params": [p for n, p in model.named_parameters()
93 | if not any(nd in n for nd in no_decay)],
94 | "weight_decay": self.cfg.weight_decay,
95 | },
96 | {
97 | "params": [p for n, p in model.named_parameters()
98 | if any(nd in n for nd in no_decay)],
99 | "weight_decay": 0.0,
100 | },
101 | ]
102 | optimizer = AdamW(optimizer_grouped_parameters,
103 | lr=self.cfg.learning_rate,
104 | eps=self.cfg.adam_epsilon)
105 | self.optimizer = optimizer
106 |
107 | scheduler = LinearLR(
108 | self.optimizer, start_factor=0.5, end_factor=0.8, total_iters=self.t_total
109 | )
110 |
111 | self.scheduler = scheduler
112 |
113 | return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}]
114 |
115 | def setup(self, stage=None):
116 | if stage == 'fit' or stage is None:
117 | self.t_total = (
118 | (len(self.__train_dataset) // (self.cfg.batch_size))
119 | // self.cfg.gradient_accumulation_steps
120 | * float(self.cfg.num_train_epochs)
121 | )
122 |
123 | def get_dataset(self, ds_path: str, split: str):
124 | return TIMIT(ds_path, self.tokenizer)
125 |
126 | def train_dataloader(self):
127 | dataset = self.get_dataset(self.__train_dataset, "train")
128 | return torch.utils.data.DataLoader(dataset,
129 | batch_size=self.cfg.batch_size,
130 | drop_last=True, shuffle=True, num_workers=self.cfg.num_worker,
131 | collate_fn=WhisperDataCollatorWithPadding()
132 | )
133 |
134 | def val_dataloader(self):
135 | dataset = self.get_dataset(self.__eval_dataset, "val")
136 | return torch.utils.data.DataLoader(dataset,
137 | batch_size=self.cfg.batch_size,
138 | num_workers=self.cfg.num_worker,
139 | collate_fn=WhisperDataCollatorWithPadding()
140 | )
141 |
142 |
143 | class LoRAStreamedWhisper(WhisperCustomModel):
144 | def __init__(self, cfg: Config, model_name="tiny", lang="en", train_dataset: str = None, eval_dataset: str = None, task="transcribe", rank=8, enc_emb_gran=15, enc_context=1, sim_stream=False, beam_size=None, use_kv_cache=False, use_ca_kv_cache=False, get_times=False, eval_script=False) -> None:
145 | super().__init__(cfg, model_name, lang, train_dataset, eval_dataset, task)
146 |
147 | self.automatic_optimization = not cfg.streaming_train
148 |
149 | # if model_name != "large-v2" and not eval_script:
150 | print(f"enc_emb_gran: {enc_emb_gran}")
151 | self.model: StreamingWhisper = careless_whisper_stream.load_streaming_model_for_train(model_name,
152 | advisor_ckpt_path=None,
153 | advisor_type=None,
154 | rank=rank,
155 | gran=enc_emb_gran,
156 | extra_gran_blocks=enc_context,
157 | )
158 |
159 | for n, p in self.model.named_parameters():
160 | if "lora" not in n:
161 | p.requires_grad = False
162 |
163 | self.model_name = model_name
164 | self.rank = rank
165 | self.enc_emb_gran = enc_emb_gran
166 | self.enc_context = enc_context
167 | self.simulate_stream = sim_stream
168 | self.full_stream = cfg.streaming_train
169 | self.beam_size = beam_size
170 | self.language = lang
171 | self.use_kv_cache = use_kv_cache
172 | self.use_ca_kv_cache = use_ca_kv_cache
173 | self.get_times = get_times
174 | self.eval_script = eval_script
175 |
176 | # for stream mode train
177 | self.num_frames = self.model.dims.n_audio_ctx // self.enc_emb_gran # 1500 // enc_emb_gran
178 | self.mel_samples = self.enc_emb_gran * 2 * HOP_LENGTH
179 |
180 | self.params = None
181 |
182 | self.last_out = None
183 | self.__train_dataset = train_dataset
184 | self.__eval_dataset = eval_dataset
185 |
186 | def _calc_labels(self, labels: Tensor, endpoints: Tensor, index: int):
187 | t_seconds = (index + 1) * self.enc_emb_gran * 0.02
188 |
189 | # take only relevant labels into account
190 | mask = (endpoints <= t_seconds) & (endpoints != -100)
191 | eot_indices = mask.int().argmin(dim=-1)
192 | clone_labels = labels.clone()
193 |
194 | # ignore irrelevant labels
195 | clone_labels[~mask] = -100
196 |
197 | # mark eot labels
198 | clone_labels[range(labels.shape[0]), eot_indices] = self.tokenizer.eot
199 | return clone_labels
200 |
201 | def _get_sample_points(self, endpoints: Tensor):
202 | # base case
203 | if self.cfg.streaming_fraction == 1:
204 | return range(self.enc_context, self.num_frames)
205 |
206 | biggest_endpoint = endpoints.max().item()
207 |
208 | # determine last index.
209 | num_frames = min(int(((biggest_endpoint / 0.02) // self.enc_emb_gran) + int(1 / (self.enc_emb_gran * 0.02)) + 1), self.num_frames) # adding 1 sec of silence
210 | new_range = range(self.enc_context, num_frames)
211 |
212 | sample_points = random.sample(new_range, k=int(len(new_range) * self.cfg.streaming_fraction) + 1)
213 |
214 | if self.cfg.streaming_random:
215 | return sample_points
216 |
217 | return sorted(sample_points)
218 |
219 | def _forward_step_stream(self, batch, step):
220 | input_ids = batch["input_ids"]
221 | labels = batch["labels"].long()
222 | dec_input_ids = batch["dec_input_ids"].long()
223 | endpoints = batch["endpoints"]
224 |
225 | if step == "train":
226 | optimizer = self.optimizers()
227 |
228 | # forward
229 | sample_points = self._get_sample_points(endpoints)
230 | for i in sample_points:
231 | audio_features = self.model.encoder(input_ids[..., :(i + 1) * (self.enc_emb_gran * 2)], index=[0, (i + 1) * self.enc_emb_gran], mask=True)
232 | out = self.model.decoder(dec_input_ids, audio_features, dump_type="None")
233 |
234 | if step == "train":
235 | optimizer.zero_grad()
236 |
237 | # loss calc
238 | frame_labels = self._calc_labels(labels, endpoints, i)
239 | loss = self.loss_fn(out.view(-1, out.size(-1)), frame_labels.view(-1))
240 |
241 | # optimizer step if relevant.
242 | if step == "train":
243 | loss.backward()
244 | optimizer.step() # might move optimizer step to out of the loop for faster training
245 |
246 | if step == "train":
247 | return loss
248 |
249 | return out, loss
250 |
251 | def _forward_step(self, batch, step):
252 | input_ids = batch["input_ids"]
253 | labels = batch["labels"].long()
254 | dec_input_ids = batch["dec_input_ids"].long()
255 |
256 | # self.model.encoder.reset()
257 | audio_features = self.model.encoder(input_ids, mask=True) # use mask for fast learning, simulates stream mode
258 | out = self.model.decoder(dec_input_ids, audio_features, dump_type="None")
259 |
260 | # loss calc
261 | loss = self.loss_fn(out.view(-1, out.size(-1)), labels.view(-1))
262 |
263 | if step == "train":
264 | return loss
265 |
266 | return out, loss
267 |
268 | def training_step(self, batch, batch_id):
269 | if self.full_stream:
270 | loss = self._forward_step_stream(batch, "train")
271 | else:
272 | loss = self._forward_step(batch, "train")
273 |
274 | self.log("train/loss", loss, on_step=True, prog_bar=True, logger=True, sync_dist=True)
275 |
276 | return loss
277 |
278 | def validation_step(self, batch, batch_id):
279 | out, loss = self._forward_step(batch, "val")
280 | wer = self.calc_wer_val(out, batch["labels"])
281 |
282 | self.log("val/loss", loss, on_step=True, prog_bar=True, logger=True, sync_dist=True)
283 | self.log("val/wer", wer, on_step=True, prog_bar=True, logger=True, sync_dist=True)
284 |
285 | return {
286 | "wer": wer,
287 | "loss": loss
288 | }
289 |
290 | def predict_step(self, batch, batch_id):
291 | wavs = batch["wav_path"]
292 | text = batch["text"]
293 |
294 | results, times = self.model.transcribe(wav_file=wavs[0],
295 | simulate_stream=True,
296 | beam_size=self.beam_size,
297 | language=self.language,
298 | use_ca_kv_cache=self.use_ca_kv_cache,
299 | use_sa_kv_cache=self.use_kv_cache,
300 | get_times=self.get_times)
301 |
302 | return [res.text for res in results], text, wavs[0], times
303 |
304 | def on_train_epoch_start(self):
305 | random.seed(self.current_epoch + self.cfg.seed)
306 |
307 | def on_train_epoch_end(self):
308 | model_sched = self.lr_schedulers()
309 | model_sched.step(self.trainer.callback_metrics["val/loss"])
310 |
311 | def get_dataset(self, ds_path, split):
312 | print(f"Stream simulation mode: {self.simulate_stream}")
313 |
314 | if self.full_stream:
315 | return AlignedTextGridDataset(ds_path=ds_path, get_streamed_mel=True, gran=self.enc_emb_gran, extra_gran_blocks=self.enc_context, n_mels=self.model.dims.n_mels, multilingual=self.cfg.multilingual)
316 |
317 | return WAVsDataset(ds_path=ds_path, get_streamed_mel=self.simulate_stream)
318 |
319 | def configure_optimizers(self):
320 | model = self.model
321 | optimizer_grouped_parameters = [
322 | {
323 | "params": [p for n, p in model.named_parameters() if p.requires_grad],
324 | "weight_decay": self.cfg.weight_decay,
325 | "lr": self.cfg.learning_rate
326 | },
327 | {
328 | "params": [p for n, p in model.named_parameters() if not p.requires_grad],
329 | "weight_decay": 0.0,
330 | "lr": 0
331 | },
332 | ]
333 | optimizer = AdamW(optimizer_grouped_parameters,
334 | eps=self.cfg.adam_epsilon)
335 | self.optimizer = optimizer
336 |
337 | scheduler = ReduceLROnPlateau(
338 | self.optimizer, 'min', patience=2, factor=0.5
339 | )
340 |
341 | self.scheduler = scheduler
342 |
343 | return [optimizer], [{"scheduler": scheduler, "monitor": "val/loss"}]
344 |
345 | def train_dataloader(self):
346 | dataset = self.get_dataset(self.__train_dataset, "train")
347 | return torch.utils.data.DataLoader(dataset,
348 | batch_size=self.cfg.batch_size,
349 | drop_last=True, shuffle=True, num_workers=self.cfg.num_worker, pin_memory=True,
350 | collate_fn=LoRAWhisperDataCollatorWithPadding())
351 |
352 | def val_dataloader(self):
353 | dataset = self.get_dataset(self.__eval_dataset, "val")
354 | return torch.utils.data.DataLoader(dataset,
355 | batch_size=self.cfg.batch_size,
356 | num_workers=self.cfg.num_worker, pin_memory=True,
357 | collate_fn=LoRAWhisperDataCollatorWithPadding())
358 |
359 | def on_save_checkpoint(self, checkpoint):
360 | checkpoint["dims"] = self.model.dims.__dict__
361 |
--------------------------------------------------------------------------------
/careless_whisper_stream/streaming_model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from torch import Tensor
7 | from typing import Optional
8 | from .model import AudioEncoder, TextDecoder, Whisper
9 |
10 | from .streaming_decoding import DecodingTask, DecodingOptions, DecodingResult
11 | from .streaming_transcribe import transcribe as transcribe_function
12 | from .decoding import decode as non_causal_decode_function
13 | from .audio import SpectrogramStream
14 |
15 | from dataclasses import replace
16 |
17 | from pytorch_lightning import LightningModule
18 |
19 |
20 | class LoraLayer(LightningModule):
21 | def __init__(self, input_dim, output_dim, rank=8, alpha=None, *args, **kwargs):
22 | super().__init__(*args, **kwargs)
23 |
24 | self.lora_A = nn.Parameter(torch.zeros(input_dim, rank))
25 | self.lora_B = nn.Parameter(torch.zeros(rank, output_dim))
26 |
27 | self.alpha = rank if alpha is None else alpha
28 | self.rank = rank
29 | self.scale = self.alpha / self.rank
30 |
31 | self._init_weights()
32 |
33 | def _init_weights(self):
34 | nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
35 | nn.init.zeros_(self.lora_B)
36 |
37 | def forward(self, x):
38 | return x @ (self.lora_A @ self.lora_B) * self.scale
39 |
40 |
41 | class LoraLinearLayer(LightningModule):
42 | def __init__(self, base_layer: nn.Linear, rank: int = 8, bias: int = True, *args, **kwargs):
43 | super().__init__(*args, **kwargs)
44 |
45 | self.base_layer = base_layer
46 |
47 | self.lora_layer = LoraLayer(base_layer.in_features, base_layer.out_features, rank=rank)
48 | self.aggregate_lora = True
49 |
50 | def turn_on_lora(self):
51 | self.aggregate_lora = True
52 |
53 | def turn_off_lora(self):
54 | self.aggregate_lora = False
55 |
56 | def forward(self, x: Tensor):
57 | if not self.aggregate_lora:
58 | return self.base_layer(x)
59 |
60 | return self.base_layer(x) + self.lora_layer(x)
61 |
62 |
63 | class LoRAMultiHeadAttention(LightningModule):
64 | def __init__(self, n_head, query, key, value, out, rank, *args, **kwargs):
65 | super().__init__(*args, **kwargs)
66 |
67 | self.n_head = n_head
68 | self.query = LoraLinearLayer(query, rank)
69 | self.key = LoraLinearLayer(key, rank, bias=False)
70 | self.value = LoraLinearLayer(value, rank)
71 | self.out = LoraLinearLayer(out, rank)
72 |
73 | def forward(
74 | self,
75 | x: Tensor,
76 | xa: Tensor = None,
77 | mask: Tensor = None,
78 | kv_cache: dict = None,
79 | *args, **kwargs
80 | ):
81 | q = self.query(x)
82 |
83 | if kv_cache is None or xa is None or self.key not in kv_cache:
84 | k = self.key(x if xa is None else xa)
85 | v = self.value(x if xa is None else xa)
86 |
87 | else:
88 | k = kv_cache[self.key]
89 | v = kv_cache[self.value]
90 |
91 | wv, qk = self.qkv_attention(q, k, v, mask, kv_cache)
92 |
93 | return self.out(wv), qk
94 |
95 | def qkv_attention(
96 | self, q: Tensor, k: Tensor, v: Tensor, mask: Tensor = None, kv_cache: dict = None
97 | ):
98 | n_batch, n_ctx, n_state = q.shape
99 | _, k_ctx, _ = k.shape
100 | scale = (n_state // self.n_head) ** -0.25
101 | q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
102 | k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
103 | v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
104 | qk = q @ k
105 |
106 | # apply causal mask
107 | if mask is not None:
108 |
109 | # kv_cache for beam search decoding case
110 | if kv_cache is not None and "beam_indices" in kv_cache.keys():
111 | for i in range(n_batch):
112 | qk[i] = qk[i] + mask[kv_cache["beam_indices"][1][i * n_ctx]:kv_cache["beam_indices"][1][i * n_ctx] + n_ctx, :k_ctx]
113 |
114 | # For training, encoder/decoder causal masks
115 | elif k_ctx == n_ctx:
116 | qk = qk + mask[:n_ctx, :n_ctx]
117 |
118 | # kv_cache in the greedy decoding case
119 | else:
120 | qk = qk + mask[k_ctx - n_ctx:k_ctx, :k_ctx]
121 |
122 | qk = qk.float()
123 |
124 | w = F.softmax(qk, dim=-1).to(q.dtype)
125 |
126 | return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
127 |
128 |
129 | class StreamingAudioEncoder(AudioEncoder):
130 |
131 | def __init__(self, n_mels, n_ctx, n_state, n_head, n_layer, cache_gran, gran, rank, extra_gran_blocks):
132 | super().__init__(n_mels, n_ctx, n_state, n_head, n_layer)
133 |
134 | self.gran = gran
135 | self.extra_gran_blocks = extra_gran_blocks
136 |
137 | for block in self.blocks:
138 | block.attn = LoRAMultiHeadAttention(self.n_head,
139 | block.attn.query,
140 | block.attn.key,
141 | block.attn.value,
142 | block.attn.out,
143 | rank)
144 |
145 | self.use_stream = False
146 |
147 | # mask for training
148 | matrix_size = n_ctx
149 | block_size = gran
150 | extra_blocks = extra_gran_blocks
151 | mask = torch.full((matrix_size, matrix_size), float('-inf'))
152 |
153 | for i in range(0, matrix_size, block_size):
154 | if (i // block_size) <= extra_blocks:
155 | zero_cols = (block_size * (extra_blocks + 1))
156 | else:
157 | zero_cols = (block_size * (extra_blocks + 1)) + ((i // block_size) - extra_blocks) * block_size
158 |
159 | mask[i:i + block_size, :zero_cols] = 1
160 |
161 | self.register_buffer("mask", mask, persistent=False)
162 |
163 | def _use_stream(self, use_stream: bool):
164 | self.use_stream = use_stream
165 |
166 | def forward(self, x: Tensor, index: list = [0, 1500], kv_cache = None, mask = True):
167 | """
168 | simulate streaming forward using qk cache self attn.
169 | """
170 | x = F.gelu(self.conv1(x))
171 | x = F.gelu(self.conv2(x))
172 | x = x.permute(0, 2, 1)
173 |
174 | if self.use_stream:
175 | offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
176 | x = x[:, offset:offset + self.gran + (int(offset == 0) * (self.extra_gran_blocks * self.gran))] # offset
177 | x = (x + self.positional_embedding[offset:offset + self.gran + (int(offset == 0) * (self.extra_gran_blocks * self.gran))]).to(x.dtype)
178 | else: # use during training
179 | x = x[:, index[0]:index[1]] # offset
180 | x = (x + self.positional_embedding[index[0]:index[1]]).to(x.dtype)
181 |
182 | for block in self.blocks:
183 | chosen_mask = mask[..., :index[1], :index[1]] if isinstance(mask, Tensor) else self.mask if (mask is not None) and (self.use_stream) else None
184 | x = block(x, mask=chosen_mask, kv_cache=kv_cache)
185 |
186 | x = self.ln_post(x)
187 |
188 | return x
189 |
190 |
191 | class StreamingTextDecoder(TextDecoder):
192 | def __init__(self, n_vocab, n_ctx, n_state, n_head, n_layer, rank):
193 | super().__init__(n_vocab, n_ctx, n_state, n_head, n_layer)
194 |
195 | self.n_ctx = n_ctx
196 | self.n_state = n_state
197 |
198 | for block in self.blocks:
199 | block.attn = LoRAMultiHeadAttention(n_head,
200 | block.attn.query,
201 | block.attn.key,
202 | block.attn.value,
203 | block.attn.out,
204 | rank)
205 |
206 | block.cross_attn = LoRAMultiHeadAttention(n_head,
207 | block.cross_attn.query,
208 | block.cross_attn.key,
209 | block.cross_attn.value,
210 | block.cross_attn.out,
211 | rank)
212 |
213 | def forward(self, x: Tensor, xa: Tensor, kv_cache: dict = None, dump_type: str = None, **kwargs):
214 | """
215 | x : torch.LongTensor, shape = (batch_size, <= n_ctx)
216 | the text tokens
217 | xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
218 | the encoded audio features to be attended on
219 | """
220 | if kv_cache is not None and "beam_indices" in kv_cache.keys():
221 | x = self.token_embedding(x) + self.positional_embedding.unsqueeze(0).expand(x.shape[0], self.positional_embedding.shape[0], self.positional_embedding.shape[1])[kv_cache["beam_indices"][0], kv_cache["beam_indices"][1]].view(x.shape[0], -1, self.n_state)
222 | else:
223 | offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
224 | x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
225 |
226 | x = x.to(xa.dtype)
227 |
228 | for block in self.blocks:
229 | x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
230 |
231 | x = self.ln(x)
232 | logits = (
233 | x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
234 | ).float()
235 |
236 | return logits
237 |
238 |
239 | class StreamingWhisper(Whisper):
240 | def __init__(self, dims, cache_gran: bool = True, gran: int = 16, rank: int = 0, extra_gran_blocks: int = 0):
241 | super().__init__(dims)
242 |
243 | self.cache_gran = cache_gran
244 | self.gran = gran
245 | self.rank = rank
246 | self.extra_gran_blocks = extra_gran_blocks
247 |
248 | print(f"Running a streaming whisper model, using chunk size: {gran * 20}[msec] and {extra_gran_blocks} extra chunks for initialization.")
249 |
250 | # The only difference is a streaming encoder
251 | self.encoder = StreamingAudioEncoder(
252 | self.dims.n_mels,
253 | self.dims.n_audio_ctx,
254 | self.dims.n_audio_state,
255 | self.dims.n_audio_head,
256 | self.dims.n_audio_layer,
257 | cache_gran=cache_gran,
258 | gran=gran,
259 | rank=rank,
260 | extra_gran_blocks=extra_gran_blocks
261 | )
262 |
263 | self.decoder = StreamingTextDecoder(
264 | self.dims.n_vocab,
265 | self.dims.n_text_ctx,
266 | self.dims.n_text_state,
267 | self.dims.n_text_head,
268 | self.dims.n_text_layer,
269 | rank=rank
270 | )
271 |
272 | # Advisor params - Dropped.
273 | self.advisor_type = None
274 | self.n_advisor_class = 0
275 | self.decoder_advisor = None
276 |
277 | self.decoding_task = None
278 | self.spec_streamer = SpectrogramStream()
279 |
280 | self.use_ca_cache_hook = True # relevant only when ca_kv_cache is installed
281 |
282 | def reset(self, use_stream: bool = True, clean_task: bool = True):
283 | self.encoder._use_stream(use_stream)
284 | del self.decoding_task # trigger clean encoder kv caching
285 | self.decoding_task = None
286 | self.spec_streamer.reset()
287 |
288 | @torch.no_grad()
289 | def decode(self, mel: Tensor, options: DecodingOptions = DecodingOptions(), use_frames: bool = False, **kwargs) -> DecodingResult:
290 | if kwargs: options = replace(options, **kwargs)
291 |
292 | if use_frames: # mel is frames of audio, need to calc mel
293 | mel = self.spec_streamer.calc_mel_with_new_frame(mel).squeeze(0)
294 |
295 | if self.encoder.gran != options.gran:
296 | print(f"Encoder gran & options gran differ. forcing options to be on encoder's gran: {self.encoder.gran}")
297 | options.gran = self.encoder.gran
298 |
299 | if not self.decoding_task:
300 | self.decoding_task = DecodingTask(self, options)
301 |
302 | return self.decoding_task.run(mel.unsqueeze(0))
303 |
304 | def _turn_off_lora(self):
305 | for _, layer in self.encoder.named_modules():
306 | if isinstance(layer, LoraLinearLayer):
307 | layer.turn_off_lora()
308 |
309 | def _turn_on_lora(self):
310 | for _, layer in self.encoder.named_modules():
311 | if isinstance(layer, LoraLinearLayer):
312 | layer.turn_on_lora()
313 |
314 | def _cancel_streaming_mode(self):
315 | self._turn_off_lora()
316 | self.encoder._use_stream(False)
317 |
318 | def _revert_streaming_mode(self):
319 | self._turn_on_lora()
320 | self.encoder._use_stream(True)
321 |
322 | @torch.no_grad()
323 | def non_causal_decode(self, mel: Tensor, options: DecodingOptions = DecodingOptions(), **kwargs) -> DecodingResult:
324 | self._cancel_streaming_mode()
325 | results = non_causal_decode_function(self, mel, options, **kwargs)
326 | self._revert_streaming_mode()
327 | return results
328 |
329 | def remove_encoder_kv_cache_hooks(self):
330 | for hook in self.encoder._forward_hooks.values():
331 | hook.remove()
332 |
333 | def install_encoder_kv_cache_hooks(self, cache = None):
334 | cache = {**cache} if cache is not None else {}
335 | hooks = []
336 |
337 | def save_to_cache(module, _, output):
338 | if module not in cache or output.shape[1] > self.dims.n_audio_ctx:
339 | # save as-is, for the first token or cross attention
340 | cache[module] = output
341 | else:
342 | cache[module] = torch.cat([cache[module], output], dim=1).detach()
343 | return cache[module]
344 |
345 | def install_hooks(layer: nn.Module):
346 | if isinstance(layer, LoRAMultiHeadAttention):
347 | hooks.append(layer.key.register_forward_hook(save_to_cache))
348 | hooks.append(layer.value.register_forward_hook(save_to_cache))
349 |
350 | self.encoder.apply(install_hooks)
351 | return cache, hooks
352 |
353 | def install_decoder_kv_cache_hooks(self, cache = None):
354 | cache = {**cache} if cache is not None else {}
355 | hooks = []
356 |
357 | def save_to_cache(module, _, output):
358 | if module not in cache or output.shape[1] > self.dims.n_text_ctx:
359 | cache[module] = output
360 | else:
361 | if "beam_indices" not in cache.keys() or all([index == (cache[module].shape[1]) for index in cache["beam_indices"][1]]):
362 | cache[module] = torch.cat([cache[module], output], dim=1).detach()
363 | else:
364 | for _, (beam, index, output_index) in enumerate(zip(*cache["beam_indices"])):
365 | if index < cache[module].shape[1]:
366 | cache[module][beam, index] = output[beam, output_index]
367 | else:
368 | cache[module] = torch.cat([cache[module], output[:, (output_index):]], dim=1).detach()
369 |
370 | return cache[module]
371 |
372 | for name, module in self.decoder.named_modules():
373 | if isinstance(module, LoRAMultiHeadAttention) and "attn" in name and "cross" not in name:
374 | hooks.append(module.key.register_forward_hook(save_to_cache))
375 | hooks.append(module.value.register_forward_hook(save_to_cache))
376 |
377 | return cache, hooks
378 |
379 | def install_cross_attn_kv_cache_hooks(self, cache=None):
380 | cache = {**cache} if cache is not None else {}
381 | hooks = []
382 |
383 | def save_to_cache(module, _, output):
384 | if not self.use_ca_cache_hook:
385 | return cache[module]
386 |
387 | if module not in cache or output.shape[1] > self.dims.n_audio_ctx:
388 | # save as-is, for the first token or cross attention
389 | cache[module] = output
390 | else:
391 | cache[module] = torch.cat([cache[module], output], dim=1).detach()
392 |
393 | return cache[module]
394 |
395 | def check_if_calculation_is_needed(module, _):
396 | if not self.use_ca_cache_hook:
397 | return cache[module]
398 |
399 | for name, module in self.decoder.named_modules():
400 | if isinstance(module, LoRAMultiHeadAttention) and "cross_attn" in name:
401 | hooks.append(module.key.register_forward_hook(save_to_cache))
402 | hooks.append(module.key.register_forward_pre_hook(check_if_calculation_is_needed))
403 | hooks.append(module.value.register_forward_hook(save_to_cache))
404 | hooks.append(module.value.register_forward_pre_hook(check_if_calculation_is_needed))
405 |
406 | return cache, hooks
407 |
408 | # For non-causal decoding compatibility
409 | def install_kv_cache_hooks(self, cache: Optional[dict] = None):
410 | """
411 | The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
412 | tensors calculated for the previous positions. This method returns a dictionary that stores
413 | all caches, and the necessary hooks for the key and value projection modules that save the
414 | intermediate tensors to be reused during later calculations.
415 |
416 | Returns
417 | -------
418 | cache : Dict[nn.Module, torch.Tensor]
419 | A dictionary object mapping the key/value projection modules to its cache
420 | hooks : List[RemovableHandle]
421 | List of PyTorch RemovableHandle objects to stop the hooks to be called
422 | """
423 | cache = {**cache} if cache is not None else {}
424 | hooks = []
425 |
426 | def save_to_cache(module, _, output):
427 | if module not in cache or output.shape[1] > self.dims.n_text_ctx:
428 | # save as-is, for the first token or cross attention
429 | cache[module] = output
430 | else:
431 | cache[module] = torch.cat([cache[module], output], dim=1).detach()
432 | return cache[module]
433 |
434 | def install_hooks(layer: nn.Module):
435 | if isinstance(layer, LoRAMultiHeadAttention):
436 | hooks.append(layer.key.register_forward_hook(save_to_cache))
437 | hooks.append(layer.value.register_forward_hook(save_to_cache))
438 |
439 | self.decoder.apply(install_hooks)
440 | return cache, hooks
441 |
442 | # refers to function from streaming_decoding, streaming_transcribe library
443 | transcribe = transcribe_function
444 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | 1. This work is built upon or includes components of "Whisper" code licensed under the MIT License as follows:
3 |
4 | MIT License
5 |
6 | Copyright (c) 2022 OpenAI
7 |
8 | Permission is hereby granted, free of charge, to any person obtaining a copy
9 | of this software and associated documentation files (the "Software"), to deal
10 | in the Software without restriction, including without limitation the rights
11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | copies of the Software, and to permit persons to whom the Software is
13 | furnished to do so, subject to the following conditions:
14 |
15 | The above copyright notice and this permission notice shall be included in all
16 | copies or substantial portions of the Software.
17 |
18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | SOFTWARE.
25 |
26 | 2. The modifications/work detailed in:
27 |
28 | work:
29 | training_code/collators.py
30 | training_code/datasets_classes.py
31 | training_code/ds_dict.py
32 | training_code/train.py
33 | training_code/utils.py
34 | training_code/whisper_module.py
35 |
36 | modifications/additions:
37 | careless_whisper_stream/__init__.py: function load_streaming_model, function load_streaming_model_for_train
38 | careless_whisper_stream/audio.py: class MyStream, class SpectrogramStream
39 | careless_whisper_stream/streaming_decoding.py: class StreamingDecoder, class BeamStreamingDecoder
40 | careless_whisper_stream/streaming_model.py: class StreamingAudioEncoder (kv_cache additions in forward), class StreamingTextDecoder (kv_cache for beam search in streaming), function install_decoder_kv_cache_hooks.
41 | careless_whisper_stream/streaming_transcribe.py
42 |
43 |
44 | is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License as follows:
45 |
46 | Attribution-NonCommercial 4.0 International
47 |
48 | =======================================================================
49 |
50 | Creative Commons Corporation ("Creative Commons") is not a law firm and
51 | does not provide legal services or legal advice. Distribution of
52 | Creative Commons public licenses does not create a lawyer-client or
53 | other relationship. Creative Commons makes its licenses and related
54 | information available on an "as-is" basis. Creative Commons gives no
55 | warranties regarding its licenses, any material licensed under their
56 | terms and conditions, or any related information. Creative Commons
57 | disclaims all liability for damages resulting from their use to the
58 | fullest extent possible.
59 |
60 | Using Creative Commons Public Licenses
61 |
62 | Creative Commons public licenses provide a standard set of terms and
63 | conditions that creators and other rights holders may use to share
64 | original works of authorship and other material subject to copyright
65 | and certain other rights specified in the public license below. The
66 | following considerations are for informational purposes only, are not
67 | exhaustive, and do not form part of our licenses.
68 |
69 | Considerations for licensors: Our public licenses are
70 | intended for use by those authorized to give the public
71 | permission to use material in ways otherwise restricted by
72 | copyright and certain other rights. Our licenses are
73 | irrevocable. Licensors should read and understand the terms
74 | and conditions of the license they choose before applying it.
75 | Licensors should also secure all rights necessary before
76 | applying our licenses so that the public can reuse the
77 | material as expected. Licensors should clearly mark any
78 | material not subject to the license. This includes other CC-
79 | licensed material, or material used under an exception or
80 | limitation to copyright. More considerations for licensors:
81 | wiki.creativecommons.org/Considerations_for_licensors
82 |
83 | Considerations for the public: By using one of our public
84 | licenses, a licensor grants the public permission to use the
85 | licensed material under specified terms and conditions. If
86 | the licensor's permission is not necessary for any reason--for
87 | example, because of any applicable exception or limitation to
88 | copyright--then that use is not regulated by the license. Our
89 | licenses grant only permissions under copyright and certain
90 | other rights that a licensor has authority to grant. Use of
91 | the licensed material may still be restricted for other
92 | reasons, including because others have copyright or other
93 | rights in the material. A licensor may make special requests,
94 | such as asking that all changes be marked or described.
95 | Although not required by our licenses, you are encouraged to
96 | respect those requests where reasonable. More considerations
97 | for the public:
98 | wiki.creativecommons.org/Considerations_for_licensees
99 |
100 | =======================================================================
101 |
102 | Creative Commons Attribution-NonCommercial 4.0 International Public
103 | License
104 |
105 | By exercising the Licensed Rights (defined below), You accept and agree
106 | to be bound by the terms and conditions of this Creative Commons
107 | Attribution-NonCommercial 4.0 International Public License ("Public
108 | License"). To the extent this Public License may be interpreted as a
109 | contract, You are granted the Licensed Rights in consideration of Your
110 | acceptance of these terms and conditions, and the Licensor grants You
111 | such rights in consideration of benefits the Licensor receives from
112 | making the Licensed Material available under these terms and
113 | conditions.
114 |
115 |
116 | Section 1 -- Definitions.
117 |
118 | a. Adapted Material means material subject to Copyright and Similar
119 | Rights that is derived from or based upon the Licensed Material
120 | and in which the Licensed Material is translated, altered,
121 | arranged, transformed, or otherwise modified in a manner requiring
122 | permission under the Copyright and Similar Rights held by the
123 | Licensor. For purposes of this Public License, where the Licensed
124 | Material is a musical work, performance, or sound recording,
125 | Adapted Material is always produced where the Licensed Material is
126 | synched in timed relation with a moving image.
127 |
128 | b. Adapter's License means the license You apply to Your Copyright
129 | and Similar Rights in Your contributions to Adapted Material in
130 | accordance with the terms and conditions of this Public License.
131 |
132 | c. Copyright and Similar Rights means copyright and/or similar rights
133 | closely related to copyright including, without limitation,
134 | performance, broadcast, sound recording, and Sui Generis Database
135 | Rights, without regard to how the rights are labeled or
136 | categorized. For purposes of this Public License, the rights
137 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
138 | Rights.
139 | d. Effective Technological Measures means those measures that, in the
140 | absence of proper authority, may not be circumvented under laws
141 | fulfilling obligations under Article 11 of the WIPO Copyright
142 | Treaty adopted on December 20, 1996, and/or similar international
143 | agreements.
144 |
145 | e. Exceptions and Limitations means fair use, fair dealing, and/or
146 | any other exception or limitation to Copyright and Similar Rights
147 | that applies to Your use of the Licensed Material.
148 |
149 | f. Licensed Material means the artistic or literary work, database,
150 | or other material to which the Licensor applied this Public
151 | License.
152 |
153 | g. Licensed Rights means the rights granted to You subject to the
154 | terms and conditions of this Public License, which are limited to
155 | all Copyright and Similar Rights that apply to Your use of the
156 | Licensed Material and that the Licensor has authority to license.
157 |
158 | h. Licensor means the individual(s) or entity(ies) granting rights
159 | under this Public License.
160 |
161 | i. NonCommercial means not primarily intended for or directed towards
162 | commercial advantage or monetary compensation. For purposes of
163 | this Public License, the exchange of the Licensed Material for
164 | other material subject to Copyright and Similar Rights by digital
165 | file-sharing or similar means is NonCommercial provided there is
166 | no payment of monetary compensation in connection with the
167 | exchange.
168 |
169 | j. Share means to provide material to the public by any means or
170 | process that requires permission under the Licensed Rights, such
171 | as reproduction, public display, public performance, distribution,
172 | dissemination, communication, or importation, and to make material
173 | available to the public including in ways that members of the
174 | public may access the material from a place and at a time
175 | individually chosen by them.
176 |
177 | k. Sui Generis Database Rights means rights other than copyright
178 | resulting from Directive 96/9/EC of the European Parliament and of
179 | the Council of 11 March 1996 on the legal protection of databases,
180 | as amended and/or succeeded, as well as other essentially
181 | equivalent rights anywhere in the world.
182 |
183 | l. You means the individual or entity exercising the Licensed Rights
184 | under this Public License. Your has a corresponding meaning.
185 |
186 |
187 | Section 2 -- Scope.
188 |
189 | a. License grant.
190 |
191 | 1. Subject to the terms and conditions of this Public License,
192 | the Licensor hereby grants You a worldwide, royalty-free,
193 | non-sublicensable, non-exclusive, irrevocable license to
194 | exercise the Licensed Rights in the Licensed Material to:
195 |
196 | a. reproduce and Share the Licensed Material, in whole or
197 | in part, for NonCommercial purposes only; and
198 |
199 | b. produce, reproduce, and Share Adapted Material for
200 | NonCommercial purposes only.
201 |
202 | 2. Exceptions and Limitations. For the avoidance of doubt, where
203 | Exceptions and Limitations apply to Your use, this Public
204 | License does not apply, and You do not need to comply with
205 | its terms and conditions.
206 |
207 | 3. Term. The term of this Public License is specified in Section
208 | 6(a).
209 |
210 | 4. Media and formats; technical modifications allowed. The
211 | Licensor authorizes You to exercise the Licensed Rights in
212 | all media and formats whether now known or hereafter created,
213 | and to make technical modifications necessary to do so. The
214 | Licensor waives and/or agrees not to assert any right or
215 | authority to forbid You from making technical modifications
216 | necessary to exercise the Licensed Rights, including
217 | technical modifications necessary to circumvent Effective
218 | Technological Measures. For purposes of this Public License,
219 | simply making modifications authorized by this Section 2(a)
220 | (4) never produces Adapted Material.
221 |
222 | 5. Downstream recipients.
223 |
224 | a. Offer from the Licensor -- Licensed Material. Every
225 | recipient of the Licensed Material automatically
226 | receives an offer from the Licensor to exercise the
227 | Licensed Rights under the terms and conditions of this
228 | Public License.
229 |
230 | b. No downstream restrictions. You may not offer or impose
231 | any additional or different terms or conditions on, or
232 | apply any Effective Technological Measures to, the
233 | Licensed Material if doing so restricts exercise of the
234 | Licensed Rights by any recipient of the Licensed
235 | Material.
236 |
237 | 6. No endorsement. Nothing in this Public License constitutes or
238 | may be construed as permission to assert or imply that You
239 | are, or that Your use of the Licensed Material is, connected
240 | with, or sponsored, endorsed, or granted official status by,
241 | the Licensor or others designated to receive attribution as
242 | provided in Section 3(a)(1)(A)(i).
243 |
244 | b. Other rights.
245 |
246 | 1. Moral rights, such as the right of integrity, are not
247 | licensed under this Public License, nor are publicity,
248 | privacy, and/or other similar personality rights; however, to
249 | the extent possible, the Licensor waives and/or agrees not to
250 | assert any such rights held by the Licensor to the limited
251 | extent necessary to allow You to exercise the Licensed
252 | Rights, but not otherwise.
253 |
254 | 2. Patent and trademark rights are not licensed under this
255 | Public License.
256 |
257 | 3. To the extent possible, the Licensor waives any right to
258 | collect royalties from You for the exercise of the Licensed
259 | Rights, whether directly or through a collecting society
260 | under any voluntary or waivable statutory or compulsory
261 | licensing scheme. In all other cases the Licensor expressly
262 | reserves any right to collect such royalties, including when
263 | the Licensed Material is used other than for NonCommercial
264 | purposes.
265 |
266 |
267 | Section 3 -- License Conditions.
268 |
269 | Your exercise of the Licensed Rights is expressly made subject to the
270 | following conditions.
271 |
272 | a. Attribution.
273 |
274 | 1. If You Share the Licensed Material (including in modified
275 | form), You must:
276 |
277 | a. retain the following if it is supplied by the Licensor
278 | with the Licensed Material:
279 |
280 | i. identification of the creator(s) of the Licensed
281 | Material and any others designated to receive
282 | attribution, in any reasonable manner requested by
283 | the Licensor (including by pseudonym if
284 | designated);
285 |
286 | ii. a copyright notice;
287 |
288 | iii. a notice that refers to this Public License;
289 |
290 | iv. a notice that refers to the disclaimer of
291 | warranties;
292 |
293 | v. a URI or hyperlink to the Licensed Material to the
294 | extent reasonably practicable;
295 |
296 | b. indicate if You modified the Licensed Material and
297 | retain an indication of any previous modifications; and
298 |
299 | c. indicate the Licensed Material is licensed under this
300 | Public License, and include the text of, or the URI or
301 | hyperlink to, this Public License.
302 |
303 | 2. You may satisfy the conditions in Section 3(a)(1) in any
304 | reasonable manner based on the medium, means, and context in
305 | which You Share the Licensed Material. For example, it may be
306 | reasonable to satisfy the conditions by providing a URI or
307 | hyperlink to a resource that includes the required
308 | information.
309 |
310 | 3. If requested by the Licensor, You must remove any of the
311 | information required by Section 3(a)(1)(A) to the extent
312 | reasonably practicable.
313 |
314 | 4. If You Share Adapted Material You produce, the Adapter's
315 | License You apply must not prevent recipients of the Adapted
316 | Material from complying with this Public License.
317 |
318 |
319 | Section 4 -- Sui Generis Database Rights.
320 |
321 | Where the Licensed Rights include Sui Generis Database Rights that
322 | apply to Your use of the Licensed Material:
323 |
324 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
325 | to extract, reuse, reproduce, and Share all or a substantial
326 | portion of the contents of the database for NonCommercial purposes
327 | only;
328 |
329 | b. if You include all or a substantial portion of the database
330 | contents in a database in which You have Sui Generis Database
331 | Rights, then the database in which You have Sui Generis Database
332 | Rights (but not its individual contents) is Adapted Material; and
333 |
334 | c. You must comply with the conditions in Section 3(a) if You Share
335 | all or a substantial portion of the contents of the database.
336 |
337 | For the avoidance of doubt, this Section 4 supplements and does not
338 | replace Your obligations under this Public License where the Licensed
339 | Rights include other Copyright and Similar Rights.
340 |
341 |
342 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
343 |
344 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
345 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
346 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
347 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
348 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
349 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
350 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
351 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
352 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
353 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
354 |
355 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
356 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
357 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
358 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
359 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
360 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
361 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
362 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
363 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
364 |
365 | c. The disclaimer of warranties and limitation of liability provided
366 | above shall be interpreted in a manner that, to the extent
367 | possible, most closely approximates an absolute disclaimer and
368 | waiver of all liability.
369 |
370 |
371 | Section 6 -- Term and Termination.
372 |
373 | a. This Public License applies for the term of the Copyright and
374 | Similar Rights licensed here. However, if You fail to comply with
375 | this Public License, then Your rights under this Public License
376 | terminate automatically.
377 |
378 | b. Where Your right to use the Licensed Material has terminated under
379 | Section 6(a), it reinstates:
380 |
381 | 1. automatically as of the date the violation is cured, provided
382 | it is cured within 30 days of Your discovery of the
383 | violation; or
384 |
385 | 2. upon express reinstatement by the Licensor.
386 |
387 | For the avoidance of doubt, this Section 6(b) does not affect any
388 | right the Licensor may have to seek remedies for Your violations
389 | of this Public License.
390 |
391 | c. For the avoidance of doubt, the Licensor may also offer the
392 | Licensed Material under separate terms or conditions or stop
393 | distributing the Licensed Material at any time; however, doing so
394 | will not terminate this Public License.
395 |
396 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
397 | License.
398 |
399 |
400 | Section 7 -- Other Terms and Conditions.
401 |
402 | a. The Licensor shall not be bound by any additional or different
403 | terms or conditions communicated by You unless expressly agreed.
404 |
405 | b. Any arrangements, understandings, or agreements regarding the
406 | Licensed Material not stated herein are separate from and
407 | independent of the terms and conditions of this Public License.
408 |
409 |
410 | Section 8 -- Interpretation.
411 |
412 | a. For the avoidance of doubt, this Public License does not, and
413 | shall not be interpreted to, reduce, limit, restrict, or impose
414 | conditions on any use of the Licensed Material that could lawfully
415 | be made without permission under this Public License.
416 |
417 | b. To the extent possible, if any provision of this Public License is
418 | deemed unenforceable, it shall be automatically reformed to the
419 | minimum extent necessary to make it enforceable. If the provision
420 | cannot be reformed, it shall be severed from this Public License
421 | without affecting the enforceability of the remaining terms and
422 | conditions.
423 |
424 | c. No term or condition of this Public License will be waived and no
425 | failure to comply consented to unless expressly agreed to by the
426 | Licensor.
427 |
428 | d. Nothing in this Public License constitutes or may be interpreted
429 | as a limitation upon, or waiver of, any privileges and immunities
430 | that apply to the Licensor or You, including from the legal
431 | processes of any jurisdiction or authority.
432 |
433 | =======================================================================
434 |
435 | Creative Commons is not a party to its public
436 | licenses. Notwithstanding, Creative Commons may elect to apply one of
437 | its public licenses to material it publishes and in those instances
438 | will be considered the "Licensor." The text of the Creative Commons
439 | public licenses is dedicated to the public domain under the CC0 Public
440 | Domain Dedication. Except for the limited purpose of indicating that
441 | material is shared under a Creative Commons public license or as
442 | otherwise permitted by the Creative Commons policies published at
443 | creativecommons.org/policies, Creative Commons does not authorize the
444 | use of the trademark "Creative Commons" or any other trademark or logo
445 | of Creative Commons without its prior written consent including,
446 | without limitation, in connection with any unauthorized modifications
447 | to any of its public licenses or any other arrangements,
448 | understandings, or agreements concerning use of licensed material. For
449 | the avoidance of doubt, this paragraph does not form part of the
450 | public licenses.
451 |
452 | Creative Commons may be contacted at creativecommons.org.
--------------------------------------------------------------------------------
/careless_whisper_stream/normalizers/english.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import re
4 | from fractions import Fraction
5 | from typing import Iterator, List, Match, Optional, Union
6 |
7 | from more_itertools import windowed
8 |
9 | from .basic import remove_symbols_and_diacritics
10 |
11 |
12 | class EnglishNumberNormalizer:
13 | """
14 | Convert any spelled-out numbers into arabic numbers, while handling:
15 |
16 | - remove any commas
17 | - keep the suffixes such as: `1960s`, `274th`, `32nd`, etc.
18 | - spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars`
19 | - spell out `one` and `ones`
20 | - interpret successive single-digit numbers as nominal: `one oh one` -> `101`
21 | """
22 |
23 | def __init__(self):
24 | super().__init__()
25 |
26 | self.zeros = {"o", "oh", "zero"}
27 | self.ones = {
28 | name: i
29 | for i, name in enumerate(
30 | [
31 | "one",
32 | "two",
33 | "three",
34 | "four",
35 | "five",
36 | "six",
37 | "seven",
38 | "eight",
39 | "nine",
40 | "ten",
41 | "eleven",
42 | "twelve",
43 | "thirteen",
44 | "fourteen",
45 | "fifteen",
46 | "sixteen",
47 | "seventeen",
48 | "eighteen",
49 | "nineteen",
50 | ],
51 | start=1,
52 | )
53 | }
54 | self.ones_plural = {
55 | "sixes" if name == "six" else name + "s": (value, "s")
56 | for name, value in self.ones.items()
57 | }
58 | self.ones_ordinal = {
59 | "zeroth": (0, "th"),
60 | "first": (1, "st"),
61 | "second": (2, "nd"),
62 | "third": (3, "rd"),
63 | "fifth": (5, "th"),
64 | "twelfth": (12, "th"),
65 | **{
66 | name + ("h" if name.endswith("t") else "th"): (value, "th")
67 | for name, value in self.ones.items()
68 | if value > 3 and value != 5 and value != 12
69 | },
70 | }
71 | self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal}
72 |
73 | self.tens = {
74 | "twenty": 20,
75 | "thirty": 30,
76 | "forty": 40,
77 | "fifty": 50,
78 | "sixty": 60,
79 | "seventy": 70,
80 | "eighty": 80,
81 | "ninety": 90,
82 | }
83 | self.tens_plural = {
84 | name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
85 | }
86 | self.tens_ordinal = {
87 | name.replace("y", "ieth"): (value, "th")
88 | for name, value in self.tens.items()
89 | }
90 | self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
91 |
92 | self.multipliers = {
93 | "hundred": 100,
94 | "thousand": 1_000,
95 | "million": 1_000_000,
96 | "billion": 1_000_000_000,
97 | "trillion": 1_000_000_000_000,
98 | "quadrillion": 1_000_000_000_000_000,
99 | "quintillion": 1_000_000_000_000_000_000,
100 | "sextillion": 1_000_000_000_000_000_000_000,
101 | "septillion": 1_000_000_000_000_000_000_000_000,
102 | "octillion": 1_000_000_000_000_000_000_000_000_000,
103 | "nonillion": 1_000_000_000_000_000_000_000_000_000_000,
104 | "decillion": 1_000_000_000_000_000_000_000_000_000_000_000,
105 | }
106 | self.multipliers_plural = {
107 | name + "s": (value, "s") for name, value in self.multipliers.items()
108 | }
109 | self.multipliers_ordinal = {
110 | name + "th": (value, "th") for name, value in self.multipliers.items()
111 | }
112 | self.multipliers_suffixed = {
113 | **self.multipliers_plural,
114 | **self.multipliers_ordinal,
115 | }
116 | self.decimals = {*self.ones, *self.tens, *self.zeros}
117 |
118 | self.preceding_prefixers = {
119 | "minus": "-",
120 | "negative": "-",
121 | "plus": "+",
122 | "positive": "+",
123 | }
124 | self.following_prefixers = {
125 | "pound": "£",
126 | "pounds": "£",
127 | "euro": "€",
128 | "euros": "€",
129 | "dollar": "$",
130 | "dollars": "$",
131 | "cent": "¢",
132 | "cents": "¢",
133 | }
134 | self.prefixes = set(
135 | list(self.preceding_prefixers.values())
136 | + list(self.following_prefixers.values())
137 | )
138 | self.suffixers = {
139 | "per": {"cent": "%"},
140 | "percent": "%",
141 | }
142 | self.specials = {"and", "double", "triple", "point"}
143 |
144 | self.words = set(
145 | [
146 | key
147 | for mapping in [
148 | self.zeros,
149 | self.ones,
150 | self.ones_suffixed,
151 | self.tens,
152 | self.tens_suffixed,
153 | self.multipliers,
154 | self.multipliers_suffixed,
155 | self.preceding_prefixers,
156 | self.following_prefixers,
157 | self.suffixers,
158 | self.specials,
159 | ]
160 | for key in mapping
161 | ]
162 | )
163 | self.literal_words = {"one", "ones"}
164 |
165 | def process_words(self, words: List[str]) -> Iterator[str]:
166 | prefix: Optional[str] = None
167 | value: Optional[Union[str, int]] = None
168 | skip = False
169 |
170 | def to_fraction(s: str):
171 | try:
172 | return Fraction(s)
173 | except ValueError:
174 | return None
175 |
176 | def output(result: Union[str, int]):
177 | nonlocal prefix, value
178 | result = str(result)
179 | if prefix is not None:
180 | result = prefix + result
181 | value = None
182 | prefix = None
183 | return result
184 |
185 | if len(words) == 0:
186 | return
187 |
188 | for prev, current, next in windowed([None] + words + [None], 3):
189 | if skip:
190 | skip = False
191 | continue
192 |
193 | next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next)
194 | has_prefix = current[0] in self.prefixes
195 | current_without_prefix = current[1:] if has_prefix else current
196 | if re.match(r"^\d+(\.\d+)?$", current_without_prefix):
197 | # arabic numbers (potentially with signs and fractions)
198 | f = to_fraction(current_without_prefix)
199 | assert f is not None
200 | if value is not None:
201 | if isinstance(value, str) and value.endswith("."):
202 | # concatenate decimals / ip address components
203 | value = str(value) + str(current)
204 | continue
205 | else:
206 | yield output(value)
207 |
208 | prefix = current[0] if has_prefix else prefix
209 | if f.denominator == 1:
210 | value = f.numerator # store integers as int
211 | else:
212 | value = current_without_prefix
213 | elif current not in self.words:
214 | # non-numeric words
215 | if value is not None:
216 | yield output(value)
217 | yield output(current)
218 | elif current in self.zeros:
219 | value = str(value or "") + "0"
220 | elif current in self.ones:
221 | ones = self.ones[current]
222 |
223 | if value is None:
224 | value = ones
225 | elif isinstance(value, str) or prev in self.ones:
226 | if (
227 | prev in self.tens and ones < 10
228 | ): # replace the last zero with the digit
229 | assert value[-1] == "0"
230 | value = value[:-1] + str(ones)
231 | else:
232 | value = str(value) + str(ones)
233 | elif ones < 10:
234 | if value % 10 == 0:
235 | value += ones
236 | else:
237 | value = str(value) + str(ones)
238 | else: # eleven to nineteen
239 | if value % 100 == 0:
240 | value += ones
241 | else:
242 | value = str(value) + str(ones)
243 | elif current in self.ones_suffixed:
244 | # ordinal or cardinal; yield the number right away
245 | ones, suffix = self.ones_suffixed[current]
246 | if value is None:
247 | yield output(str(ones) + suffix)
248 | elif isinstance(value, str) or prev in self.ones:
249 | if prev in self.tens and ones < 10:
250 | assert value[-1] == "0"
251 | yield output(value[:-1] + str(ones) + suffix)
252 | else:
253 | yield output(str(value) + str(ones) + suffix)
254 | elif ones < 10:
255 | if value % 10 == 0:
256 | yield output(str(value + ones) + suffix)
257 | else:
258 | yield output(str(value) + str(ones) + suffix)
259 | else: # eleven to nineteen
260 | if value % 100 == 0:
261 | yield output(str(value + ones) + suffix)
262 | else:
263 | yield output(str(value) + str(ones) + suffix)
264 | value = None
265 | elif current in self.tens:
266 | tens = self.tens[current]
267 | if value is None:
268 | value = tens
269 | elif isinstance(value, str):
270 | value = str(value) + str(tens)
271 | else:
272 | if value % 100 == 0:
273 | value += tens
274 | else:
275 | value = str(value) + str(tens)
276 | elif current in self.tens_suffixed:
277 | # ordinal or cardinal; yield the number right away
278 | tens, suffix = self.tens_suffixed[current]
279 | if value is None:
280 | yield output(str(tens) + suffix)
281 | elif isinstance(value, str):
282 | yield output(str(value) + str(tens) + suffix)
283 | else:
284 | if value % 100 == 0:
285 | yield output(str(value + tens) + suffix)
286 | else:
287 | yield output(str(value) + str(tens) + suffix)
288 | elif current in self.multipliers:
289 | multiplier = self.multipliers[current]
290 | if value is None:
291 | value = multiplier
292 | elif isinstance(value, str) or value == 0:
293 | f = to_fraction(value)
294 | p = f * multiplier if f is not None else None
295 | if f is not None and p.denominator == 1:
296 | value = p.numerator
297 | else:
298 | yield output(value)
299 | value = multiplier
300 | else:
301 | before = value // 1000 * 1000
302 | residual = value % 1000
303 | value = before + residual * multiplier
304 | elif current in self.multipliers_suffixed:
305 | multiplier, suffix = self.multipliers_suffixed[current]
306 | if value is None:
307 | yield output(str(multiplier) + suffix)
308 | elif isinstance(value, str):
309 | f = to_fraction(value)
310 | p = f * multiplier if f is not None else None
311 | if f is not None and p.denominator == 1:
312 | yield output(str(p.numerator) + suffix)
313 | else:
314 | yield output(value)
315 | yield output(str(multiplier) + suffix)
316 | else: # int
317 | before = value // 1000 * 1000
318 | residual = value % 1000
319 | value = before + residual * multiplier
320 | yield output(str(value) + suffix)
321 | value = None
322 | elif current in self.preceding_prefixers:
323 | # apply prefix (positive, minus, etc.) if it precedes a number
324 | if value is not None:
325 | yield output(value)
326 |
327 | if next in self.words or next_is_numeric:
328 | prefix = self.preceding_prefixers[current]
329 | else:
330 | yield output(current)
331 | elif current in self.following_prefixers:
332 | # apply prefix (dollars, cents, etc.) only after a number
333 | if value is not None:
334 | prefix = self.following_prefixers[current]
335 | yield output(value)
336 | else:
337 | yield output(current)
338 | elif current in self.suffixers:
339 | # apply suffix symbols (percent -> '%')
340 | if value is not None:
341 | suffix = self.suffixers[current]
342 | if isinstance(suffix, dict):
343 | if next in suffix:
344 | yield output(str(value) + suffix[next])
345 | skip = True
346 | else:
347 | yield output(value)
348 | yield output(current)
349 | else:
350 | yield output(str(value) + suffix)
351 | else:
352 | yield output(current)
353 | elif current in self.specials:
354 | if next not in self.words and not next_is_numeric:
355 | # apply special handling only if the next word can be numeric
356 | if value is not None:
357 | yield output(value)
358 | yield output(current)
359 | elif current == "and":
360 | # ignore "and" after hundreds, thousands, etc.
361 | if prev not in self.multipliers:
362 | if value is not None:
363 | yield output(value)
364 | yield output(current)
365 | elif current == "double" or current == "triple":
366 | if next in self.ones or next in self.zeros:
367 | repeats = 2 if current == "double" else 3
368 | ones = self.ones.get(next, 0)
369 | value = str(value or "") + str(ones) * repeats
370 | skip = True
371 | else:
372 | if value is not None:
373 | yield output(value)
374 | yield output(current)
375 | elif current == "point":
376 | if next in self.decimals or next_is_numeric:
377 | value = str(value or "") + "."
378 | else:
379 | # should all have been covered at this point
380 | raise ValueError(f"Unexpected token: {current}")
381 | else:
382 | # all should have been covered at this point
383 | raise ValueError(f"Unexpected token: {current}")
384 |
385 | if value is not None:
386 | yield output(value)
387 |
388 | def preprocess(self, s: str):
389 | # replace " and a half" with " point five"
390 | results = []
391 |
392 | segments = re.split(r"\band\s+a\s+half\b", s)
393 | for i, segment in enumerate(segments):
394 | if len(segment.strip()) == 0:
395 | continue
396 | if i == len(segments) - 1:
397 | results.append(segment)
398 | else:
399 | results.append(segment)
400 | last_word = segment.rsplit(maxsplit=2)[-1]
401 | if last_word in self.decimals or last_word in self.multipliers:
402 | results.append("point five")
403 | else:
404 | results.append("and a half")
405 |
406 | s = " ".join(results)
407 |
408 | # put a space at number/letter boundary
409 | s = re.sub(r"([a-z])([0-9])", r"\1 \2", s)
410 | s = re.sub(r"([0-9])([a-z])", r"\1 \2", s)
411 |
412 | # but remove spaces which could be a suffix
413 | s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s)
414 |
415 | return s
416 |
417 | def postprocess(self, s: str):
418 | def combine_cents(m: Match):
419 | try:
420 | currency = m.group(1)
421 | integer = m.group(2)
422 | cents = int(m.group(3))
423 | return f"{currency}{integer}.{cents:02d}"
424 | except ValueError:
425 | return m.string
426 |
427 | def extract_cents(m: Match):
428 | try:
429 | return f"¢{int(m.group(1))}"
430 | except ValueError:
431 | return m.string
432 |
433 | # apply currency postprocessing; "$2 and ¢7" -> "$2.07"
434 | s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s)
435 | s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s)
436 |
437 | # write "one(s)" instead of "1(s)", just for the readability
438 | s = re.sub(r"\b1(s?)\b", r"one\1", s)
439 |
440 | return s
441 |
442 | def __call__(self, s: str):
443 | s = self.preprocess(s)
444 | s = " ".join(word for word in self.process_words(s.split()) if word is not None)
445 | s = self.postprocess(s)
446 |
447 | return s
448 |
449 |
450 | class EnglishSpellingNormalizer:
451 | """
452 | Applies British-American spelling mappings as listed in [1].
453 |
454 | [1] https://www.tysto.com/uk-us-spelling-list.html
455 | """
456 |
457 | def __init__(self):
458 | mapping_path = os.path.join(os.path.dirname(__file__), "english.json")
459 | self.mapping = json.load(open(mapping_path))
460 |
461 | def __call__(self, s: str):
462 | return " ".join(self.mapping.get(word, word) for word in s.split())
463 |
464 |
465 | class EnglishTextNormalizer:
466 | def __init__(self):
467 | self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b"
468 | self.replacers = {
469 | # common contractions
470 | r"\bwon't\b": "will not",
471 | r"\bcan't\b": "can not",
472 | r"\blet's\b": "let us",
473 | r"\bain't\b": "aint",
474 | r"\by'all\b": "you all",
475 | r"\bwanna\b": "want to",
476 | r"\bgotta\b": "got to",
477 | r"\bgonna\b": "going to",
478 | r"\bi'ma\b": "i am going to",
479 | r"\bimma\b": "i am going to",
480 | r"\bwoulda\b": "would have",
481 | r"\bcoulda\b": "could have",
482 | r"\bshoulda\b": "should have",
483 | r"\bma'am\b": "madam",
484 | # contractions in titles/prefixes
485 | r"\bmr\b": "mister ",
486 | r"\bmrs\b": "missus ",
487 | r"\bst\b": "saint ",
488 | r"\bdr\b": "doctor ",
489 | r"\bprof\b": "professor ",
490 | r"\bcapt\b": "captain ",
491 | r"\bgov\b": "governor ",
492 | r"\bald\b": "alderman ",
493 | r"\bgen\b": "general ",
494 | r"\bsen\b": "senator ",
495 | r"\brep\b": "representative ",
496 | r"\bpres\b": "president ",
497 | r"\brev\b": "reverend ",
498 | r"\bhon\b": "honorable ",
499 | r"\basst\b": "assistant ",
500 | r"\bassoc\b": "associate ",
501 | r"\blt\b": "lieutenant ",
502 | r"\bcol\b": "colonel ",
503 | r"\bjr\b": "junior ",
504 | r"\bsr\b": "senior ",
505 | r"\besq\b": "esquire ",
506 | # prefect tenses, ideally it should be any past participles, but it's harder..
507 | r"'d been\b": " had been",
508 | r"'s been\b": " has been",
509 | r"'d gone\b": " had gone",
510 | r"'s gone\b": " has gone",
511 | r"'d done\b": " had done", # "'s done" is ambiguous
512 | r"'s got\b": " has got",
513 | # general contractions
514 | r"n't\b": " not",
515 | r"'re\b": " are",
516 | r"'s\b": " is",
517 | r"'d\b": " would",
518 | r"'ll\b": " will",
519 | r"'t\b": " not",
520 | r"'ve\b": " have",
521 | r"'m\b": " am",
522 | }
523 | self.standardize_numbers = EnglishNumberNormalizer()
524 | self.standardize_spellings = EnglishSpellingNormalizer()
525 |
526 | def __call__(self, s: str):
527 | s = s.lower()
528 |
529 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
530 | s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
531 | s = re.sub(self.ignore_patterns, "", s)
532 | s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe
533 |
534 | for pattern, replacement in self.replacers.items():
535 | s = re.sub(pattern, replacement, s)
536 |
537 | s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
538 | s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
539 | s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols
540 |
541 | s = self.standardize_numbers(s)
542 | s = self.standardize_spellings(s)
543 |
544 | # now remove prefix/suffix symbols that are not preceded/followed by numbers
545 | s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
546 | s = re.sub(r"([^0-9])%", r"\1 ", s)
547 |
548 | s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space
549 |
550 | return s
551 |
--------------------------------------------------------------------------------