├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── examples.py ├── requirements.txt ├── setup.py ├── tests ├── __init__.py ├── test_examples.py ├── test_slicemap.py ├── test_text_normalization.py └── test_whisper_asr.py └── transcription_diff ├── __init__.py ├── find_lang_match.py ├── number_normalization.py ├── slice_map.py ├── text_diff.py ├── text_normalization.py └── whisper_asr.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | *.pyc 3 | .vscode 4 | .env.dev 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Environments 39 | .env 40 | .venv 41 | env/ 42 | venv/ 43 | ENV/ 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Corentin Jemine 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | prune tests 2 | include LICENSE 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # transcription-diff 2 | A small python library to find differences between audio and transcriptions 3 | 4 | Example (audio as mp4 to allow an embed): 5 | 6 | https://github.com/CorentinJ/transcription-diff/assets/12038136/41fda0a8-92bb-46fe-a7b7-903ccfed3463 7 | 8 | ```python 9 | from transcription_diff.text_diff import transcription_diff, render_text_diff 10 | 11 | diff = transcription_diff("You can go pretty far in life if you're a perfect sphere in a vacuum", "sphere.mp4") 12 | print(render_text_diff(diff)) 13 | ``` 14 | 15 | ```diff 16 | ! Well 17 | You can go pretty far in life 18 | ! when 19 | + if 20 | you're a perfect sphere in a vacuum 21 | ``` 22 | 23 | ### Mechanism 24 | - The library relies on [openai-whisper](https://github.com/openai/whisper) to perform Audio Speech Recognition unguided by the transcription 25 | - It then compares the expected transcription to the output of Whisper, ignoring superfluous characters 26 | - It returns the output in a simple structure, keeping the original text format of the transcription 27 | 28 | ### Limitations 29 | - Only a single hypothesis is considered for the ASR output, leaving the possibility of missing a hypothesis that would satisfy the expected transcription 30 | - The ASR output is not in the phoneme space, making homophones prone to false positives 31 | - Rare words unknown to Whisper require to be explicitly passed to the function, and have no guarantee of being properly recognized by Whisper 32 | - Currently only annotates up to 30 seconds of audio per sample 33 | 34 | ## Installation 35 | `pip install transcription-diff` 36 | 37 | ## Short term TODOs 38 | - [ ] Phoneme-level comparison 39 | - [ ] User handling of model cache 40 | - [ ] Support for audios longer than 30s 41 | 42 | ## Long shot TODOs 43 | - [ ] More robust support for non-English languages 44 | - [ ] Inverse normalization support for less false positives 45 | -------------------------------------------------------------------------------- /examples.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | import librosa 5 | import torch 6 | 7 | from transcription_diff.text_diff import transcription_diff, render_text_diff 8 | 9 | 10 | logging.basicConfig(level="INFO", stream=sys.stdout) 11 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 12 | 13 | 14 | audio_fpath = librosa.example("libri2") 15 | wav, sr = librosa.core.load(audio_fpath) 16 | 17 | # We'll keep only a short audio to keep the demo concise 18 | # N.B.: without this step we could feed the audio file path directly to transcription_diff() 19 | cut_range = 2.5, 9.5 20 | wav = wav[int(sr * cut_range[0]):int(sr * cut_range[1])] 21 | correct_transcription = \ 22 | "It befell in the month of May, Queen Guenever called her knights of the Table Round and gave them warning." 23 | 24 | # # You can listen to the audio using this package, or by playing the file at 25 | # import sounddevice as sd 26 | # sd.play(wav, sr, blocking=True) 27 | 28 | 29 | # Running with all default parameters 30 | diff = transcription_diff(correct_transcription, wav, sr, device=DEVICE) 31 | print(render_text_diff(diff)) 32 | 33 | # Providing hints to custom words to whisper has a chance to make it transcribe that word 34 | diff = transcription_diff(correct_transcription, wav, sr, custom_words=["Guenever"], device=DEVICE) 35 | print(render_text_diff(diff)) 36 | 37 | # Increase the model size generally increases ASR accuracy 38 | diff = transcription_diff( 39 | correct_transcription, wav, sr, custom_words=["Guenever"], whisper_model_size=3, device=DEVICE 40 | ) 41 | print(render_text_diff(diff)) 42 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | minineedle>=3.0.0 2 | numpy>=1.18.0 3 | openai-whisper>=20230918 4 | librosa>=0.9.0 5 | langcodes>=3.0.0 6 | colorama>=0.4.3 7 | torch>=2.0.0 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import re 2 | from pathlib import Path 3 | 4 | from setuptools import setup, find_packages 5 | 6 | 7 | setup( 8 | name="transcription-diff", 9 | version=re.search(r"__version__\s+=\s+\"(.*)\"", Path("transcription_diff/__init__.py").read_text()).group(1), 10 | description="Speech to transcription comparison", 11 | long_description="A small python library to find differences between audio and transcriptions\n" 12 | "https://github.com/CorentinJ/transcription-diff/", 13 | author="Corentin Jemine", 14 | author_email="corentin.jemine@gmail.com", 15 | packages=find_packages(), 16 | platforms="any", 17 | python_requires=">=3.7", 18 | install_requires=Path("requirements.txt").read_text("utf-8").splitlines(), 19 | tests_require=["pytest>=7.0.0"], 20 | long_description_content_type="text/markdown", 21 | url="https://github.com/CorentinJ/transcription-diff", 22 | classifiers=[ 23 | "Development Status :: 4 - Beta", 24 | "License :: OSI Approved :: MIT License", 25 | "Operating System :: OS Independent", 26 | "Programming Language :: Python :: 3 :: Only", 27 | "Topic :: Multimedia :: Sound/Audio :: Speech", 28 | ], 29 | ) 30 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CorentinJ/transcription-diff/868f3c12f38b9446474ba79f570eb267ac48213f/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_examples.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import sys 3 | from pathlib import Path 4 | 5 | 6 | _ROOT_DIR = Path(__file__).parent.parent 7 | _EXAMPLES_FPATH = _ROOT_DIR / "examples.py" 8 | 9 | 10 | def test_run_examples(): 11 | """ 12 | Runs the examples script and verifies it does not crash or print anything to stderr. 13 | """ 14 | result = subprocess.run([sys.executable, _EXAMPLES_FPATH], capture_output=True, text=True) 15 | assert result.returncode == 0, result.stdout 16 | assert result.stderr == "", result.stdout 17 | -------------------------------------------------------------------------------- /tests/test_slicemap.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | from transcription_diff.slice_map import SliceMap 6 | 7 | 8 | class TestSliceMap(unittest.TestCase): 9 | def test_constructor(self): 10 | # Slices cannot have negative indices and cannot index beyond the size of Y 11 | with self.assertRaises(AssertionError): 12 | SliceMap([slice(0, -1)], 10) 13 | with self.assertRaises(AssertionError): 14 | SliceMap([slice(0, 1), slice(1, 2)], 1) 15 | 16 | # The slice starts cannot decrease, the same goes for the slice stops. You can however have consecutive 17 | # overlapping slices, e.g. [slice(0, 2), slice(0, 2)] 18 | with self.assertRaises(AssertionError): 19 | SliceMap([slice(1, 1), slice(0, 2)], 2) 20 | with self.assertRaises(AssertionError): 21 | SliceMap([slice(0, 2), slice(0, 1)], 2) 22 | SliceMap([slice(0, 2), slice(0, 2)], 2) 23 | 24 | # Slices can be empty (stop <= start) 25 | SliceMap([slice(2, 1)], 2) 26 | 27 | # Map to nothing 28 | SliceMap([slice(0, 0), slice(0, 0)], 0) 29 | 30 | # Map from nothing 31 | SliceMap([], 2) 32 | 33 | def test_getitem(self): 34 | # X[i] mapping to Y[smap[i]] 35 | m = SliceMap([slice(0, 1), slice(1, 2)], 2) 36 | self.assertEqual(m[1], slice(1, 2)) 37 | 38 | # X[i:j] mapping to Y[smap[i].start:smap[j - 1].stop] 39 | m = SliceMap([slice(0, 1), slice(0, 2), slice(3, 3)], 3) 40 | self.assertEqual(m[:1], slice(0, 1)) 41 | self.assertEqual(m[:2], slice(0, 2)) 42 | self.assertEqual(m[1:2], slice(0, 2)) 43 | self.assertEqual(m[1:3], slice(0, 3)) 44 | 45 | # Other tests 46 | m = SliceMap([slice(0, 1), slice(1, 1), slice(1, 2), slice(3, 4)], 5) 47 | seq = "abcd" 48 | 49 | # Step != 1 is not supported 50 | with self.assertRaises(AssertionError): 51 | m[1:2:2] 52 | 53 | # Index with an int is equivalent to slice of size 1 54 | for i in range(m.source_len): 55 | self.assertEqual(m[i], m[i:i + 1]) 56 | 57 | # Index that maps to nothing 58 | self.assertEqual(seq[m[1]], "") 59 | 60 | # Slicing includes gaps 61 | self.assertEqual(seq[m[2:]], "bcd") 62 | 63 | # Empty slices 64 | for i in range(m.source_len + 1): 65 | self.assertEqual(seq[m[i:i]], "") 66 | 67 | # Slices beyond the size of the map 68 | self.assertEqual(seq[m[10:20]], "") 69 | 70 | def test_project(self): 71 | # Length must match 72 | m = SliceMap([slice(0, 1), slice(1, 2), slice(2, 3)], 3) 73 | with self.assertRaises(AssertionError): 74 | m.project(list("ab")) 75 | with self.assertRaises(AssertionError): 76 | m.project(list("abcd")) 77 | 78 | # Empty 79 | m = SliceMap.empty() 80 | self.assertEqual(m.project([]), []) 81 | 82 | # Map to nothing 83 | m = SliceMap([slice(0, 0), slice(0, 0)], 0) 84 | self.assertEqual(m.project(list("ab")), []) 85 | 86 | # Map from nothing 87 | m = SliceMap([], 2) 88 | self.assertEqual(m.project([], "?"), list("??")) 89 | 90 | # Identity 91 | m = SliceMap([slice(0, 1), slice(1, 2), slice(2, 3)], 3) 92 | self.assertEqual(m.project(list("abc")), list("abc")) 93 | 94 | # Spanning multiple indices 95 | m = SliceMap([slice(0, 1), slice(1, 3), slice(3, 4)], 4) 96 | self.assertEqual(m.project(list("abc")), list("abbc")) 97 | 98 | # Gap in the target space 99 | m = SliceMap([slice(0, 1), slice(1, 1), slice(1, 2)], 2) 100 | self.assertEqual(m.project(list("abc")), list("ac")) 101 | 102 | # Gap in the source space 103 | m = SliceMap([slice(0, 1), slice(2, 3)], 3) 104 | self.assertEqual(m.project(list("ab"), "?"), list("a?b")) 105 | 106 | # Overlap 107 | m = SliceMap([slice(0, 1), slice(0, 1)], 1) 108 | self.assertEqual(m.project(list("ab")), list("b")) 109 | 110 | # Overlap that spans multiple characters 111 | m = SliceMap([slice(0, 2), slice(0, 2)], 2) 112 | self.assertEqual(m.project(list("ab")), list("bb")) 113 | 114 | # Overlap with step 115 | m = SliceMap([slice(0, 2), slice(1, 3)], 3) 116 | self.assertEqual(m.project(list("ab")), list("abb")) 117 | 118 | # Composition of the above 119 | m = SliceMap([slice(0, 1), slice(1, 1), slice(1, 2), slice(3, 5), slice(4, 6)], 7) 120 | self.assertEqual(m.project(list("abcde"), "?"), list("ac?dee?")) 121 | 122 | def test_inverse(self): 123 | # Empty 124 | self.assertEqual(SliceMap.empty(), SliceMap.empty().inverse()) 125 | 126 | # Map to nothing 127 | m = SliceMap([slice(0, 0), slice(0, 0)], 0) 128 | self.assertEqual(m, m.inverse().inverse()) 129 | 130 | # Map from nothing 131 | m = SliceMap([], 2) 132 | self.assertEqual(m, m.inverse().inverse()) 133 | 134 | # Identity 135 | m = SliceMap([slice(0, 1), slice(1, 2), slice(2, 3)], 3) 136 | self.assertEqual(m.inverse().project(list("abc")), list("abc")) 137 | self.assertEqual(m, m.inverse().inverse()) 138 | 139 | # Spanning multiple indices 140 | m = SliceMap([slice(0, 1), slice(1, 3), slice(3, 4)], 4) 141 | self.assertEqual(m.inverse().project(list("abbc")), list("abc")) 142 | self.assertEqual(m, m.inverse().inverse()) 143 | 144 | # Gap in the target space 145 | m = SliceMap([slice(0, 1), slice(1, 1), slice(1, 2)], 2) 146 | self.assertEqual(m.inverse().project(list("ac"), "?"), list("a?c")) 147 | self.assertEqual(m, m.inverse().inverse()) 148 | 149 | # Gap in the source space 150 | m = SliceMap([slice(0, 1), slice(2, 3)], 3) 151 | self.assertEqual(m.inverse().project(list("a?b")), list("ab")) 152 | self.assertEqual(m, m.inverse().inverse()) 153 | 154 | # Composition of the above 155 | m = SliceMap([slice(0, 1), slice(1, 1), slice(1, 2), slice(3, 4)], 5) 156 | self.assertEqual(m.inverse().project(list("ac?d?"), "?"), list("a?cd")) 157 | self.assertEqual(m, m.inverse().inverse()) 158 | 159 | # Overlap 160 | m = SliceMap([slice(0, 1), slice(0, 1)], 1) 161 | self.assertEqual(m.inverse().project(list("a")), list("aa")) 162 | self.assertEqual(m.inverse(), SliceMap([slice(0, 2)], 2)) 163 | self.assertEqual(m, m.inverse().inverse()) 164 | 165 | # Overlap that spans multiple characters 166 | m = SliceMap([slice(0, 2), slice(0, 2)], 2) 167 | self.assertEqual(m.inverse().project(list("ab")), list("bb")) 168 | self.assertEqual(m.inverse(), SliceMap([slice(0, 2), slice(0, 2)], 2)) 169 | self.assertEqual(m, m.inverse().inverse()) 170 | 171 | # Overlap with step 172 | m = SliceMap([slice(0, 2), slice(1, 3)], 3) 173 | self.assertEqual(m.inverse().project(list("abc")), list("bc")) 174 | self.assertEqual(m.inverse(), SliceMap([slice(0, 1), slice(0, 2), slice(1, 2)], 2)) 175 | self.assertEqual(m, m.inverse().inverse()) 176 | 177 | # Composition of the above 178 | m = SliceMap([slice(0, 1), slice(1, 1), slice(1, 2), slice(3, 5), slice(4, 6)], 7) 179 | self.assertEqual(m.inverse().project(list("abcdefg"), "?"), list("a?bef")) 180 | self.assertEqual(m.inverse(), SliceMap( 181 | [slice(0, 1), slice(2, 3), slice(3, 3), slice(3, 4), slice(3, 5), slice(4, 5), slice(5, 5)], 5 182 | )) 183 | self.assertEqual(m, m.inverse().inverse()) 184 | 185 | def test_compose(self): 186 | # Empty 187 | self.assertEqual(SliceMap.empty() * SliceMap.empty(), SliceMap.empty()) 188 | 189 | # Map to nothing 190 | m1 = SliceMap([slice(0, 1), slice(1, 2)], 2) 191 | m2 = SliceMap([slice(0, 0), slice(0, 0)], 0) 192 | self.assertEqual(m1 * m2, m2) 193 | 194 | # Map from nothing 195 | m1 = SliceMap([], 2) 196 | m2 = SliceMap([slice(0, 1), slice(1, 2)], 2) 197 | self.assertEqual(m1 * m2, m1) 198 | 199 | # Identity 200 | m = SliceMap([slice(0, 1), slice(1, 2), slice(2, 3)], 3) 201 | self.assertEqual(m * m, m) 202 | 203 | # Gap in the source space, gap in the source space 204 | m1 = SliceMap([slice(0, 1), slice(2, 3)], 3) 205 | m2 = SliceMap([slice(1, 2), slice(2, 3), slice(3, 4)], 4) 206 | self.assertEqual((m1 * m2).project(list("ab"), "?"), list("?a?b")) 207 | 208 | # Gap in the source space, gap in the target space 209 | m1 = SliceMap([slice(0, 1), slice(2, 3)], 3) 210 | m2 = SliceMap([slice(0, 1), slice(1, 1), slice(1, 2)], 2) 211 | self.assertEqual((m1 * m2).project(list("ab"), "?"), list("ab")) 212 | 213 | # Gap in the target space, gap in the target space 214 | m1 = SliceMap([slice(0, 1), slice(1, 1), slice(1, 2)], 2) 215 | m2 = SliceMap([slice(0, 0), slice(0, 1)], 1) 216 | self.assertEqual((m1 * m2).project(list("abc"), "?"), list("c")) 217 | 218 | # Gap in the target space, gap in the source space 219 | m1 = SliceMap([slice(0, 1), slice(1, 1), slice(1, 2)], 2) 220 | m2 = SliceMap([slice(0, 1), slice(2, 3)], 3) 221 | self.assertEqual((m1 * m2).project(list("abc"), "?"), list("a?c")) 222 | 223 | # Gap in the source space, overlap 224 | m1 = SliceMap([slice(0, 1), slice(2, 3)], 3) 225 | m2 = SliceMap([slice(0, 2), slice(0, 2), slice(0, 2)], 2) 226 | self.assertEqual((m1 * m2).project(list("ac"), "?"), list("cc")) 227 | 228 | # Gap in the target space, overlap 229 | m1 = SliceMap([slice(0, 1), slice(1, 1), slice(1, 2)], 2) 230 | m2 = SliceMap([slice(0, 2), slice(0, 2)], 2) 231 | self.assertEqual((m1 * m2).project(list("abc"), "?"), list("cc")) 232 | 233 | # Gap at the start and end of the source space 234 | m = SliceMap([slice(0, 1), slice(1, 1)], 1) 235 | self.assertEqual(m * SliceMap.lerp(1, 1), m) 236 | m = SliceMap([slice(0, 0), slice(0, 1)], 1) 237 | self.assertEqual(m * SliceMap.lerp(1, 1), m) 238 | 239 | # Overlap, gap in the source space 240 | m1 = SliceMap([slice(0, 2), slice(0, 2), slice(0, 2)], 2) 241 | m2 = SliceMap([slice(0, 1), slice(2, 3)], 3) 242 | self.assertEqual((m1 * m2).project(list("abc"), "?"), list("ccc")) 243 | 244 | # Overlap, gap in the target space 245 | m1 = SliceMap([slice(0, 2), slice(0, 2), slice(0, 2)], 3) 246 | m2 = SliceMap([slice(0, 1), slice(1, 1), slice(1, 2)], 2) 247 | self.assertEqual((m1 * m2).project(list("abc"), "?"), list("c?")) 248 | 249 | # Overlap, overlap 250 | m1 = SliceMap([slice(0, 2), slice(0, 2)], 2) 251 | m2 = SliceMap([slice(1, 3), slice(1, 3)], 3) 252 | self.assertEqual(m1 * m2, m2) 253 | 254 | def test_concat(self): 255 | # Empty 256 | self.assertEqual(SliceMap.empty() + SliceMap.empty(), SliceMap.empty()) 257 | 258 | # Map to nothing 259 | m1 = SliceMap([slice(0, 1), slice(1, 2)], 2) 260 | m2 = SliceMap([slice(0, 0), slice(0, 0)], 0) 261 | self.assertEqual(m1 + m2, SliceMap([slice(0, 1), slice(1, 2), slice(2, 2), slice(2, 2)], 2)) 262 | self.assertEqual(m2 + m1, SliceMap([slice(0, 0), slice(0, 0), slice(0, 1), slice(1, 2)], 2)) 263 | 264 | # Map from nothing 265 | m1 = SliceMap([], 2) 266 | m2 = SliceMap([slice(0, 1), slice(1, 2)], 2) 267 | self.assertEqual(m1 + m2, SliceMap([slice(2, 3), slice(3, 4)], 4)) 268 | self.assertEqual(m2 + m1, SliceMap([slice(0, 1), slice(1, 2)], 4)) 269 | 270 | # Identity 271 | m1 = SliceMap([slice(0, 1), slice(1, 2)], 2) 272 | m2 = SliceMap([slice(0, 1), slice(1, 2), slice(2, 3), slice(3, 4)], 4) 273 | self.assertEqual(m1 + m1, m2) 274 | 275 | def test_lerp(self): 276 | # Empty 277 | self.assertEqual(SliceMap.lerp(0, 0), SliceMap.empty()) 278 | 279 | # Map to nothing 280 | self.assertEqual(SliceMap.lerp(1, 0), SliceMap([slice(0, 0)], 0)) 281 | 282 | # Map from nothing 283 | self.assertEqual(SliceMap.lerp(0, 1), SliceMap([], 1)) 284 | 285 | # Identity 286 | m = SliceMap([slice(0, 1), slice(1, 1), slice(1, 2), slice(3, 5), slice(4, 6)], 7) 287 | self.assertEqual(m * SliceMap.lerp(7, 7), m) 288 | 289 | # Ensure the spread is even 290 | for i in range(1, 20): 291 | for j in range(1, 20): 292 | m = SliceMap.lerp(i, j) 293 | idx = np.arange(i) 294 | counts = np.zeros(j, dtype=int) 295 | for k in range(i): 296 | for l in idx[m[k]]: 297 | counts[l] += 1 298 | self.assertGreaterEqual(counts.min() + 1, counts.max()) 299 | -------------------------------------------------------------------------------- /tests/test_text_normalization.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from transcription_diff.slice_map import SliceMap 4 | from transcription_diff.text_normalization import normalize_text, expand_abbreviations, standardize_characters 5 | 6 | 7 | def test_unchanged_text(): 8 | raw_text = "this text is already normalized" 9 | norm_text, raw2norm = normalize_text(raw_text, "en-us") 10 | assert raw_text == norm_text 11 | assert raw2norm == SliceMap.identity(len(raw_text)) 12 | 13 | 14 | def test_edge_cases(): 15 | norm_text, raw2norm = normalize_text("", "en-us") 16 | assert norm_text == "" 17 | assert raw2norm == SliceMap.empty() 18 | 19 | norm_text, raw2norm = normalize_text(" ", "en-us") 20 | assert norm_text == " " 21 | assert raw2norm == SliceMap.identity(1) 22 | 23 | norm_text, raw2norm = normalize_text(".", "en-us") 24 | assert norm_text == "" 25 | assert raw2norm == SliceMap.lerp(1, 0) 26 | 27 | norm_text, raw2norm = normalize_text(" " * 3, "en-us") 28 | assert norm_text == " " 29 | assert raw2norm == SliceMap.lerp(3, 1) 30 | 31 | norm_text, raw2norm = normalize_text(". . .", "en-us") 32 | assert norm_text == " " 33 | assert raw2norm == SliceMap([slice(0, 0), slice(0, 1), slice(0, 1), slice(0, 1), slice(1, 1)], 1) 34 | 35 | 36 | def test_abbreviation_expansion(): 37 | norm_text, raw2norm = expand_abbreviations("Hi there dr. House") 38 | assert norm_text == "Hi there doctor House" 39 | assert raw2norm == SliceMap.identity(9) + SliceMap.lerp(3, 6) + SliceMap.identity(6) 40 | 41 | norm_text, raw2norm = expand_abbreviations("Hey, jr.! Are you coming jr.?") 42 | assert norm_text == "Hey, junior! Are you coming junior?" 43 | assert raw2norm == SliceMap.identity(5) + SliceMap.lerp(3, 6) + SliceMap.identity(17) + \ 44 | SliceMap.lerp(3, 6) + SliceMap.identity(1) 45 | 46 | norm_text, raw2norm = expand_abbreviations("So it goes oct., nov., dec.... Wait, what's after oct.?") 47 | assert norm_text == "So it goes october, november, december... Wait, what's after october?" 48 | assert raw2norm == \ 49 | SliceMap.identity(11) + SliceMap.lerp(4, 7) + \ 50 | SliceMap.identity(2) + SliceMap.lerp(4, 8) + \ 51 | SliceMap.identity(2) + SliceMap.lerp(4, 8) + \ 52 | SliceMap.identity(23) + SliceMap.lerp(4, 7) + \ 53 | SliceMap.identity(1) 54 | 55 | 56 | @pytest.mark.parametrize( 57 | "text_in, text_out", 58 | [ 59 | ("Hello world!", "Hello world!"), 60 | ("é", "é"), 61 | ("👀", "👀"), 62 | 63 | ("ℍ", "H"), 64 | ("①", "1"), 65 | ("¼", "1⁄4"), 66 | ] 67 | ) 68 | def test_character_standardization(text_in: str, text_out: str): 69 | actual_text_out = "".join(part for part, _ in standardize_characters(text_in)) 70 | assert actual_text_out == text_out 71 | -------------------------------------------------------------------------------- /tests/test_whisper_asr.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | import pytest 4 | 5 | from transcription_diff.whisper_asr import whisper_asr 6 | 7 | 8 | @pytest.mark.parametrize("audio_lang", ["en-gb", "fr-CA", None]) 9 | @pytest.mark.parametrize("custom_words", [[], ["butterfly"]]) 10 | @pytest.mark.parametrize("whisper_model_params", [ 11 | dict(whisper_model_size=1, device="cpu"), 12 | dict(whisper_model_size=2, device="cuda"), 13 | ]) 14 | def test_whisper_asr_args(audio_lang, custom_words, whisper_model_params): 15 | # Single in-memory input 16 | sample_rate = 32000 17 | wav = np.random.randn(sample_rate * 4) 18 | asr_out, audio_lang_out = whisper_asr( 19 | wav, sample_rate, audio_lang=audio_lang, **whisper_model_params, custom_words=custom_words 20 | ) 21 | assert audio_lang is None or audio_lang_out == audio_lang[:2] 22 | 23 | # Multiple in-memory input 24 | sample_rate = 22500 25 | wavs = [np.random.randn(sample_rate * 4) for _ in range(3)] 26 | asr_out, audio_lang_out = whisper_asr( 27 | wavs, sample_rate, audio_lang=audio_lang, **whisper_model_params, custom_words=custom_words 28 | ) 29 | assert audio_lang is None or audio_lang_out == audio_lang[:2] 30 | 31 | # One file on disk 32 | asr_out, audio_lang_out = whisper_asr( 33 | librosa.example("libri1"), audio_lang=audio_lang, **whisper_model_params, custom_words=custom_words 34 | ) 35 | assert audio_lang is None or audio_lang_out == audio_lang[:2] 36 | 37 | # Multiple files on disk 38 | asr_out, audio_lang_out = whisper_asr( 39 | [librosa.example("libri1"), librosa.example("libri2"), librosa.example("libri3")], 40 | audio_lang=audio_lang, **whisper_model_params, custom_words=custom_words 41 | ) 42 | assert audio_lang is None or audio_lang_out == audio_lang[:2] 43 | -------------------------------------------------------------------------------- /transcription_diff/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.3" 2 | -------------------------------------------------------------------------------- /transcription_diff/find_lang_match.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | 3 | from langcodes import Language 4 | 5 | 6 | def find_lang_match( 7 | req_lang: Union[str, Language], avail_langs: Union[List[str], List[Language]], territory_match=False, 8 | ) -> List[int]: 9 | """ 10 | Find the best match for a requested language in a list of available languages. 11 | 12 | This method uses the langcode library to deal with the many ways of specifiying languages and the many variations 13 | they can have. See https://pypi.org/project/langcodes/ for more information. 14 | 15 | :param req_lang: the language requested, as a language code or Language instance 16 | :param avail_langs: the list of available languages, as language codes or Language instances 17 | :param territory_match: whether to also match the territory (~= accent) of the language. 18 | - If has a territory specified and this argument is True, only entries that specifically match the 19 | requested territory will be considered. 20 | - In any other case, the territory will be ignored and only the language will be considered. 21 | :return: a list of indices of the qualifying matches in , possibly empty. All matches are considered 22 | equally good. 23 | """ 24 | if isinstance(req_lang, str): 25 | req_lang = Language.get(req_lang) 26 | if isinstance(avail_langs[0], str): 27 | avail_langs = [Language.get(lang) for lang in avail_langs] 28 | 29 | # Filter languages that don't match the requested language 30 | match_idx = [i for i, avail_lang in enumerate(avail_langs) if avail_lang.language == req_lang.language] 31 | 32 | # Also filter by territory if applicable 33 | if territory_match and req_lang.territory is not None: 34 | match_idx = [i for i in match_idx if avail_langs[i].territory == req_lang.territory] 35 | 36 | return match_idx 37 | -------------------------------------------------------------------------------- /transcription_diff/number_normalization.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from transcription_diff.slice_map import SliceMap 4 | 5 | 6 | # This is a set of trivial functions for number normalization. This module does not cover all cases but it's simple 7 | # enough and supports texts mappings. 8 | 9 | 10 | _comma_number_re = re.compile(r'(\(?[A-Z]{2,3})?([\$|£|¥|€|#|\(]*[0-9][0-9\,\.]+[0-9])([^\s]+)?') 11 | _decimal_number_re = re.compile(r'(number\s)?([0-9]+\.[0-9]+)(\.|,|\?|!)?') 12 | _hash_number_re = re.compile(r'(#)([0-9]+(?:\.[0-9]+)?)(\.|,|\?|!)?') 13 | 14 | # currencies 15 | _pounds_re = re.compile(r'(\(?£)([0-9\.]*[0-9]+)(\.|,|\?|\!)?') 16 | _yen_re = re.compile(r'(\(?¥)([0-9]+)(\.|,|\?|\!)?') 17 | _euro_re = re.compile(r'(\(?€)([0-9\.]*[0-9]+)(\.|,|\?|\!)?') 18 | _dollars_re = re.compile(r'(\(?\$)([0-9,]*\.?[0-9]+)([\.|,|\?|\!|\)]+)?') 19 | 20 | # currency with abbreviated unit (e.g. B, K, M) 21 | _curr_abbrev_re = re.compile(r'(\(?[$£¥€])([0-9]*\.?[0-9]+)([BKMT]| [BMbmTtr]+illion)([\.|,|\?|\!|\)]+)?') 22 | 23 | # units 24 | _ml_re = re.compile(r'([0-9\.]*[0-9]+)(ml)(\.|,|\?|!)?') 25 | _cl_re = re.compile(r'([0-9\.]*[0-9]+)(cl)(\.|,|\?|!)?') 26 | _g_re = re.compile(r'([0-9\.]*[0-9]+)(g)(\.|,|\?|!)?') 27 | _l_re = re.compile(r'([0-9\.]*[0-9]+)(l)(\.|,|\?|!)?') 28 | _m_re = re.compile(r'([0-9\.]*[0-9]+)(m)(\.|,|\?|!)?') 29 | _kg_re = re.compile(r'([0-9\.]*[0-9]+)(kg)(\.|,|\?|!)?') 30 | _mm_re = re.compile(r'([0-9\.]*[0-9]+)(mm)(\.|,|\?|!)?') 31 | _cm_re = re.compile(r'([0-9\.]*[0-9]+)(cm)(\.|,|\?|!)?') 32 | _km_re = re.compile(r'([0-9\.]*[0-9]+)(km)(\.|,|\?|!)?') 33 | _in_re = re.compile(r'([0-9\.]*[0-9]+)(in)(\.|,|\?|!)?') 34 | _ft_re = re.compile(r'([0-9\.]*[0-9]+)(ft)(\.|,|\?|!)?') 35 | _yd_re = re.compile(r'([0-9\.]*[0-9]+)(yd[s]?)(\.|,|\?|!)?') 36 | _s_re = re.compile(r'([0-9\.]*[0-9]+)(s[ecs]*)(\.|,|\?|!)?') 37 | 38 | _ordinal_re = re.compile(r'([0-9]+)(st|nd|rd|th)') 39 | _number_re = re.compile(r'([0-9]+)(\.|,|\?|!)?') 40 | _year_re = re.compile(r'([Ff]rom|[Aa]fter|[Bb]efore|[Bb]y|[Uu]ntil)(\s)(? 2: 171 | out = match + ' dollars' # Unexpected format 172 | dollars = int(parts[0]) if parts[0] else 0 173 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 174 | if dollars and cents: 175 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 176 | cent_unit = 'cent' if cents == 1 else 'cents' 177 | out = '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 178 | elif dollars: 179 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 180 | out = '%s %s' % (dollars, dollar_unit) 181 | elif cents: 182 | cent_unit = 'cent' if cents == 1 else 'cents' 183 | out = '%s %s' % (cents, cent_unit) 184 | else: 185 | out = 'zero dollars' 186 | 187 | # append any following punctuation 188 | if len(m) == 3: 189 | out = out + m[2] 190 | 191 | # compute the SliceMap between raw and normalised 192 | r = [''.join(m), out] 193 | for i, t in enumerate(mapping): 194 | if t[0] == r[0]: 195 | mapping[i] = (r[1], SliceMap.lerp(t[1].source_len, len(r[1]))) 196 | 197 | text_out = ''.join([t[0] for t in mapping]) 198 | 199 | return text_out, mapping 200 | 201 | 202 | def _expand_other_currency(text, mapping, regex, one, many): 203 | match = re.findall(regex, text) 204 | for m in match: 205 | parts = m[1].split(".") 206 | curr = one if int(parts[0]) == 1 else many 207 | try: 208 | out = parts[0] + " " + curr + " " + parts[1] 209 | except IndexError: 210 | out = parts[0] + " " + curr 211 | r = [''.join(m), out] 212 | 213 | # compute the SliceMap 214 | for i, t in enumerate(mapping): 215 | if t[0] == r[0]: 216 | mapping[i] = (r[1], SliceMap.lerp(t[1].source_len, len(r[1]))) 217 | 218 | text_out = ''.join([t[0] for t in mapping]) 219 | 220 | return text_out, mapping 221 | 222 | 223 | def _expand_abbreviated_currency_unit(text, mapping): 224 | match = re.findall(_curr_abbrev_re, text) 225 | # fill this with duplicate words if required 226 | to_remove = [] 227 | 228 | for m in match: 229 | curr, val, unit, punc = m 230 | 231 | # remove leading paranthesis from currency 232 | curr = curr.strip("(") 233 | 234 | # split decimal and expand post-decimal digits 235 | val_parts = val.split(".") 236 | if len(val_parts) > 1: 237 | val_out = val_parts[0] + ' ' + 'point ' + ' '.join(val_parts[1]) 238 | else: 239 | val_out = val_parts[0] 240 | 241 | # reorder elements 242 | try: 243 | out = ' '.join([val_out, unit_dict[unit], curr_dict[curr]]) 244 | except KeyError: 245 | out = ' '.join([val_out, unit, curr_dict[curr]]) 246 | out = out + punc 247 | 248 | ## create raw-clean mapping 249 | r = [''.join(m), out] 250 | 251 | for i, t in enumerate(mapping): 252 | if t[0] == r[0]: 253 | mapping[i] = (r[1], SliceMap.lerp(t[1].source_len, len(r[1]))) 254 | 255 | # deal with mapping across multiple words 256 | try: 257 | # check for where the digit is followed by 258 | # e.g ' billion' 259 | join_text = ''.join([mapping[i][0], mapping[i+1][0], mapping[i+2][0]]) 260 | 261 | if join_text == r[0]: 262 | 263 | # generate SliceMap of shape 264 | # SliceMap(raw_length, clean_length) 265 | m = SliceMap.lerp(len(r[0]), len(r[1])) 266 | mapping[i] = (out, m) 267 | 268 | # since upcoming words have been added to 269 | # raw-clean mapping, prepare for upcoming 270 | # words to be removed from mapping array 271 | to_remove.append(i+1) 272 | to_remove.append(i+2) 273 | except IndexError: 274 | continue 275 | 276 | # remove the duplicate words from mapping array 277 | new_mapping = [v for i, v in enumerate(mapping) if i not in to_remove] 278 | text_out = ''.join([t[0] for t in new_mapping]) 279 | return text_out, new_mapping 280 | 281 | 282 | def _expand_other_unit(text, mapping, regex, one, many): 283 | match = re.findall(regex, text) 284 | for m in match: 285 | # check number for decimal point 286 | parts = re.split("\.", m[0]) 287 | ## determine if plural unit is needed 288 | unit = one if parts[0] == "1" else many 289 | 290 | ## check for decimals and read separately 291 | if len(parts) > 1: 292 | dec = ''.join([i + " " for i in parts[1]]) 293 | 294 | ## convert decimal into point inside function 295 | ## to maintain the mapping 296 | parts = parts[0] + " " + "point" + " " + dec 297 | 298 | ## always use plural unit if theres a decimal 299 | unit = many 300 | out = parts + unit 301 | else: 302 | out = parts[0] + " " + unit 303 | 304 | out = out + m[-1] 305 | ## create raw-clean mapping 306 | r = [''.join(m), out] 307 | 308 | for i, t in enumerate(mapping): 309 | if t[0] == r[0]: 310 | mapping[i] = (r[1], SliceMap.lerp(t[1].source_len, len(r[1]))) 311 | 312 | text_out = ''.join([t[0] for t in mapping]) 313 | return text_out, mapping 314 | 315 | 316 | def _expand_year(text, mapping): 317 | m = re.findall(_year_re, text) 318 | for result in m: 319 | prep, mill_cent, dec_year, post = str(result[0]), str(result[2]), str(result[3]), str(result[4]) 320 | if mill_cent == "20" and dec_year in _sub_ten_nums: 321 | year_out = mill_cent + dec_year 322 | elif dec_year in _sub_ten_nums: 323 | year = dec_year[-1] 324 | year_out = mill_cent + " " + "oh" + " " + year 325 | else: 326 | year_out = _number_to_words(int(mill_cent)) + " " + _number_to_words(int(dec_year)) 327 | year_out = year_out + post 328 | 329 | # compute the SliceMap between raw and normalised years 330 | r = [mill_cent + dec_year + post, year_out] 331 | for i, t in enumerate(mapping): 332 | if t[0] == r[0] and mapping[i-2][0] == prep: 333 | mapping[i] = (r[1], SliceMap.lerp(t[1].source_len, len(r[1]))) 334 | 335 | text_out = ''.join([t[0] for t in mapping]) 336 | 337 | return text_out, mapping 338 | 339 | 340 | def _expand_time(text, mapping): 341 | m = re.findall(_time_re, text) 342 | for result in m: 343 | hour, minute = result[0], result[1] 344 | 345 | # remove the leading zero for hours like "09" 346 | hour = hour.strip("0") 347 | 348 | # remove zeros when there are no following minutes 349 | # and convert leading zeros to "oh" 350 | if minute == "00": 351 | minute = "" 352 | elif minute[0] == "0": 353 | minute = ' '.join(["oh", minute[1]]) 354 | 355 | # check for a following am/pm 356 | if result[2] != "": 357 | am_pm = ' '.join(result[2]) 358 | 359 | out = " ".join([hour, minute, am_pm]) 360 | else: 361 | out = " ".join([hour, minute]) 362 | 363 | # add following puncuation 364 | if result[3] != "": 365 | out = out + result[3] 366 | 367 | # compute SliceMap 368 | r = [''.join([result[0], ":", result[1], result[2], result[3]]), out] 369 | for i, t, in enumerate(mapping): 370 | if t[0] == r[0]: 371 | mapping[i] = (r[1], SliceMap.lerp(t[1].source_len, len(r[1]))) 372 | 373 | text_out = ''.join(t[0] for t in mapping) 374 | return text_out, mapping 375 | 376 | 377 | def _standard_number_to_words(n, digit_group): 378 | parts = [] 379 | if n >= 1000: 380 | # Format next higher digit group. 381 | parts.append(_standard_number_to_words(n // 1000, digit_group + 1)) 382 | n = n % 1000 383 | 384 | if n >= 100: 385 | parts.append('%s hundred' % _units[n // 100]) 386 | if n % 100 >= len(_units): 387 | parts.append(_tens[(n % 100) // 10]) 388 | parts.append(_units[(n % 100) % 10]) 389 | else: 390 | parts.append(_units[n % 100]) 391 | if n > 0: 392 | parts.append(_digit_groups[digit_group]) 393 | return ' '.join([x for x in parts if x]) 394 | 395 | 396 | def _number_to_words(n): 397 | # Handle special cases first, then go to the standard case: 398 | if n >= 1000000000000000000: 399 | return str(n) # Too large, just return the digits 400 | elif n == 0: 401 | return 'zero' 402 | elif n % 100 == 0 and n % 1000 != 0 and n < 3000: 403 | return _standard_number_to_words(n // 100, 0) + ' hundred' 404 | else: 405 | return _standard_number_to_words(n, 0) 406 | 407 | 408 | def _expand_number(text, mapping): 409 | match = re.findall(_number_re, text) 410 | for m in match: 411 | 412 | out = _number_to_words(int(m[0])) + m[1] 413 | r = [''.join(m), out] 414 | 415 | for i, t in enumerate(mapping): 416 | 417 | # only compare the numeric portion of the string 418 | text_re_nums = re.search(r"\d+", t[0]) 419 | if text_re_nums is None: 420 | continue 421 | 422 | if r[0] == text_re_nums.group(): 423 | 424 | # escape regex special characters (e.g., question mark) 425 | rep = re.escape(r[0]) 426 | j = re.sub(rep, r[1], t[0]) 427 | mapping[i] = (j, SliceMap.lerp(t[1].source_len, len(j))) 428 | 429 | text_out = ''.join([t[0] for t in mapping]) 430 | return text_out, mapping 431 | 432 | 433 | def _expand_ordinal(text, mapping): 434 | match = re.findall(_ordinal_re, text) 435 | for m in match: 436 | num = _number_to_words(int(m[0])) 437 | for suffix, replacement in _ordinal_suffixes: 438 | if num.endswith(suffix): 439 | out = num[:-len(suffix)] + replacement 440 | break 441 | else: 442 | out = num + 'th' 443 | r = [''.join(m), out] 444 | 445 | # compute SliceMap 446 | for i, t in enumerate(mapping): 447 | if t[0] == r[0]: 448 | mapping[i] = (r[1], SliceMap.lerp(t[1].source_len, len(r[1]))) 449 | 450 | text_out = ''.join([t[0] for t in mapping]) 451 | 452 | return text_out, mapping 453 | 454 | 455 | def normalize_numbers(text: str): 456 | words = re.split("(\s+)", text) 457 | mapping = list(zip(words, [SliceMap.identity(len(word)) for word in words])) 458 | 459 | text, mapping = _remove_commas(text, mapping) 460 | text, mapping = _expand_year(text, mapping) 461 | text, mapping = _expand_abbreviated_currency_unit(text, mapping) 462 | text, mapping = _expand_other_currency(text, mapping, _pounds_re, "pound", "pounds") 463 | text, mapping = _expand_other_currency(text, mapping, _yen_re, "yen", "yen") 464 | text, mapping = _expand_other_currency(text, mapping, _euro_re, "euro", "euros") 465 | text, mapping = _expand_other_unit(text, mapping, _ml_re, "milliliter", "milliliters") 466 | text, mapping = _expand_other_unit(text, mapping, _cl_re, "centiliter", "centiliters") 467 | text, mapping = _expand_other_unit(text, mapping, _g_re, "gram", "grams") 468 | text, mapping = _expand_other_unit(text, mapping, _kg_re, "kilogram", "kilograms") 469 | text, mapping = _expand_other_unit(text, mapping, _mm_re, "millimeter", "millimeters") 470 | text, mapping = _expand_other_unit(text, mapping, _cm_re, "centimeter", "centimeters") 471 | text, mapping = _expand_other_unit(text, mapping, _km_re, "kilometer", "kilometers") 472 | text, mapping = _expand_other_unit(text, mapping, _in_re, "inch", "inches") 473 | text, mapping = _expand_other_unit(text, mapping, _ft_re, "foot", "feet") 474 | text, mapping = _expand_other_unit(text, mapping, _l_re, "liter", "liters") 475 | text, mapping = _expand_other_unit(text, mapping, _m_re, "meter", "meters") 476 | text, mapping = _expand_other_unit(text, mapping, _yd_re, "yard", "yards") 477 | text, mapping = _expand_other_unit(text, mapping, _s_re, "second", "seconds") 478 | text, mapping = _expand_dollars(text, mapping) 479 | text, mapping = _convert_hash(text, mapping) 480 | text, mapping = _expand_decimal_point(text, mapping) 481 | text, mapping = _expand_time(text, mapping) 482 | text, mapping = _expand_ordinal(text, mapping) 483 | text, mapping = _expand_number(text, mapping) 484 | 485 | raw2clean_map = SliceMap.empty() 486 | for word, word_map in mapping: 487 | raw2clean_map += word_map 488 | return text, raw2clean_map 489 | -------------------------------------------------------------------------------- /transcription_diff/slice_map.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List, Iterable, overload 2 | 3 | import numpy as np 4 | 5 | 6 | class SliceMap: 7 | def __init__(self, smap: Union[List[slice], np.ndarray], target_len: int): 8 | """ 9 | A slice map smap is a list of slices that maps from X to Y with 10 | - X[i] mapping to Y[smap[i]] 11 | - X[i:j] mapping to Y[smap[i].start:smap[j - 1].stop] 12 | 13 | Informally, an item in X can correspond to 0 or more consecutive items in Y. A slice of one or more items in X 14 | will map to the slice spanning from the leftmost corresponding item in Y to the rightmost corresponding item. 15 | 16 | :param smap: the list of slices. The following must hold: 17 | - len(smap) must be equal to the size of the X 18 | - slices cannot have negative indices and cannot index beyond the size of Y 19 | - the slice starts cannot decrease, the same goes for the slice stops. You can however have consecutive 20 | overlapping slices, e.g. [slice(0, 2), slice(0, 2)]. 21 | Note that slices can be empty (stop <= start). 22 | The slices can also be passed as an (X, 2) shaped integer array. The second dimension holds slice starts and 23 | ends, respectively. 24 | :param target_len: the size of Y 25 | """ 26 | self.source_len = len(smap) 27 | self.target_len = target_len 28 | 29 | # Convert slices to an array 30 | if not isinstance(smap, np.ndarray): 31 | self._map = np.empty((self.source_len, 2), dtype=np.int64) 32 | for i, sli in enumerate(smap): 33 | self._map[i] = [sli.start, sli.stop] 34 | else: 35 | self._map = smap.astype(np.int64, copy=True) 36 | 37 | assert np.all((0 <= self._map) & (self._map <= target_len)), "Slice starts/stops out of bounds" 38 | assert np.all(self._map[1:] >= self._map[:-1]), "Slice starts/stops must be increasing" 39 | 40 | def __getitem__(self, item: Union[int, slice]) -> slice: 41 | """ 42 | Indexes the position in X with either an integer or a slice (step != 1 is not supported). Returns the 43 | corresponding slice in Y. 44 | """ 45 | if np.issubdtype(type(item), np.integer): 46 | item = slice(item, item + 1) 47 | else: 48 | assert item.step in [None, 1], "Only steps of 1 are supported" 49 | 50 | view = self._map[item] 51 | if len(view): 52 | # We return a slice that spans from the lowest to the highest target indices 53 | return slice(view[0][0], view[-1][1]) 54 | else: 55 | # We return an empty slice, it is computed so as to stay consistent with our axioms. 56 | pos = np.clip(0, item.start, self.source_len) 57 | start = self._map[pos][0] if pos < self.source_len else self.target_len 58 | stop = self._map[pos - 1][1] if pos > 0 else 0 59 | stop = max(start, stop) 60 | return slice(start, stop) 61 | 62 | def __len__(self): 63 | """ 64 | Returns the size of X 65 | """ 66 | return self.source_len 67 | 68 | def __bool__(self): 69 | """ 70 | To ensure we still get a True value when the mapping is empty 71 | """ 72 | return True 73 | 74 | def __iter__(self): 75 | """ 76 | Iterates over slices, returning pairs (start, stop) 77 | """ 78 | yield from ((int(start), int(end)) for start, end in self._map) 79 | 80 | def project(self, data: Union[np.ndarray, List], default=None) -> Union[np.ndarray, List]: 81 | """ 82 | Projects data in the source space to the target space. 83 | A default value will be returned in place of gaps in the target space. 84 | In case of overlaps, the rightmost item will take priority. 85 | 86 | :param data: a list of arbitrary objects or a numpy array. It must be that len(data) == source_len 87 | :param default: the value to give to entries that nothing maps to. This value must be specified in the case 88 | of numpy arrays 89 | :return: the projected data in Y as a list or numpy array 90 | """ 91 | assert len(data) == self.source_len, "The data to project must have the same length as the mapping." 92 | is_numpy = isinstance(data, np.ndarray) 93 | assert not (is_numpy and default is None), "The default value must be specified for numpy arrays." 94 | 95 | if is_numpy: 96 | projected = np.full_like(data, default, shape=self.target_len) 97 | else: 98 | projected = [default] * self.target_len 99 | 100 | for source_idx, (target_start, target_end) in enumerate(self._map): 101 | if is_numpy: 102 | projected[target_start:target_end] = data[source_idx] 103 | else: 104 | projected[target_start:target_end] = [data[source_idx]] * (target_end - target_start) 105 | 106 | return projected 107 | 108 | def inverse(self) -> 'SliceMap': 109 | """ 110 | With self mapping from X to Y, returns the inverse Y to X mapping. 111 | This operation is bijective, including in the presence of gaps or overlaps. 112 | """ 113 | # Find the Points Of Interest: the indices where the mapping's starts or stops increase 114 | bounded_map = np.concatenate((self._map, [[self.target_len, self.target_len]])) 115 | changes = np.diff(bounded_map, axis=0, prepend=0) 116 | (start_pois,), (stop_pois,) = changes[:, 1].nonzero(), changes[:, 0].nonzero() 117 | 118 | n_repeats = np.diff(bounded_map[start_pois, 1], prepend=0) 119 | inv_map_starts = np.repeat(start_pois, n_repeats) 120 | 121 | n_repeats = np.diff(bounded_map[stop_pois, 0], prepend=0) 122 | inv_map_stops = np.repeat(stop_pois, n_repeats) 123 | 124 | inv_map = np.stack([inv_map_starts, inv_map_stops], axis=1) 125 | 126 | return SliceMap(inv_map, self.source_len) 127 | 128 | def compose(self, other: 'SliceMap') -> 'SliceMap': 129 | """ 130 | With self mapping from X to Y and other mapping from Y to Z, returns the composed X to Z mapping. 131 | """ 132 | assert self.target_len == other.source_len, \ 133 | f"Cannot compose {self.source_len}x{self.target_len} map with {other.source_len}x{other.target_len} map." 134 | 135 | smap = np.empty((self.source_len, 2), dtype=np.int64) 136 | for i in range(len(self)): 137 | sli = other[self[i]] 138 | smap[i] = [sli.start, sli.stop] 139 | 140 | return SliceMap(smap, other.target_len) 141 | 142 | def __mul__(self, other): 143 | """ 144 | Multiplication is shorthand for compose 145 | """ 146 | return self.compose(other) 147 | 148 | def concat(self, other: 'SliceMap') -> 'SliceMap': 149 | """ 150 | With self mapping from Xi to Yi and other mapping from Xj to Yj, returns the concatenated mapping 151 | from cat(Xi, Xj) to cat(Yi, Tj). 152 | """ 153 | new_map = np.concatenate((self._map, other._map + self.target_len)) 154 | return SliceMap(new_map, self.target_len + other.target_len) 155 | 156 | def __add__(self, other): 157 | """ 158 | Addition is shorthand for concatenation 159 | """ 160 | return self.concat(other) 161 | 162 | def __eq__(self, other: 'SliceMap'): 163 | if other is None: 164 | return False 165 | return \ 166 | self.source_len == other.source_len and \ 167 | self.target_len == other.target_len and \ 168 | np.array_equal(self._map, other._map) 169 | 170 | @staticmethod 171 | def from_1to1_map(oto_map: Iterable[int], target_len: int): 172 | """ 173 | Creates a slicemap where each index i corresponds to the slice oto_map[i]:oto_map[i] + 1 174 | """ 175 | return SliceMap([slice(p, p + 1) for p in oto_map], target_len) 176 | 177 | @staticmethod 178 | def from_ranges(ranges: Iterable[int]): 179 | """ 180 | This is the non-cumulative version of a monotonic mapping: 181 | - SliceMap.from_ranges(r) is equivalent to SliceMap.from_monotonic_map(np.cumsum(r)) 182 | """ 183 | smap = [] 184 | target_pos = 0 185 | for r in ranges: 186 | smap.append(slice(target_pos, target_pos + r)) 187 | target_pos += r 188 | return SliceMap(smap, target_pos) 189 | 190 | @staticmethod 191 | def lerp(source_len: int, target_len: int): 192 | """ 193 | Creates a map that linearly interpolates from X to Y, e.g. for source_len=6 and target_len=12, the slice 194 | 2:3 in X maps to 4:6 in Y. 195 | """ 196 | low = min(source_len, target_len) 197 | high = max(source_len, target_len) 198 | idx = np.linspace(0, low, high, endpoint=False, dtype=np.int64) 199 | smap = np.stack([idx, np.minimum(idx + 1, low)], axis=1) 200 | smap = SliceMap(smap, low) 201 | 202 | return smap if target_len == low else smap.inverse() 203 | 204 | @staticmethod 205 | def full(source_len: int, target_len: int): 206 | """ 207 | Creates a map where each element in the source space maps to the entirety of the target space. 208 | """ 209 | smap = np.zeros((source_len, 2), dtype=np.int64) 210 | smap[:, 1] = target_len 211 | return SliceMap(smap, target_len) 212 | 213 | @staticmethod 214 | def empty() -> 'SliceMap': 215 | return SliceMap([], 0) 216 | 217 | @staticmethod 218 | def identity(length: int) -> 'SliceMap': 219 | return SliceMap.slice(0, length, length) 220 | 221 | @overload 222 | def slice(start: int, end: int, target_len: int) -> 'SliceMap': ... 223 | @overload 224 | def slice(sli: slice, target_len: int) -> 'SliceMap': ... 225 | @staticmethod 226 | def slice(*args) -> 'SliceMap': 227 | """ 228 | Convenience method. Creates a map where all elements map to a slice of a target space. 229 | - is where the slice begins in the target space 230 | - is where the slice ends in the target space 231 | - is the size of the target space 232 | This method is the inverse of eye() 233 | """ 234 | if len(args) == 2: 235 | start, end, target_len = args[0].start, args[0].stop, args[1] 236 | else: 237 | start, end, target_len = args 238 | assert 0 <= start <= end <= target_len, f"Invalid slice: {start}:{end} in {target_len}" 239 | return SliceMap( 240 | np.stack([np.arange(start, end), np.arange(start, end) + 1], axis=1), 241 | target_len 242 | ) 243 | 244 | @overload 245 | def eye(start: int, end: int, length: int) -> 'SliceMap': ... 246 | @overload 247 | def eye(sli: slice, length: int) -> 'SliceMap': ... 248 | @staticmethod 249 | def eye(*args) -> 'SliceMap': 250 | """ 251 | Convenience method. Creates a map where 252 | - the first element map to nothing 253 | - the elements between and map to the identity 254 | - the elements after (up to ) map to nothing 255 | This method is the inverse of slice() 256 | """ 257 | if len(args) == 2: 258 | start, end, length = args[0].start, args[0].stop, args[1] 259 | else: 260 | start, end, length = args 261 | return SliceMap.full(start, 0) + SliceMap.identity(end - start) + SliceMap.full(length - end, 0) 262 | 263 | @staticmethod 264 | def compose_by_name(mapping_name: str, **mappings: 'SliceMap'): 265 | """ 266 | Composes mappings together based on their names. Each SliceMap passed as argument must have 267 | the name structure . 268 | 269 | For example, calling the function as: 270 | SliceMap.compose_by_name('a2c', a2b=a2b, b2c=b2c) 271 | will return the composition a2c = a2b * b2c. 272 | 273 | An AssertionError will be raised if or are not found in the names of the 274 | passed mappings. 275 | 276 | Mappings that are not used in the composition may be passed. They will be ignored. 277 | """ 278 | assert all(k.count("2") == 1 for k in list(mappings) + [mapping_name]), \ 279 | f"All mappings must conform to the name convention , got {list(mappings)}" 280 | source_name, target_name = mapping_name.split("2") 281 | source_names, target_names = zip(*[map_name.split("2") for map_name in mappings]) 282 | assert source_name in source_names, f"Source name {source_name} not found in {source_names}" 283 | assert target_name in target_names, f"Target name {target_name} not found in {target_names}" 284 | 285 | dim_name = source_name 286 | composed_map = None 287 | seen_idx = set() 288 | while dim_name != target_name: 289 | map_idx = source_names.index(dim_name) 290 | assert map_idx not in seen_idx, f"Cycle detected: {list(mappings)}" 291 | seen_idx.add(map_idx) 292 | map_to_compose = mappings[f"{dim_name}2{target_names[map_idx]}"] 293 | composed_map = composed_map * map_to_compose if composed_map else map_to_compose 294 | dim_name = target_names[map_idx] 295 | 296 | return composed_map 297 | 298 | def __repr__(self): 299 | return f"<{self.source_len}x{self.target_len} map: {[tuple(sli) for sli in self._map]}>" 300 | -------------------------------------------------------------------------------- /transcription_diff/text_diff.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import List, Iterable, overload, Union 5 | 6 | import numpy as np 7 | from minineedle import needle 8 | 9 | from transcription_diff.text_normalization import normalize_text 10 | from transcription_diff.whisper_asr import whisper_asr 11 | from colorama import Fore as colors 12 | 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | @dataclass 18 | class TextDiffRegion: 19 | reference_text: str 20 | compared_text: str 21 | pronunciation_match: bool 22 | 23 | 24 | def clean_text_diff(ref_text: str, compared: str) -> List[TextDiffRegion]: 25 | alignment = needle.NeedlemanWunsch(ref_text.split(" "), compared.split(" ")) 26 | alignment.align() 27 | 28 | # Arrange 29 | regions = [] 30 | for ref_word, compared_word in zip(*alignment.get_aligned_sequences()): 31 | regions.append(TextDiffRegion( 32 | ref_word if isinstance(ref_word, str) else "", 33 | compared_word if isinstance(compared_word, str) else "", 34 | pronunciation_match=(ref_word == compared_word) 35 | )) 36 | 37 | # Re-add the spaces between words, and prefer to add them on identical regions rather than non-identical ones 38 | for text_attr in ("reference_text", "compared_text"): 39 | last_word_region = None 40 | for region in regions: 41 | if not getattr(region, text_attr): 42 | continue 43 | if last_word_region: 44 | if last_word_region.pronunciation_match: 45 | setattr(last_word_region, text_attr, getattr(last_word_region, text_attr) + " ") 46 | else: 47 | setattr(region, text_attr, " " + getattr(region, text_attr)) 48 | last_word_region = region 49 | 50 | # Compress 51 | new_regions = [] 52 | for region in regions: 53 | if new_regions and (new_regions[-1].pronunciation_match == region.pronunciation_match): 54 | new_regions[-1].reference_text += region.reference_text 55 | new_regions[-1].compared_text += region.compared_text 56 | else: 57 | new_regions.append(region) 58 | 59 | return new_regions 60 | 61 | 62 | def text_diff( 63 | reference_texts: Iterable[str], compared_texts: Iterable[str], lang_id: str 64 | ) -> List[List[TextDiffRegion]]: 65 | raw_refs, raw_comps = list(reference_texts), list(compared_texts) 66 | 67 | # Normalize text down to characters that influence pronunciation only 68 | clean_refs, raw2clean_refs = zip(*[normalize_text(raw_ref, lang_id) for raw_ref in raw_refs]) 69 | clean_comps, raw2clean_comps = zip(*[normalize_text(raw_comp, lang_id) for raw_comp in raw_comps]) 70 | 71 | # Align clean texts and isolate errors 72 | text_diffs = [clean_text_diff(clean_ref, clean_comp) for clean_ref, clean_comp in zip(clean_refs, clean_comps)] 73 | 74 | # Bring the regions up to the unnormalized text space 75 | for raw_ref, raw2clean_ref, raw_comp, raw2clean_comp, clean_diff in zip( 76 | raw_refs, raw2clean_refs, raw_comps, raw2clean_comps, text_diffs 77 | ): 78 | clean2raw_ref = raw2clean_ref.inverse() 79 | clean2raw_comp = raw2clean_comp.inverse() 80 | 81 | clean_ref_pos, clean_comp_pos = 0, 0 82 | raw_ref_pos, raw_comp_pos = 0, 0 83 | for region in clean_diff: 84 | # Use slicemaps to figure out which parts of the unnormalized text this region corresponds to 85 | clean_ref_sli = slice(clean_ref_pos, clean_ref_pos + len(region.reference_text)) 86 | clean_comp_sli = slice(clean_comp_pos, clean_comp_pos + len(region.compared_text)) 87 | if region is not clean_diff[-1]: 88 | raw_ref_sli = slice(raw_ref_pos, max(clean2raw_ref[clean_ref_sli].stop, raw_ref_pos)) 89 | raw_comp_sli = slice(raw_comp_pos, max(clean2raw_comp[clean_comp_sli].stop, raw_comp_pos)) 90 | else: 91 | # Ensure we span the entirety of the unnormalized text, slicemaps are not guaranteed to be surjective 92 | # Typical example: a final punctuation that is erased in text normalization. 93 | raw_ref_sli = slice(raw_ref_pos, len(raw_ref)) 94 | raw_comp_sli = slice(raw_comp_pos, len(raw_comp)) 95 | 96 | # Modify the region in place with the unnormalized text 97 | region.reference_text = raw_ref[raw_ref_sli] 98 | region.compared_text = raw_comp[raw_comp_sli] 99 | 100 | # Update the positions 101 | clean_ref_pos = clean_ref_sli.stop 102 | clean_comp_pos = clean_comp_sli.stop 103 | raw_ref_pos = raw_ref_sli.stop 104 | raw_comp_pos = raw_comp_sli.stop 105 | 106 | return text_diffs 107 | 108 | 109 | @overload 110 | def transcription_diff( 111 | text: str, wav: np.ndarray, sr, *, audio_lang: str=None, whisper_model_size=2, custom_words=[], device="cuda" 112 | ) -> List[TextDiffRegion]: ... 113 | @overload 114 | def transcription_diff( 115 | texts: List[str], wavs: Iterable[np.ndarray], sr, *, audio_lang: str=None, whisper_model_size=2, custom_words=[], 116 | device="cuda" 117 | ) -> List[List[TextDiffRegion]]: ... 118 | @overload 119 | def transcription_diff( 120 | text: str, fpath: Union[str, Path], *, audio_lang: str=None, whisper_model_size=2, custom_words=[], device="cuda" 121 | ) -> List[TextDiffRegion]: ... 122 | @overload 123 | def transcription_diff( 124 | texts: List[str], fpaths: Iterable[Union[str, Path]], *, audio_lang: str=None, whisper_model_size=2, 125 | custom_words=[], device="cuda" 126 | ) -> List[List[TextDiffRegion]]: ... 127 | def transcription_diff( 128 | *args, lang_id: str=None, whisper_model_size=2, custom_words=[], device="cuda" 129 | ) -> Union[List[TextDiffRegion], List[List[TextDiffRegion]]]: 130 | # TODO: doc 131 | # Arg parsing 132 | texts, args = args[0], args[1:] 133 | if single := isinstance(texts, str): 134 | texts = [texts] 135 | 136 | # Perform ASR 137 | asr_texts, lang_id = whisper_asr( 138 | *args, audio_lang=lang_id, whisper_model_size=whisper_model_size, custom_words=custom_words, device=device 139 | ) 140 | if isinstance(asr_texts, str): 141 | asr_texts = [asr_texts] 142 | 143 | # Get the diffs 144 | diffs = text_diff(texts, asr_texts, lang_id) 145 | 146 | if single: 147 | return diffs[0] 148 | else: 149 | return diffs 150 | 151 | 152 | def render_text_diff(text_diff: List[TextDiffRegion], with_colors=True) -> str: 153 | str_out = "" 154 | for region in text_diff: 155 | if region.pronunciation_match: 156 | str_out += region.reference_text 157 | else: 158 | str_out += "(" 159 | if with_colors: 160 | str_out += colors.RED 161 | str_out += region.compared_text 162 | if with_colors: 163 | str_out += colors.RESET 164 | str_out += "|" 165 | if with_colors: 166 | str_out += colors.GREEN 167 | str_out += region.reference_text 168 | if with_colors: 169 | str_out += colors.RESET 170 | str_out += ")" 171 | 172 | return str_out 173 | -------------------------------------------------------------------------------- /transcription_diff/text_normalization.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import logging 3 | import re 4 | from typing import Tuple, Callable, List 5 | 6 | import unicodedata 7 | from langcodes import Language 8 | 9 | from transcription_diff.number_normalization import normalize_numbers 10 | from transcription_diff.slice_map import SliceMap 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | # Regular expressions matching whitespace. When using with re.split(), the second one will keep whitespaces in the 17 | # output because all captured groups are kept. 18 | _whitespace_excl_re = re.compile(r'\s+') 19 | _whitespace_incl_re = re.compile(r'(\s+)') 20 | 21 | # List of (regular expression, replacement) pairs for abbreviations: 22 | _abbreviations = [ 23 | (re.compile('\\b%s\\.' % abbrev, re.IGNORECASE), expanded) 24 | for abbrev, expanded in [ 25 | ('mrs', 'misess'), 26 | ('mr', 'mister'), 27 | ('dr', 'doctor'), 28 | ('st', 'saint'), 29 | ('co', 'company'), 30 | ('jr', 'junior'), 31 | ('maj', 'major'), 32 | ('gen', 'general'), 33 | ('drs', 'doctors'), 34 | ('rev', 'reverend'), 35 | ('lt', 'lieutenant'), 36 | ('hon', 'honorable'), 37 | ('sgt', 'sergeant'), 38 | ('capt', 'captain'), 39 | ('esq', 'esquire'), 40 | ('ltd', 'limited'), 41 | ('col', 'colonel'), 42 | ('ft', 'feet'), 43 | ('abbrev', 'abbreviation'), 44 | ('ave', 'avenue'), 45 | ('abstr', 'abstract'), 46 | ('addr', 'address'), 47 | ('jan', 'january'), 48 | ('feb', 'february'), 49 | ('mar', 'march'), 50 | ('apr', 'april'), 51 | ('jul', 'july'), 52 | ('aug', 'august'), 53 | ('sep', 'september'), 54 | ('sept', 'september'), 55 | ('oct', 'october'), 56 | ('nov', 'november'), 57 | ('dec', 'december'), 58 | ('mon', 'monday'), 59 | ('tue', 'tuesday'), 60 | ('wed', 'wednesday'), 61 | ('thur', 'thursday'), 62 | ('fri', 'friday'), 63 | ('sec', 'second'), 64 | ('min', 'minute'), 65 | ('mo', 'month'), 66 | ('yr', 'year'), 67 | ('cal', 'calorie'), 68 | ('dept', 'department'), 69 | ('gal', 'gallon'), 70 | ('kg', 'kilogram'), 71 | ('km', 'kilometer'), 72 | ('mt', 'mount'), 73 | ('oz', 'ounce'), 74 | ('vol', 'volume'), 75 | ('vs', 'versus'), 76 | ('yd', 'yard'), 77 | ('e\\.g', 'eg'), 78 | ('i\\.e', 'ie'), 79 | ('etc', 'etc'), 80 | ] 81 | ] 82 | 83 | 84 | def expand_abbreviations(text: str): 85 | orig2new = SliceMap.identity(len(text)) 86 | new_text = text 87 | 88 | for regex, replacement in _abbreviations: 89 | for match in re.finditer(regex, text): 90 | new_sli = orig2new[slice(*match.span())] 91 | new_text = new_text[:new_sli.start] + replacement + new_text[new_sli.stop:] 92 | orig2new *= SliceMap.identity(new_sli.start) + \ 93 | SliceMap.lerp(len(match.group()), len(replacement)) + \ 94 | SliceMap.identity(orig2new.target_len - new_sli.stop) 95 | 96 | return new_text, orig2new 97 | 98 | 99 | def collapse_whitespace(text: str): 100 | for part in re.split(_whitespace_incl_re, text): 101 | match = re.search(_whitespace_excl_re, part) 102 | if match is not None: 103 | new_part = re.sub(_whitespace_excl_re, " ", part) 104 | yield new_part, SliceMap.lerp(len(part), len(new_part)) 105 | else: 106 | yield part, SliceMap.identity(len(part)) 107 | 108 | 109 | def standardize_characters(text: str): 110 | for part in re.split(_whitespace_incl_re, text): 111 | new_part = unicodedata.normalize("NFKC", part) 112 | transform = SliceMap.lerp(len(part), len(new_part)) 113 | yield new_part, transform 114 | 115 | 116 | def keep_pronounced_only(text: str): 117 | text = text.replace("-", " ") 118 | kept_idx = [i for i, c in enumerate(text) if c.isalnum() or c in ("'", " ")] 119 | new_text = "".join(text[i] for i in kept_idx).lower() 120 | new2orig = SliceMap.from_1to1_map(kept_idx, len(text)) 121 | return new_text, new2orig.inverse() 122 | 123 | 124 | def apply_text_transforms_with_mapping( 125 | text: str, funcs: List[Callable], fault_tolerant=False 126 | ) -> Tuple[str, SliceMap]: 127 | """ 128 | :param funcs: a list of Callables that take a text string and return a tuple (new_text, mapping), where the mapping 129 | must be a SliceMap from the new text to the text provided as argument. For convenience, the function can also be a 130 | generator function that yields outputs in chunks (new_text_part, mapping_part). 131 | """ 132 | # Backcompat: we'll support funcs=None as argument 133 | funcs = funcs or [] 134 | 135 | orig2new = SliceMap.identity(len(text)) 136 | for func in funcs: 137 | # Perform the cleaning operation and obtain the new mapping 138 | try: 139 | if inspect.isgeneratorfunction(func): 140 | new_text = "" 141 | map_transform = SliceMap.empty() 142 | for new_text_part, map_transform_part in func(text): 143 | new_text += new_text_part 144 | map_transform += map_transform_part 145 | else: 146 | new_text, map_transform = func(text) 147 | except Exception as e: 148 | if fault_tolerant: 149 | logger.error(f"Exception in cleaning function {func.__name__}: {e}") 150 | continue 151 | else: 152 | raise 153 | 154 | # Update the mapping, verifying that it is valid 155 | if map_transform.source_len != len(text) or map_transform.target_len != len(new_text): 156 | if fault_tolerant: 157 | logger.error("Cleaning operations gave an incorrect mapping") 158 | map_transform = SliceMap.lerp(len(text), len(new_text)) 159 | else: 160 | raise RuntimeError("Cleaning operations gave an incorrect mapping") 161 | orig2new *= map_transform 162 | text = new_text 163 | 164 | return text, orig2new 165 | 166 | 167 | def normalize_text(raw_text: str, lang_id: str, fault_tolerant=False) -> Tuple[str, SliceMap]: 168 | """ 169 | :param fault_tolerant: issues arising in cleaning operations will not raise an exception if True. The cleaning 170 | and/or mapping may then be incorrect. 171 | :return: the tuple 172 | - clean_text: the cleaned text 173 | - raw2clean: the mapping from raw text to clean text 174 | """ 175 | # Define the ops to apply 176 | text_cleaning_ops = [standardize_characters] 177 | if Language.get(lang_id).language == "en": 178 | text_cleaning_ops.extend([expand_abbreviations, normalize_numbers]) 179 | text_cleaning_ops.extend([keep_pronounced_only, collapse_whitespace]) 180 | 181 | return apply_text_transforms_with_mapping(raw_text, text_cleaning_ops, fault_tolerant) 182 | -------------------------------------------------------------------------------- /transcription_diff/whisper_asr.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from functools import lru_cache 3 | from pathlib import Path 4 | from typing import overload, List, Union, Iterable, Tuple 5 | 6 | import librosa 7 | import numpy as np 8 | import torch 9 | import whisper 10 | from whisper.audio import SAMPLE_RATE as _WHISPER_SAMPLE_RATE, N_SAMPLES as _WHISPER_CHUNK_SIZE 11 | from whisper.tokenizer import LANGUAGES as _WHISPER_LANGUAGES 12 | 13 | from transcription_diff.find_lang_match import find_lang_match 14 | 15 | 16 | logger = logging.getLogger(__name__) 17 | _WHISPER_LANGUAGES = list(_WHISPER_LANGUAGES) 18 | 19 | 20 | # TODO: let the user handle the cache 21 | @lru_cache(maxsize=1) 22 | def get_whisper_model(model_size=3, english_only=False, device="cuda"): 23 | """ 24 | Available models: https://github.com/openai/whisper/blob/main/model-card.md 25 | 26 | :param model_size: controls the accuracy-speed tradeoff. Larger models are slower but more accurate. Ranges from 27 | 1 to 5. 28 | :param english_only: English-only models can only process input audio and output text in English, but they are 29 | more accurate. Do not use English-only models for other languages (not even for any-to-English translation) as 30 | you will get highly inaccurate results. 31 | """ 32 | model_name = { 33 | 1: "tiny", 34 | 2: "base", 35 | 3: "small", 36 | 4: "medium", 37 | 5: "large", 38 | }[model_size] 39 | if english_only and model_size != 5: 40 | model_name += ".en" 41 | 42 | logger.info(f"Loading whisper model \"{model_name}\" on {device}") 43 | return whisper.load_model(model_name, device=device) 44 | 45 | 46 | @overload 47 | def whisper_asr( 48 | wav: np.ndarray, sr, *, audio_lang: str=None, whisper_model_size=2, custom_words=[], device="cuda" 49 | ) -> Tuple[str, str]: ... 50 | @overload 51 | def whisper_asr( 52 | wavs: Iterable[np.ndarray], sr, *, audio_lang: str=None, whisper_model_size=2, custom_words=[], device="cuda" 53 | ) -> Tuple[List[str], str]: ... 54 | @overload 55 | def whisper_asr( 56 | fpath: Union[str, Path], *, audio_lang: str=None, whisper_model_size=2, custom_words=[], device="cuda" 57 | ) -> Tuple[str, str]: ... 58 | @overload 59 | def whisper_asr( 60 | fpaths: Iterable[Union[str, Path]], *, audio_lang: str=None, whisper_model_size=2, custom_words=[], device="cuda" 61 | ) -> Tuple[List[str], str]: ... 62 | def whisper_asr( 63 | *args, audio_lang: str=None, whisper_model_size=2, custom_words=[], device="cuda" 64 | ) -> Union[Tuple[str, str], Tuple[List[str], str]]: 65 | """ 66 | Performs automatic speech recognition on the given audio(s). Supports most languages, and can perform automatic 67 | language detection. 68 | 69 | :param sr: samples rate of the waveforms, if provided 70 | :param audio_lang: the lang code of the input audio as an IETF language tag (e.g. "en-us", "fr", ...), if known. 71 | When None, the language is automatically determined by the model. If provided and the language is English, 72 | the English-only whisper model will be used. 73 | :param whisper_model_size: controls the accuracy-speed tradeoff. Ranges from 1 to 5, which 5 being the highest 74 | accuracy (largest model size) but the lowest inference speed. This parameter has a large impact, consider setting 75 | it as high as you can afford to. 76 | :param custom_words: a list of words likely to be unknown to Whisper. We'll attempt to make whisper aware of them 77 | by passing them to the initial prompt. 78 | :return: a tuple: 79 | - The transcription(s) as a string or list of strings 80 | - The detected language ID of the first sample if was None, the whisper equivalent of 81 | otherwise. 82 | """ 83 | # Audio args parsing 84 | if len(args) == 1: 85 | if single := (isinstance(args[0], str) or isinstance(args[0], Path)): 86 | fpaths = [args[0]] 87 | else: 88 | fpaths = list(args[0]) 89 | # TODO?: batched resampling using torchaudio for efficiency 90 | wavs = [librosa.core.load(str(fpath), sr=_WHISPER_SAMPLE_RATE)[0] for fpath in fpaths] 91 | sr = _WHISPER_SAMPLE_RATE 92 | else: 93 | wavs, sr = args 94 | if single := isinstance(wavs, np.ndarray): 95 | wavs = [wavs] 96 | wavs = [wav.astype(np.float32) for wav in wavs] 97 | 98 | # Lang args 99 | if audio_lang: 100 | lang_idx = find_lang_match(audio_lang, _WHISPER_LANGUAGES) 101 | if not lang_idx: 102 | raise ValueError(f"Language code {audio_lang} is not recognized or supported by Whisper.") 103 | audio_lang = _WHISPER_LANGUAGES[lang_idx[0]] 104 | 105 | # Resample 106 | # TODO?: batched resampling using torchaudio for efficiency 107 | wavs = [ 108 | librosa.core.resample(wav, orig_sr=sr, target_sr=_WHISPER_SAMPLE_RATE) 109 | for wav in wavs 110 | ] 111 | 112 | # Format inputs 113 | if any(len(wav) > _WHISPER_CHUNK_SIZE for wav in wavs): 114 | logger.warning( 115 | # TODO: support for >30s inputs 116 | "At least one input to whisper is larger than the chunk size (30s), this is not yet supported and the " 117 | "input will be trimmed." 118 | ) 119 | wavs = [whisper.pad_or_trim(wav) for wav in wavs] 120 | mels = torch.stack([whisper.log_mel_spectrogram(wav) for wav in wavs]) 121 | 122 | # Ensuring the right device is selected 123 | if torch.device(device).type == "cuda" and not torch.cuda.is_available(): 124 | logger.warning( 125 | "CUDA is not available on your torch install, whisper will run on CPU instead. If you do have a " 126 | "CUDA-compatible GPU available, you may reinstall torch this way to enable CUDA:\n" 127 | "\tpip uninstall torch\n" 128 | "\tpip cache purge\n" 129 | "\tpip install torch -f https://download.pytorch.org/whl/torch_stable.html\n" 130 | ) 131 | device = "cpu" 132 | 133 | # Inference 134 | model = get_whisper_model(model_size=whisper_model_size, english_only=(audio_lang == "en"), device=device) 135 | device = next(model.parameters()).device 136 | options = whisper.DecodingOptions( 137 | language=audio_lang, 138 | # TODO?: support for timestamped ASR 139 | without_timestamps=True, 140 | fp16=device.type != "cpu", 141 | # TODO?: a more reliable way of expecting custom words? Maybe something with beam decoding? 142 | prompt=f"CUSTOM_WORDS={','.join(custom_words)}" if custom_words else None, 143 | ) 144 | with torch.inference_mode(): 145 | outputs = model.decode(mels.to(device), options) 146 | 147 | out_lang = audio_lang or outputs[0].language 148 | if single: 149 | return outputs[0].text, out_lang 150 | else: 151 | return [output.text for output in outputs], out_lang 152 | --------------------------------------------------------------------------------