├── data ├── gt_text │ ├── p088_1166.txt │ ├── p088_3706.txt │ ├── p088_3966.txt │ ├── p088_4067.txt │ ├── p088_2008.txt │ ├── p088_3232.txt │ ├── p088_9554.txt │ ├── p088_7953.txt │ ├── p088_3558.txt │ ├── p088_4616.txt │ ├── p088_5034.txt │ ├── p088_7234.txt │ ├── p088_7620.txt │ ├── p088_8331.txt │ ├── p088_4249.txt │ ├── p088_5568.txt │ ├── p088_8425.txt │ ├── p088_5892.txt │ ├── p088_9993.txt │ └── p088_8468.txt └── audio │ ├── p088_1166.wav │ ├── p088_2008.wav │ ├── p088_3232.wav │ ├── p088_3558.wav │ ├── p088_3706.wav │ ├── p088_3966.wav │ ├── p088_4067.wav │ ├── p088_4249.wav │ ├── p088_4616.wav │ ├── p088_5034.wav │ ├── p088_5568.wav │ ├── p088_5892.wav │ ├── p088_7234.wav │ ├── p088_7620.wav │ ├── p088_7953.wav │ ├── p088_8331.wav │ ├── p088_8425.wav │ ├── p088_8468.wav │ ├── p088_9554.wav │ └── p088_9993.wav ├── resources └── decoder-new.png ├── utils ├── rule_sim_matrix.npy ├── __pycache__ │ ├── wper.cpython-39.pyc │ └── decoder.cpython-39.pyc ├── wper.py └── decoder.py ├── config ├── ipa2cmu.json └── lexicon.json ├── README.md └── requirements.txt /data/gt_text/p088_1166.txt: -------------------------------------------------------------------------------- 1 | He walks to work. -------------------------------------------------------------------------------- /data/gt_text/p088_3706.txt: -------------------------------------------------------------------------------- 1 | Smile but frown. -------------------------------------------------------------------------------- /data/gt_text/p088_3966.txt: -------------------------------------------------------------------------------- 1 | Write fast now. -------------------------------------------------------------------------------- /data/gt_text/p088_4067.txt: -------------------------------------------------------------------------------- 1 | They left early -------------------------------------------------------------------------------- /data/gt_text/p088_2008.txt: -------------------------------------------------------------------------------- 1 | She reads fast now -------------------------------------------------------------------------------- /data/gt_text/p088_3232.txt: -------------------------------------------------------------------------------- 1 | We walked across town. -------------------------------------------------------------------------------- /data/gt_text/p088_9554.txt: -------------------------------------------------------------------------------- 1 | Could you please confirm your flight reservation number? -------------------------------------------------------------------------------- /data/gt_text/p088_7953.txt: -------------------------------------------------------------------------------- 1 | Is there a good restaurant nearby that serves local cuisine? -------------------------------------------------------------------------------- /data/gt_text/p088_3558.txt: -------------------------------------------------------------------------------- 1 | If we analyze the data carefully, we'll find meaningful patterns. -------------------------------------------------------------------------------- /data/gt_text/p088_4616.txt: -------------------------------------------------------------------------------- 1 | Could you please forward the email with the conference details to the team? -------------------------------------------------------------------------------- /data/gt_text/p088_5034.txt: -------------------------------------------------------------------------------- 1 | Could you recommend a good local restaurant that serves authentic cuisine? -------------------------------------------------------------------------------- /data/gt_text/p088_7234.txt: -------------------------------------------------------------------------------- 1 | The study reveals a strong correlation between exercise and mental health. -------------------------------------------------------------------------------- /data/gt_text/p088_7620.txt: -------------------------------------------------------------------------------- 1 | Scientists are working on developing more efficient renewable energy sources. -------------------------------------------------------------------------------- /data/gt_text/p088_8331.txt: -------------------------------------------------------------------------------- 1 | The museum curator explained the historical significance of the artifacts. -------------------------------------------------------------------------------- /data/gt_text/p088_4249.txt: -------------------------------------------------------------------------------- 1 | I'm sorry, but your flight has been delayed due to inclement weather conditions. -------------------------------------------------------------------------------- /data/gt_text/p088_5568.txt: -------------------------------------------------------------------------------- 1 | The research team is currently analyzing the data collected during the expedition. -------------------------------------------------------------------------------- /data/gt_text/p088_8425.txt: -------------------------------------------------------------------------------- 1 | How do you think artificial intelligence will impact job markets in the future? -------------------------------------------------------------------------------- /data/gt_text/p088_5892.txt: -------------------------------------------------------------------------------- 1 | The museum curator carefully restored the ancient artifacts found in the excavation. -------------------------------------------------------------------------------- /data/gt_text/p088_9993.txt: -------------------------------------------------------------------------------- 1 | The hotel offers a complimentary shuttle service for guests to and from the airport. -------------------------------------------------------------------------------- /data/gt_text/p088_8468.txt: -------------------------------------------------------------------------------- 1 | The more efficiently we manage our resources, the more competitive our business will become. -------------------------------------------------------------------------------- /data/audio/p088_1166.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Berkeley-Speech-Group/DysfluentWFST/HEAD/data/audio/p088_1166.wav -------------------------------------------------------------------------------- /data/audio/p088_2008.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Berkeley-Speech-Group/DysfluentWFST/HEAD/data/audio/p088_2008.wav -------------------------------------------------------------------------------- /data/audio/p088_3232.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Berkeley-Speech-Group/DysfluentWFST/HEAD/data/audio/p088_3232.wav -------------------------------------------------------------------------------- /data/audio/p088_3558.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Berkeley-Speech-Group/DysfluentWFST/HEAD/data/audio/p088_3558.wav -------------------------------------------------------------------------------- /data/audio/p088_3706.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Berkeley-Speech-Group/DysfluentWFST/HEAD/data/audio/p088_3706.wav -------------------------------------------------------------------------------- /data/audio/p088_3966.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Berkeley-Speech-Group/DysfluentWFST/HEAD/data/audio/p088_3966.wav -------------------------------------------------------------------------------- /data/audio/p088_4067.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Berkeley-Speech-Group/DysfluentWFST/HEAD/data/audio/p088_4067.wav -------------------------------------------------------------------------------- /data/audio/p088_4249.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Berkeley-Speech-Group/DysfluentWFST/HEAD/data/audio/p088_4249.wav -------------------------------------------------------------------------------- /data/audio/p088_4616.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Berkeley-Speech-Group/DysfluentWFST/HEAD/data/audio/p088_4616.wav -------------------------------------------------------------------------------- /data/audio/p088_5034.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Berkeley-Speech-Group/DysfluentWFST/HEAD/data/audio/p088_5034.wav -------------------------------------------------------------------------------- /data/audio/p088_5568.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Berkeley-Speech-Group/DysfluentWFST/HEAD/data/audio/p088_5568.wav -------------------------------------------------------------------------------- /data/audio/p088_5892.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Berkeley-Speech-Group/DysfluentWFST/HEAD/data/audio/p088_5892.wav -------------------------------------------------------------------------------- /data/audio/p088_7234.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Berkeley-Speech-Group/DysfluentWFST/HEAD/data/audio/p088_7234.wav -------------------------------------------------------------------------------- /data/audio/p088_7620.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Berkeley-Speech-Group/DysfluentWFST/HEAD/data/audio/p088_7620.wav -------------------------------------------------------------------------------- /data/audio/p088_7953.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Berkeley-Speech-Group/DysfluentWFST/HEAD/data/audio/p088_7953.wav -------------------------------------------------------------------------------- /data/audio/p088_8331.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Berkeley-Speech-Group/DysfluentWFST/HEAD/data/audio/p088_8331.wav -------------------------------------------------------------------------------- /data/audio/p088_8425.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Berkeley-Speech-Group/DysfluentWFST/HEAD/data/audio/p088_8425.wav -------------------------------------------------------------------------------- /data/audio/p088_8468.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Berkeley-Speech-Group/DysfluentWFST/HEAD/data/audio/p088_8468.wav -------------------------------------------------------------------------------- /data/audio/p088_9554.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Berkeley-Speech-Group/DysfluentWFST/HEAD/data/audio/p088_9554.wav -------------------------------------------------------------------------------- /data/audio/p088_9993.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Berkeley-Speech-Group/DysfluentWFST/HEAD/data/audio/p088_9993.wav -------------------------------------------------------------------------------- /resources/decoder-new.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Berkeley-Speech-Group/DysfluentWFST/HEAD/resources/decoder-new.png -------------------------------------------------------------------------------- /utils/rule_sim_matrix.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Berkeley-Speech-Group/DysfluentWFST/HEAD/utils/rule_sim_matrix.npy -------------------------------------------------------------------------------- /utils/__pycache__/wper.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Berkeley-Speech-Group/DysfluentWFST/HEAD/utils/__pycache__/wper.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/decoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Berkeley-Speech-Group/DysfluentWFST/HEAD/utils/__pycache__/decoder.cpython-39.pyc -------------------------------------------------------------------------------- /config/ipa2cmu.json: -------------------------------------------------------------------------------- 1 | { 2 | "a": "AA", 3 | "b": "B", 4 | "d": "D", 5 | "e": "EY", 6 | "f": "F", 7 | "h": "HH", 8 | "i": "IY", 9 | "j": "Y", 10 | "k": "K", 11 | "l": "L", 12 | "m": "M", 13 | "n": "N", 14 | "o": "OW", 15 | "p": "P", 16 | "r": "R", 17 | "s": "S", 18 | "t": "T", 19 | "u": "UW", 20 | "v": "V", 21 | "w": "W", 22 | "z": "Z", 23 | "æ": "AE", 24 | "ð": "DH", 25 | "ŋ": "NG", 26 | "ɑ": "AA", 27 | "ɔ": "AO", 28 | "ə": "AH", 29 | "ɚ": "ER", 30 | "ɛ": "EH", 31 | "ɜ": "ER", 32 | "ɡ": "G", 33 | "ɪ": "IH", 34 | "ɫ": "L", 35 | "ɹ": "R", 36 | "ɾ": "DX", 37 | "ʃ": "SH CH", 38 | "ʊ": "UH", 39 | "ʌ": "AH", 40 | "ʒ": "ZH JH", 41 | "θ": "TH", 42 | "ᵻ": "IH", 43 | "ɻ": "R", 44 | "ɒ": "AA", 45 | "ɔɪ": "OY", 46 | "ɐ": "AH", 47 | "aɪ": "AY", 48 | "aʊ": "AW" 49 | } 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Dysfluent WFST: A Framework for Zero-Shot Speech Dysfluency Transcription and Detection 👋 2 | 3 | 17 | 18 | Accepted by [Interspeech 2025](https://www.interspeech2025.org/home). [Paper](https://arxiv.org/abs/2505.16351) available. 19 | 20 | Basic workflow: 21 | 22 | ![workflow](resources/decoder-new.png) 23 | 24 | For inference, please check `main.ipynb` and `data` for example. 25 | 26 | For the calculation of Weight PER, please check `./utils/wper.py`, you can get similarity matrix there as well. 27 | -------------------------------------------------------------------------------- /utils/wper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class W_PER: 4 | def __init__(self): 5 | self.sim_matrix = np.load('./rule_sim_matrix.npy') 6 | self.phn2idx = { 7 | "|": 0, "OW": 1, "UW": 2, "EY": 3, "AW": 4, "AH": 5, "AO": 6, "AY": 7, "EH": 8, "K": 9, 8 | "NG": 10, "F": 11, "JH": 12, "M": 13, "CH": 14, "IH": 15, "UH": 16, "HH": 17, "L": 18, 9 | "AA": 19, "R": 20, "TH": 21, "AE": 22, "D": 23, "Z": 24, "OY": 25, "DH": 26, "IY": 27, "B": 28, "W": 29, "S": 30, 10 | "T": 31, "SH": 32, "ZH": 33, "ER": 34, "V": 35, "Y": 36, "N": 37, "G": 38, "P": 39, "-": 40 11 | } 12 | 13 | def weight_per(self, GT_list, hypo_phn_list): 14 | sim_matrix = self.sim_matrix 15 | phn2idx = self.phn2idx 16 | n = len(GT_list) 17 | m = len(hypo_phn_list) 18 | 19 | # init dp 20 | dp = np.zeros((n + 1, m + 1)) 21 | insertion_cost = 1 22 | deletion_cost = 1 23 | 24 | for i in range(1, n + 1): 25 | dp[i][0] = i * deletion_cost 26 | for j in range(1, m + 1): 27 | dp[0][j] = j * insertion_cost 28 | 29 | 30 | for i in range(1, n + 1): 31 | for j in range(1, m + 1): 32 | ref_idx = phn2idx[GT_list[i - 1]] 33 | hyp_idx = phn2idx[hypo_phn_list[j - 1]] 34 | substitution_cost = 1 - sim_matrix[ref_idx][hyp_idx] 35 | 36 | dp[i][j] = min( 37 | dp[i - 1][j - 1] + substitution_cost, 38 | dp[i - 1][j] + deletion_cost, 39 | dp[i][j - 1] + insertion_cost 40 | ) 41 | 42 | return dp[n][m] / n -------------------------------------------------------------------------------- /config/lexicon.json: -------------------------------------------------------------------------------- 1 | ["", "", "", "", "n", "t", "s", "a", "\u026a", "l", "\u0259", "d", "\u025b", "e", "k", "i", "m", "o", "p", "z", "\u0281", "b", "v", "f", "j", "r", "\u027b", "u", "w", "\u0254", "\u00e6", "\u028a", "\u0252", "i\u02d0", "\u027e", "\u0283", "h", "\u03b8", "e\u026a", "\u014b", "y", "\u00f0", "R", "\u0261", "\u025c", "g", "u\u02d0", "\u0254\u02d0", "a\u026a", "\u028c", "x", "a\u02d0", "\u0259\u028a", "\u0251\u02d0", "\u025c\u02d0", "n\u0329", "\u0251", "e\u02d0", "\u0272", "\u0292", "d\u0292", "l\u0329", "\u0251\u0303", "ai", "ts", "\u03b2", "\u0153", "o\u02d0", "a\u028a", "\u0254\u0303", "\u028f", "\u026a\u0259", "\u025b\u0259", "\u0265", "\u0282", "\u0255", "au", "\u0268", "c", "\u026f", "\u026b", "\u0294", "\u2019", "\u025b\u0303", "\u0288\u0282", "t\u0255", "\u00f8", "l\u02b2", "\u028e", "\u028a\u0259", "t\u02b2", "r\u02b2", "\u0254\u026a", "n\u02b2", "s\u02b2", "t\u02b0", "\u0295", "t\u0255\u02b0", "\u0288\u0282\u02b0", "\u0263", "\u0254y", "d\u02b2", "k\u02b0", "q", "\u0290", "\u0279", "\u0255\u02b2", "v\u02b2", "\u025d", "m\u02b2", "\u0127", "t\u0283", "i\u02b2", "ts\u02b0", "\u0290\u02b2", "\u0264", "\u025b\u02d0", "\u025f", "p\u02b0", "\u0266", "\u028b", "\u0288", "p\u02b2", "b\u02b2", "r\u031d", "\u025a", "a\u02b2", "s\u02e4", "o\u02b2", "\u0250\u0303", "t\u02e4", "\u0282\u02b2", "\u0289", "\u026d", "\u00f8\u02d0", "y\u02d0", "pf", "d\u02e4", "\u02b2", "\u0268\u02b2", "\u0271", "e\u02b2", "\u00e7", "\u0289\u02d0", "\u0273", "\u027d", "\u00f9", "\u0251\u0303\u02d0", "\u00e6\u02b2", "k\u02b2", "\u0261\u02b2", "\u00f0\u02e4", "f\u02b2", "y\u02b2", "\u0267", "\u02d0", "\u0261j", "z\u02b2", "\u00e1", "n\u032a", "\u0254\u0303\u02d0", "\u0251\u02d0\u030c", "\u00ed", "\u0275", "\u00fc", "a\u0302", "\u00f6", "\u00f3", "\u0252\u0303", "\u0283\u02b2", "r\u030c", "\u0153\u0303", "\u014d", "t\u032a", "\u026a\u030c", "\u00eb", "i\u02d0\u0302", "o\u02d0\u0302", "u\u02d0\u030c", "\u0292\u02b2", "\u026a\u0302", "\u00f5", "e\u0303", "oi", "t\u0283\u02b2", "\u00e6\u02d0\u030c", "d\u032a", "e\u030c", "e\u02d0\u0302", "o\u02d0\u030c", "\u0307", "i\u02d0\u030c", "ts\u02b2", "\u028a\u0302", "n\u030c", "\u00e3", "o\u0303", "\u025b\u0302", "\u00e6\u02d0", "\u028a\u030c", "b\u02b0", "\u029d", "t\u0320\u0283", "\u00e6e\u032f", "u\u02d0\u0302", "\u2205", "\u0254\u030c", "\u028a\u032f", "d\u02b0", "x\u02b2", "e\u02d0\u030c", "d\u0292\u02b2", "m\u02d0", "\u0268\u02d0", "n\u02b2\u030c", "r\u02b2\u030c", "\u0254\u0302", "b\u02b1", "\u0261\u02b0", "\u00e2", "\u02b0", "\u0283\u02b0", "\u00e6\u0303", "\u026b\u030c", "\u00e6\u0303\u02d0", "\u028a\u0303", "\u0292\u02b1", "a\u0303", "\u0256", "l\u02b2\u030c", "z\u02d0", "\u02b1", "l\u02d0", "n\u02d0", "ou", "l\u02b0", "\u0251\u02d0\u0302", "\u0288\u02b0", "g\u02b1", "d\u02d0", "\u0263\u02b2", "j\u02b2", "m\u030c", "m\u02b2\u030c", "dz", "u\u0303", "r\u02d0", "oi\u032f", "m\u0329", "\u028ai", "2", "\u027d\u02b1", "\u025f\u02d0", ":", "ts\u02d0", "j\u02d0", "i\u030c", "\u028ci", "\u06f7", "3", "t\u02d0", "\u0283\u02d0", "\u00e6\u02d0\u0302", "\u014b\u0329", "\u0256\u02b1", "g\u02b0", "4", "h\u02d0", "dz\u02d0", "\u0272\u02d0", "t\u0283\u02d0", "\u0289\u02b2"] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | # created-by: conda 24.11.1 5 | _libgcc_mutex=0.1=main 6 | _openmp_mutex=5.1=1_gnu 7 | accelerate=1.3.0=pypi_0 8 | aiohappyeyeballs=2.4.4=pypi_0 9 | aiohttp=3.11.11=pypi_0 10 | aiosignal=1.3.2=pypi_0 11 | antlr4-python3-runtime=4.9.3=pypi_0 12 | anyio=4.8.0=pypi_0 13 | argon2-cffi=23.1.0=pypi_0 14 | argon2-cffi-bindings=21.2.0=pypi_0 15 | arrow=1.3.0=pypi_0 16 | asteroid-filterbanks=0.4.0=pypi_0 17 | asttokens=3.0.0=pypi_0 18 | async-lru=2.0.4=pypi_0 19 | async-timeout=5.0.1=pypi_0 20 | attrs=24.3.0=pypi_0 21 | audioread=3.0.1=pypi_0 22 | babel=2.16.0=pypi_0 23 | beautifulsoup4=4.12.3=pypi_0 24 | bibtexparser=2.0.0b8=pypi_0 25 | blas=1.0=mkl 26 | bleach=6.2.0=pypi_0 27 | blessed=1.20.0=pypi_0 28 | brotli-python=1.0.9=py39h6a678d5_9 29 | bzip2=1.0.8=h5eee18b_6 30 | ca-certificates=2024.12.31=h06a4308_0 31 | certifi=2024.12.14=py39h06a4308_0 32 | cffi=1.17.1=pypi_0 33 | charset-normalizer=3.3.2=pyhd3eb1b0_0 34 | ci-sdr=0.0.2=pypi_0 35 | click=8.1.8=pypi_0 36 | clldutils=3.24.0=pypi_0 37 | cmudict=1.0.32=pypi_0 38 | colorama=0.4.6=pypi_0 39 | colorlog=6.9.0=pypi_0 40 | comm=0.2.2=pypi_0 41 | configargparse=1.7=pypi_0 42 | contourpy=1.3.0=pypi_0 43 | csvw=3.5.1=pypi_0 44 | ctc-segmentation=1.7.4=pypi_0 45 | cuda-cudart=12.4.127=0 46 | cuda-cupti=12.4.127=0 47 | cuda-libraries=12.4.1=0 48 | cuda-nvcc=12.4.99=0 49 | cuda-nvrtc=12.4.127=0 50 | cuda-nvtx=12.4.127=0 51 | cuda-opencl=12.6.77=0 52 | cuda-runtime=12.4.1=0 53 | cuda-version=12.6=3 54 | cycler=0.12.1=pypi_0 55 | cython=3.0.11=pypi_0 56 | datasets=3.2.0=pypi_0 57 | debugpy=1.8.12=pypi_0 58 | decorator=5.1.1=pypi_0 59 | defusedxml=0.7.1=pypi_0 60 | dill=0.3.8=pypi_0 61 | distance=0.1.3=pypi_0 62 | dlinfo=2.0.0=pypi_0 63 | editdistance=0.8.1=pypi_0 64 | einops=0.8.0=pypi_0 65 | espnet=202412=pypi_0 66 | espnet-tts-frontend=0.0.3=pypi_0 67 | et-xmlfile=2.0.0=pypi_0 68 | exceptiongroup=1.2.2=pypi_0 69 | executing=2.1.0=pypi_0 70 | fast-bss-eval=0.1.3=pypi_0 71 | fastjsonschema=2.21.1=pypi_0 72 | ffmpeg=4.3=hf484d3e_0 73 | filelock=3.13.1=py39h06a4308_0 74 | flash-attn=2.7.4.post1=pypi_0 75 | fonttools=4.55.3=pypi_0 76 | fqdn=1.5.1=pypi_0 77 | freetype=2.12.1=h4a9f257_0 78 | frozenlist=1.5.0=pypi_0 79 | fsspec=2024.9.0=pypi_0 80 | g2p-en=2.1.0=pypi_0 81 | gdown=5.2.0=pypi_0 82 | giflib=5.2.2=h5eee18b_0 83 | gmp=6.2.1=h295c915_3 84 | gmpy2=2.1.2=py39heeb90bb_0 85 | gnutls=3.6.15=he1e5248_0 86 | gpustat=1.1.1=pypi_0 87 | graphviz=0.20.3=pypi_0 88 | h11=0.14.0=pypi_0 89 | h5py=3.12.1=pypi_0 90 | httpcore=1.0.7=pypi_0 91 | httpx=0.28.1=pypi_0 92 | huggingface-hub=0.27.1=pypi_0 93 | humanfriendly=10.0=pypi_0 94 | hydra-core=1.3.2=pypi_0 95 | idna=3.7=py39h06a4308_0 96 | importlib-metadata=4.13.0=pypi_0 97 | importlib-resources=6.5.2=pypi_0 98 | inflect=7.5.0=pypi_0 99 | intel-openmp=2023.1.0=hdb19cb5_46306 100 | ipykernel=6.29.5=pypi_0 101 | ipython=8.18.1=pypi_0 102 | ipywidgets=8.1.5=pypi_0 103 | isodate=0.7.2=pypi_0 104 | isoduration=20.11.0=pypi_0 105 | jaconv=0.4.0=pypi_0 106 | jamo=0.4.1=pypi_0 107 | jedi=0.19.2=pypi_0 108 | jinja2=3.1.4=py39h06a4308_1 109 | jiwer=3.0.5=pypi_0 110 | joblib=1.4.2=pypi_0 111 | jpeg=9e=h5eee18b_3 112 | json5=0.10.0=pypi_0 113 | jsonpointer=3.0.0=pypi_0 114 | jsonschema=4.23.0=pypi_0 115 | jsonschema-specifications=2024.10.1=pypi_0 116 | jupyter=1.1.1=pypi_0 117 | jupyter-client=8.6.3=pypi_0 118 | jupyter-console=6.6.3=pypi_0 119 | jupyter-core=5.7.2=pypi_0 120 | jupyter-events=0.11.0=pypi_0 121 | jupyter-lsp=2.2.5=pypi_0 122 | jupyter-server=2.15.0=pypi_0 123 | jupyter-server-terminals=0.5.3=pypi_0 124 | jupyterlab=4.3.4=pypi_0 125 | jupyterlab-pygments=0.3.0=pypi_0 126 | jupyterlab-server=2.27.3=pypi_0 127 | jupyterlab-widgets=3.0.13=pypi_0 128 | k2=1.24.4.dev20241127+cuda12.4.torch2.5.1=pypi_0 129 | kaldiio=2.18.0=pypi_0 130 | kiwisolver=1.4.7=pypi_0 131 | lame=3.100=h7b6447c_0 132 | language-tags=1.2.0=pypi_0 133 | lazy-loader=0.4=pypi_0 134 | lcms2=2.16=hb9589c4_0 135 | ld_impl_linux-64=2.40=h12ee557_0 136 | lerc=4.0.0=h6a678d5_0 137 | libcublas=12.4.5.8=0 138 | libcufft=11.2.1.3=0 139 | libcufile=1.11.1.6=0 140 | libcurand=10.3.7.77=0 141 | libcusolver=11.6.1.9=0 142 | libcusparse=12.3.1.170=0 143 | libdeflate=1.22=h5eee18b_0 144 | libffi=3.4.4=h6a678d5_1 145 | libgcc-ng=11.2.0=h1234567_1 146 | libgomp=11.2.0=h1234567_1 147 | libiconv=1.16=h5eee18b_3 148 | libidn2=2.3.4=h5eee18b_0 149 | libjpeg-turbo=2.0.0=h9bf148f_0 150 | libnpp=12.2.5.30=0 151 | libnvfatbin=12.6.77=0 152 | libnvjitlink=12.4.127=0 153 | libnvjpeg=12.3.1.117=0 154 | libpng=1.6.39=h5eee18b_0 155 | librosa=0.9.2=pypi_0 156 | libstdcxx-ng=11.2.0=h1234567_1 157 | libtasn1=4.19.0=h5eee18b_0 158 | libtiff=4.5.1=hffd6297_1 159 | libunistring=0.9.10=h27cfd23_0 160 | libwebp=1.3.2=h11a3e52_0 161 | libwebp-base=1.3.2=h5eee18b_1 162 | llvm-openmp=14.0.6=h9e868ea_0 163 | llvmlite=0.43.0=pypi_0 164 | loralib=0.1.2=pypi_0 165 | lxml=5.3.0=pypi_0 166 | lz4-c=1.9.4=h6a678d5_1 167 | markdown=3.7=pypi_0 168 | markupsafe=2.1.3=py39h5eee18b_1 169 | matplotlib=3.9.4=pypi_0 170 | matplotlib-inline=0.1.7=pypi_0 171 | mistune=3.1.0=pypi_0 172 | mkl=2023.1.0=h213fc3f_46344 173 | mkl-service=2.4.0=py39h5eee18b_2 174 | mkl_fft=1.3.11=py39h5eee18b_0 175 | mkl_random=1.2.8=py39h1128e8f_0 176 | more-itertools=10.6.0=pypi_0 177 | mpc=1.1.0=h10f8cd9_1 178 | mpfr=4.0.2=hb69a4c5_1 179 | mpmath=1.3.0=py39h06a4308_0 180 | msgpack=1.1.0=pypi_0 181 | multidict=6.1.0=pypi_0 182 | multiprocess=0.70.16=pypi_0 183 | nbclient=0.10.2=pypi_0 184 | nbconvert=7.16.5=pypi_0 185 | nbformat=5.10.4=pypi_0 186 | ncurses=6.4=h6a678d5_0 187 | nest-asyncio=1.6.0=pypi_0 188 | nettle=3.7.3=hbbd107a_1 189 | networkx=3.2.1=py39h06a4308_0 190 | nltk=3.9.1=pypi_0 191 | notebook=7.3.2=pypi_0 192 | notebook-shim=0.2.4=pypi_0 193 | numba=0.60.0=pypi_0 194 | numpy=1.23.5=pypi_0 195 | nvidia-ml-py=12.560.30=pypi_0 196 | omegaconf=2.3.0=pypi_0 197 | openh264=2.1.1=h4ff587b_0 198 | openjpeg=2.5.2=he7f1fd0_0 199 | openpyxl=3.1.5=pypi_0 200 | openssl=3.0.15=h5eee18b_0 201 | opt-einsum=3.4.0=pypi_0 202 | overrides=7.7.0=pypi_0 203 | packaging=24.2=pypi_0 204 | pandas=2.2.3=pypi_0 205 | pandocfilters=1.5.1=pypi_0 206 | parso=0.8.4=pypi_0 207 | pexpect=4.9.0=pypi_0 208 | phonemizer=3.3.0=pypi_0 209 | pillow=11.0.0=py39hcea889d_1 210 | pip=24.2=py39h06a4308_0 211 | platformdirs=4.3.6=pypi_0 212 | pooch=1.8.2=pypi_0 213 | praatio=6.2.0=pypi_0 214 | prometheus-client=0.21.1=pypi_0 215 | prompt-toolkit=3.0.48=pypi_0 216 | propcache=0.2.1=pypi_0 217 | protobuf=5.29.3=pypi_0 218 | psutil=6.1.1=pypi_0 219 | ptyprocess=0.7.0=pypi_0 220 | pure-eval=0.2.3=pypi_0 221 | pyarrow=19.0.0=pypi_0 222 | pycparser=2.22=pypi_0 223 | pydub=0.25.1=pypi_0 224 | pygments=2.19.1=pypi_0 225 | pylatexenc=2.10=pypi_0 226 | pyparsing=3.2.1=pypi_0 227 | pypinyin=0.44.0=pypi_0 228 | pysocks=1.7.1=py39h06a4308_0 229 | python=3.9.21=he870216_1 230 | python-dateutil=2.9.0.post0=pypi_0 231 | python-json-logger=3.2.1=pypi_0 232 | pytorch=2.5.1=py3.9_cuda12.4_cudnn9.1.0_0 233 | pytorch-cuda=12.4=hc786d27_7 234 | pytorch-mutex=1.0=cuda 235 | pytz=2024.2=pypi_0 236 | pyworld=0.3.5=pypi_0 237 | pyyaml=6.0.2=py39h5eee18b_0 238 | pyzmq=26.2.0=pypi_0 239 | rapidfuzz=3.11.0=pypi_0 240 | rdflib=7.1.3=pypi_0 241 | readline=8.2=h5eee18b_0 242 | referencing=0.36.1=pypi_0 243 | regex=2024.11.6=pypi_0 244 | requests=2.32.3=py39h06a4308_1 245 | resampy=0.4.3=pypi_0 246 | rfc3339-validator=0.1.4=pypi_0 247 | rfc3986=1.5.0=pypi_0 248 | rfc3986-validator=0.1.1=pypi_0 249 | rpds-py=0.22.3=pypi_0 250 | s3prl=0.4.17=pypi_0 251 | safetensors=0.5.2=pypi_0 252 | scikit-learn=1.6.1=pypi_0 253 | scipy=1.13.1=pypi_0 254 | seaborn=0.13.2=pypi_0 255 | segments=2.2.1=pypi_0 256 | send2trash=1.8.3=pypi_0 257 | sentencepiece=0.1.97=pypi_0 258 | setuptools=73.0.1=pypi_0 259 | six=1.17.0=pypi_0 260 | sniffio=1.3.1=pypi_0 261 | soundfile=0.13.0=pypi_0 262 | soupsieve=2.6=pypi_0 263 | soxr=0.5.0.post1=pypi_0 264 | sqlite=3.45.3=h5eee18b_0 265 | stack-data=0.6.3=pypi_0 266 | sympy=1.13.1=pypi_0 267 | tabulate=0.9.0=pypi_0 268 | tbb=2021.8.0=hdb19cb5_0 269 | tensorboardx=2.6.2.2=pypi_0 270 | terminado=0.18.1=pypi_0 271 | threadpoolctl=3.5.0=pypi_0 272 | tinycss2=1.4.0=pypi_0 273 | tk=8.6.14=h39e8969_0 274 | tokenizers=0.21.0=pypi_0 275 | tomli=2.2.1=pypi_0 276 | torch-complex=0.4.4=pypi_0 277 | torchaudio=2.5.1=py39_cu124 278 | torchcrepe=0.0.23=pypi_0 279 | torchtriton=3.1.0=py39 280 | torchvision=0.20.1=py39_cu124 281 | tornado=6.4.2=pypi_0 282 | tqdm=4.67.1=pypi_0 283 | traitlets=5.14.3=pypi_0 284 | transformers=4.48.0=pypi_0 285 | typeguard=4.4.1=pypi_0 286 | types-python-dateutil=2.9.0.20241206=pypi_0 287 | typing_extensions=4.12.2=py39h06a4308_0 288 | tzdata=2025.1=pypi_0 289 | unidecode=1.3.8=pypi_0 290 | uri-template=1.3.0=pypi_0 291 | uritemplate=4.1.1=pypi_0 292 | urllib3=2.2.3=py39h06a4308_0 293 | wcwidth=0.2.13=pypi_0 294 | webcolors=24.11.1=pypi_0 295 | webencodings=0.5.1=pypi_0 296 | websocket-client=1.8.0=pypi_0 297 | wheel=0.44.0=py39h06a4308_0 298 | widgetsnbextension=4.0.13=pypi_0 299 | xxhash=3.5.0=pypi_0 300 | xz=5.4.6=h5eee18b_1 301 | yaml=0.2.5=h7b6447c_0 302 | yarl=1.18.3=pypi_0 303 | zipp=3.21.0=pypi_0 304 | zlib=1.2.13=h5eee18b_1 305 | zstd=1.5.6=hc292b87_0 306 | -------------------------------------------------------------------------------- /utils/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import k2 3 | import cmudict 4 | import json 5 | from jiwer import wer 6 | import math 7 | import torch.nn.functional as F 8 | import time 9 | import numpy as np 10 | 11 | class WFSTdecoder: 12 | def __init__(self, device: str, phoneme_lexicon: list, is_ipa=False): 13 | self.device = torch.device(device) 14 | self.lexicon = phoneme_lexicon 15 | self.is_ipa = is_ipa 16 | # read from npy file 17 | self.similarity_matrix = np.load('utils/rule_sim_matrix.npy') 18 | self.similarity_matrix = torch.from_numpy(self.similarity_matrix).to(self.device) 19 | self.phn2idx = { 20 | "|": 0, "OW": 1, "UW": 2, "EY": 3, "AW": 4, "AH": 5, "AO": 6, "AY": 7, "EH": 8, "K": 9, 21 | "NG": 10, "F": 11, "JH": 12, "M": 13, "CH": 14, "IH": 15, "UH": 16, "HH": 17, "L": 18, 22 | "AA": 19, "R": 20, "TH": 21, "AE": 22, "D": 23, "Z": 24, "OY": 25, "DH": 26, "IY": 27, "B": 28, "W": 29, "S": 30, 23 | "T": 31, "SH": 32, "ZH": 33, "ER": 34, "V": 35, "Y": 36, "N": 37, "G": 38, "P": 39, "-": 40 24 | } 25 | 26 | def ctc_topo(self, num_phonemes: int): 27 | return k2.ctc_topo(max_token=num_phonemes, modified=False) 28 | 29 | def create_dense_fsa_vec(self, log_probs: torch.Tensor, lengths: torch.Tensor) -> k2.DenseFsaVec: 30 | """ 31 | Create a DenseFsaVec from model outputs (log_probs) and sequence lengths. 32 | 33 | Args: 34 | log_probs (torch.Tensor): A tensor of shape (B, T, C), where 35 | - B: Batch size, 36 | - T: Number of time steps, 37 | - C: Number of classes (including blank). 38 | lengths (torch.Tensor): A tensor of shape (B,) containing the valid lengths 39 | (number of time steps) for each sample in the batch. 40 | 41 | Returns: 42 | k2.DenseFsaVec: A DenseFsaVec object that represents the dense FSA for each sequence. 43 | 44 | Raises: 45 | ValueError: If the input tensors do not have compatible dimensions. 46 | """ 47 | # Validate the dimensions of log_probs 48 | if log_probs.ndim != 3: 49 | raise ValueError(f"log_probs must be a 3D tensor of shape (B, T, C), but got {log_probs.ndim}D tensor.") 50 | 51 | B, T, C = log_probs.shape 52 | 53 | if lengths.shape[0] != B: 54 | raise ValueError(f"The size of lengths must match the batch size, but got {lengths.shape[0]} and {B}.") 55 | lengths = lengths.to(dtype=torch.int32) 56 | 57 | 58 | log_probs = F.log_softmax(log_probs, dim=-1) 59 | 60 | supervision_segments = [] 61 | for i in range(B): 62 | supervision_segments.append([i, 0, lengths[i].item()]) 63 | supervision_segments = torch.tensor(supervision_segments, dtype=torch.int32) 64 | 65 | dense_fsa_vec = k2.DenseFsaVec(log_probs, supervision_segments) 66 | return dense_fsa_vec 67 | 68 | # def create_emit_graph(self, probNT): 69 | # if probNT.ndim != 2: 70 | # raise ValueError(f"probNT must be a 2D tensor, but got {probNT.ndim}D tensor.") 71 | # lines = [] 72 | # for i in range(probNT.shape[0]): 73 | # for j in range(probNT.shape[1]-1): 74 | # lines.append(f"{i} {i+1} {j} {probNT[i, j]}") 75 | # lines.append(f"{probNT.shape[0]} {probNT.shape[0]+1} {-1} {1}") 76 | # lines.append(f"{probNT.shape[0]+1}") 77 | # return '\n'.join(lines) 78 | 79 | def cmu2ipa(self, phoneme_seq, map='config/ipa2cmu.json'): 80 | map_dict = json.load(open(map)) 81 | ipa_seq = [] 82 | flag = False 83 | for phoneme in phoneme_seq: 84 | if phoneme == '': 85 | print(f"Phoneme {phoneme} not found in the CMU dictionary") 86 | continue 87 | flag = False 88 | for k, v in map_dict.items(): 89 | if phoneme in v: 90 | flag = True 91 | if ' ' in k: 92 | k = k.split() 93 | ipa_seq.extend(k) 94 | break 95 | ipa_seq.append(k) 96 | break 97 | if not flag: 98 | raise ValueError(f"Phoneme {phoneme} not found in the map") 99 | return ipa_seq 100 | 101 | # def ipa_to_cmu(self, ipa_list): 102 | # cmu_list = [] 103 | # for ipa in ipa_list: 104 | # if ipa in ipa2cmu: 105 | # cmu_value = ipa2cmu[ipa].split()[0] 106 | # cmu_list.append(cmu_value) 107 | # else: 108 | # continue 109 | # # cmu_list.append(f"") 110 | # return cmu_list 111 | # { "a": "AA", 112 | # "b": "B", 113 | # "d": "D", 114 | # "e": "EY",...} 115 | def ipa2cmu(self, phoneme, map='config/ipa2cmu.json'): 116 | map_dict = json.load(open(map)) 117 | if phoneme in map_dict: 118 | cmu_value = map_dict[phoneme].split()[0] 119 | return cmu_value 120 | else: 121 | raise ValueError(f"Phoneme {phoneme} not found in the map") 122 | 123 | def get_phoneme_sequence(self, ref_text): 124 | phoneme_sequence = [] 125 | ref_text = ref_text.lower() 126 | ref_text = ref_text.replace('.', '').replace(',', '').replace('?', '').replace('!', '') 127 | for word in ref_text.split(): 128 | phonemes = cmudict.dict().get(word.lower(), [''])[0] 129 | if phonemes == '': 130 | if word == 'quivers': 131 | phonemes = 'K W IH V ER S'.split() 132 | else: 133 | print(f"Word {word} not found in the CMU dictionary") 134 | continue 135 | phonemes = [phoneme.rstrip('012') for phoneme in phonemes] 136 | phoneme_sequence.extend(phonemes) 137 | if self.is_ipa: 138 | ipa_sequence = self.cmu2ipa(phoneme_sequence) 139 | else: 140 | ipa_sequence = phoneme_sequence 141 | return ipa_sequence 142 | 143 | def get_phoneme_id(self, phoneme): 144 | if phoneme == '|' or phoneme == '-' or phoneme == '' or phoneme == '' or phoneme == '' or phoneme == '' or phoneme == 'SIL' or phoneme == 'SPN': 145 | return 0 146 | return self.lexicon.index(phoneme) 147 | 148 | def get_phoneme_ids(self, phoneme_sequence): 149 | return [self.get_phoneme_id(phoneme) for phoneme in phoneme_sequence] 150 | 151 | def create_fsa_graph(self, phonemes, beta, skip=False, back=True, sub=True): 152 | alpha = 1 - 10**(-beta) 153 | error_score = (1-alpha) 154 | lines = [] 155 | for i, phone in enumerate(phonemes): 156 | for j in range(len(phonemes)+1): 157 | if i == j: 158 | continue 159 | if j == i+1: 160 | lines.append(f"{i} {j} {phone} {phone} {alpha}") 161 | if skip: 162 | mis_token = ''.join([str(i), str(j)]) 163 | self.lexicon.append(mis_token) 164 | mis_id = self.lexicon.index(mis_token) 165 | lines.append(f"{i} {j} {0} {mis_id} {error_score * math.exp(-(i-j)**2/2)}") 166 | continue 167 | if sub: 168 | # phoneme id 2 phoneme 169 | phone_text = self.lexicon[phone] 170 | if self.is_ipa: 171 | phone_text = self.ipa2cmu(phone_text) 172 | phoneme_id_sim = self.phn2idx[phone_text] 173 | # select top3 similar phonemes's id 174 | top3_sim_id = torch.topk(self.similarity_matrix[phoneme_id_sim], 2).indices 175 | # use self.phn2idx to get phoneme, remeber id is the value and we want key 176 | for sim_id in top3_sim_id: 177 | if sim_id == phoneme_id_sim: 178 | continue 179 | sim_phoneme = list(self.phn2idx.keys())[sim_id] 180 | # print(f"sim_phoneme: {sim_phoneme}") 181 | if sim_phoneme == '|' or sim_phoneme == '-' or sim_phoneme == '' or sim_phoneme == '' or sim_phoneme == '' or sim_phoneme == '' or sim_phoneme == 'SIL' or sim_phoneme == 'SPN': 182 | continue 183 | if self.is_ipa: 184 | sim_phoneme = self.cmu2ipa([sim_phoneme])[0] 185 | try: 186 | sim_phoneme_id = self.get_phoneme_id(sim_phoneme) 187 | except ValueError: 188 | print(f"Phoneme {sim_phoneme} not found in the lexicon") 189 | continue 190 | sub_token = ''.join([str(i), str(j)]) 191 | self.lexicon.append(sub_token) 192 | sub_id = self.lexicon.index(sub_token) 193 | lines.append(f"{i} {j} {0} {sub_id} {error_score/10000}") 194 | else: 195 | if alpha == 1: 196 | continue 197 | if j > i and skip: 198 | if j - i > 3: 199 | continue 200 | mis_token = ''.join([str(i), str(j)]) 201 | self.lexicon.append(mis_token) 202 | mis_id = self.lexicon.index(mis_token) 203 | lines.append(f"{i} {j} {0} {mis_id} {error_score * math.exp(-(i-j)**2/2)}") 204 | continue 205 | if j < i and back: 206 | if i - j > 2: 207 | continue 208 | rep_token = ''.join([str(i), str(j)]) 209 | self.lexicon.append(rep_token) 210 | rep_id = self.lexicon.index(rep_token) 211 | lines.append(f"{i} {j} {0} {rep_id} {error_score * math.exp(-(i-j)**2/2)}") 212 | continue 213 | 214 | lines.append(f"{len(phonemes)} {len(phonemes)+1} {-1} {-1} {0}") 215 | lines.append(f"{len(phonemes)+1}") 216 | return '\n'.join(lines) 217 | 218 | def extract_phoneme_states(self, transition_list): 219 | merged_list = [] 220 | current_merge = None 221 | 222 | for item in transition_list: 223 | if "" in item: 224 | if current_merge is None: 225 | current_merge = item 226 | else: 227 | current_merge = current_merge.split("")[0] + "" + item.split("")[-1] 228 | else: 229 | if current_merge is not None: 230 | merged_list.append(current_merge) 231 | current_merge = None 232 | merged_list.append(item) 233 | 234 | # Handle case where the last item was a merge 235 | if current_merge is not None: 236 | merged_list.append(current_merge) 237 | 238 | return merged_list 239 | 240 | def detect_dysfluency(self, phoneme_seq): 241 | dysfluency_results = [] 242 | state_history = set() 243 | prev_end = -1 244 | 245 | 246 | clean_states = [] 247 | current_state = 0 248 | 249 | for elem in phoneme_seq: 250 | if '' in elem: 251 | _, j = elem.split('') 252 | current_state = int(j) 253 | else: 254 | start = current_state 255 | end = start + 1 256 | clean_states.append((start, end, elem)) 257 | current_state = end 258 | 259 | # Detect dysfluency 260 | for item in clean_states: 261 | start, end, phoneme = item 262 | # get the minimum time in state_history: [(1, 2, 'ɛ'), ...] 263 | min_time = min(state_history) if state_history else -1 264 | 265 | if start in state_history: 266 | dysfluency_results.append({ 267 | "phoneme": phoneme, 268 | "start_state": start, 269 | "end_state": end, 270 | "dysfluency_type": "repetition" 271 | }) 272 | # Check for insertion (insertion occurs when start is earlier than any previous time) 273 | elif start < min_time: 274 | dysfluency_results.append({ 275 | "phoneme": phoneme, 276 | "start_state": start, 277 | "end_state": end, 278 | "dysfluency_type": "insertion" 279 | }) 280 | # Otherwise, it's a normal transition 281 | elif start > prev_end + 1: 282 | dysfluency_results.append({ 283 | "phoneme": "", 284 | "start_state": prev_end, 285 | "end_state": start, 286 | "dysfluency_type": "deletion" 287 | }) 288 | dysfluency_results.append({ 289 | "phoneme": phoneme, 290 | "start_state": start, 291 | "end_state": end, 292 | "dysfluency_type": "normal" 293 | }) 294 | else: 295 | dysfluency_results.append({ 296 | "phoneme": phoneme, 297 | "start_state": start, 298 | "end_state": end, 299 | "dysfluency_type": "normal" 300 | }) 301 | 302 | state_history.add(start) 303 | prev_end = end 304 | 305 | return dysfluency_results 306 | 307 | def _deduplicate_and_filter(self, phoneme_list): 308 | """Deduplicate consecutive phonemes and filter out unwanted tokens.""" 309 | filtered_list = [] 310 | prev_label = None 311 | for phoneme in phoneme_list: 312 | if phoneme != prev_label and phoneme not in ['|', '-', "", "", "", "", 'STL', 'SPN', '', '']: 313 | filtered_list.append(phoneme) 314 | prev_label = phoneme 315 | return filtered_list 316 | 317 | 318 | 319 | def _build_lattice(self, emission, length, ref_text, 320 | beta, back, skip, num_beam): 321 | """Create the k2 lattice for one utterance and return it together 322 | with the reference phoneme sequence.""" 323 | emission = emission.to(self.device) 324 | 325 | # 1. Dense FSA from model post‑eriors 326 | dense_fsa = self.create_dense_fsa_vec( 327 | emission.unsqueeze(0), 328 | torch.tensor([length], dtype=torch.int32) 329 | ).to(self.device) 330 | 331 | # 2. Reference text → phoneme IDs → FSA 332 | phoneme_sequence = self.get_phoneme_sequence(ref_text) 333 | phoneme_ids = self.get_phoneme_ids(phoneme_sequence) 334 | ref_fsa_str = self.create_fsa_graph( 335 | phoneme_ids, beta=beta, skip=skip, back=back 336 | ) 337 | ref_fsa = k2.Fsa.from_str(ref_fsa_str, acceptor=False) 338 | ref_fsa = k2.arc_sort(ref_fsa).to(self.device) 339 | 340 | # 3. CTC topology 341 | ctc_fsa = self.ctc_topo(len(self.lexicon)) 342 | ctc_fsa = k2.arc_sort(ctc_fsa).to(self.device) 343 | 344 | # 4. Compose & intersect to obtain the lattice 345 | composed = k2.compose( 346 | ctc_fsa.to("cpu"), 347 | ref_fsa.to("cpu"), 348 | treat_epsilons_specially=True 349 | ).to(self.device) 350 | lattice = k2.intersect_dense( 351 | composed, dense_fsa, output_beam=num_beam 352 | ).to(self.device) 353 | 354 | return lattice, phoneme_sequence 355 | 356 | def _compute_loss(self, lattice): 357 | loss = lattice.get_tot_scores( 358 | log_semiring=True, use_double_scores=True 359 | ) 360 | return -loss.mean() 361 | 362 | def _decode_lattice(self, lattice): 363 | """Return dysfluency annotations derived from the shortest path.""" 364 | shortest = k2.shortest_path(lattice, use_double_scores=True) 365 | phoneme_seq = [self.lexicon[i] for i in shortest[0].aux_labels[:-1]] 366 | 367 | phoneme_seq = self._deduplicate_and_filter(phoneme_seq) 368 | state_seq = self.extract_phoneme_states(phoneme_seq) 369 | return self.detect_dysfluency(state_seq) 370 | 371 | def decode(self, batch, beta, num_beam=2, back=True, skip=False, train=False): 372 | """ 373 | Decode a batch of sequences using the WFST decoder. 374 | Args: 375 | batch (dict): A dictionary containing the following 376 | - "id": List of IDs in the batch. 377 | - "tensor": Batched emission tensors (padded). 378 | - "ref_text": List of reference texts. 379 | - "lengths": Original lengths of each sequence in the batch. 380 | beta (float): Parameter for the FSA graph. 381 | num_beam (int): Number of beams for decoding. 382 | back (bool): Whether to allow back transitions in the FSA graph. 383 | skip (bool): Whether to allow skip transitions in the FSA graph. 384 | train (bool): Whether the model is in training mode. 385 | Returns: 386 | results (list): A list of dictionaries containing the following 387 | - "id": Sample ID. 388 | - "ref_phonemes": Reference phoneme sequence. 389 | - "dys_detect": List of detected dysfluencies. 390 | - "decode_phonemes": List of decoded phonemes. 391 | """ 392 | ids = batch["id"] 393 | emissions = batch["tensor"] 394 | ref_texts = batch["ref_text"] 395 | lengths = batch["lengths"] 396 | 397 | results = [] 398 | for idx, sample_id in enumerate(ids): 399 | emission = emissions[idx, : lengths[idx]] 400 | ref_text = ref_texts[idx] 401 | 402 | lattice, ref_phonemes = self._build_lattice( 403 | emission, lengths[idx], ref_text, 404 | beta, back, skip, num_beam 405 | ) 406 | 407 | if train: 408 | loss = self._compute_loss(lattice) 409 | # Clean up GPU memory before returning 410 | del lattice 411 | torch.cuda.empty_cache() 412 | return loss 413 | 414 | dys_info = self._decode_lattice(lattice) 415 | results.append({ 416 | "id": sample_id, 417 | "ref_phonemes": ref_phonemes, 418 | "dys_detect": dys_info, 419 | "decode_phonemes": [item["phoneme"] for item in dys_info], 420 | }) 421 | 422 | del lattice 423 | torch.cuda.empty_cache() 424 | 425 | return results 426 | 427 | --------------------------------------------------------------------------------