├── .github └── ISSUE_TEMPLATE │ ├── ask_questions.md │ ├── bug_report.md │ ├── config.yaml │ └── error_docs.md ├── .gitignore ├── LICENSE ├── README.md ├── README_ja.md ├── README_zh.md ├── api.py ├── data ├── train_example.jsonl └── val_example.jsonl ├── deepspeed_conf └── ds_stage1.json ├── demo1.py ├── demo2.py ├── demo_libtorch.py ├── demo_onnx.py ├── export.py ├── export_meta.py ├── finetune.sh ├── image ├── aed_figure.png ├── asr_results.png ├── asr_results1.png ├── asr_results2.png ├── dingding_funasr.png ├── dingding_sv.png ├── inference.png ├── sensevoice.png ├── sensevoice2.png ├── ser_figure.png ├── ser_table.png ├── webui.png └── wechat.png ├── model.py ├── requirements.txt ├── utils ├── __init__.py ├── ctc_alignment.py ├── export_utils.py ├── frontend.py ├── infer_utils.py └── model_bin.py └── webui.py /.github/ISSUE_TEMPLATE/ask_questions.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: ❓ Questions/Help 3 | about: If you have questions, please first search existing issues and docs 4 | labels: 'question, needs triage' 5 | --- 6 | 7 | Notice: In order to resolve issues more efficiently, please raise issue following the template. 8 | (注意:为了更加高效率解决您遇到的问题,请按照模板提问,补充细节) 9 | 10 | ## ❓ Questions and Help 11 | 12 | 13 | ### Before asking: 14 | 1. search the issues. 15 | 2. search the docs. 16 | 17 | 18 | 19 | #### What is your question? 20 | 21 | #### Code 22 | 23 | 24 | 25 | #### What have you tried? 26 | 27 | #### What's your environment? 28 | 29 | - OS (e.g., Linux): 30 | - FunASR Version (e.g., 1.0.0): 31 | - ModelScope Version (e.g., 1.11.0): 32 | - PyTorch Version (e.g., 2.0.0): 33 | - How you installed funasr (`pip`, source): 34 | - Python version: 35 | - GPU (e.g., V100M32) 36 | - CUDA/cuDNN version (e.g., cuda11.7): 37 | - Docker version (e.g., funasr-runtime-sdk-cpu-0.4.1) 38 | - Any other relevant information: -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🐛 Bug Report 3 | about: Submit a bug report to help us improve 4 | labels: 'bug, needs triage' 5 | --- 6 | 7 | Notice: In order to resolve issues more efficiently, please raise issue following the template. 8 | (注意:为了更加高效率解决您遇到的问题,请按照模板提问,补充细节) 9 | 10 | ## 🐛 Bug 11 | 12 | 13 | 14 | ### To Reproduce 15 | 16 | Steps to reproduce the behavior (**always include the command you ran**): 17 | 18 | 1. Run cmd '....' 19 | 2. See error 20 | 21 | 22 | 23 | 24 | #### Code sample 25 | 27 | 28 | ### Expected behavior 29 | 30 | 31 | 32 | ### Environment 33 | 34 | - OS (e.g., Linux): 35 | - FunASR Version (e.g., 1.0.0): 36 | - ModelScope Version (e.g., 1.11.0): 37 | - PyTorch Version (e.g., 2.0.0): 38 | - How you installed funasr (`pip`, source): 39 | - Python version: 40 | - GPU (e.g., V100M32) 41 | - CUDA/cuDNN version (e.g., cuda11.7): 42 | - Docker version (e.g., funasr-runtime-sdk-cpu-0.4.1) 43 | - Any other relevant information: 44 | 45 | ### Additional context 46 | 47 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yaml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/error_docs.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 📚 Documentation/Typos 3 | about: Report an issue related to documentation or a typo 4 | labels: 'documentation, needs triage' 5 | --- 6 | 7 | ## 📚 Documentation 8 | 9 | For typos and doc fixes, please go ahead and: 10 | 11 | 1. Create an issue. 12 | 2. Fix the typo. 13 | 3. Submit a PR. 14 | 15 | Thanks! -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | ./__pycache__/ 3 | */__pycache__/ 4 | */*/__pycache__/ 5 | */*/*/__pycache__/ 6 | .DS_Store 7 | init_model/ 8 | *.tar.gz 9 | test_local/ 10 | RapidASR 11 | export/* 12 | *.pyc 13 | .eggs 14 | MaaS-lib 15 | .gitignore 16 | .egg* 17 | dist 18 | build 19 | funasr.egg-info 20 | docs/_build 21 | modelscope 22 | samples 23 | .ipynb_checkpoints 24 | outputs* 25 | emotion2vec* 26 | GPT-SoVITS* 27 | modelscope_models 28 | examples/aishell/llm_asr_nar/* 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Ref to https://github.com/modelscope/FunASR?tab=readme-ov-file#license 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ([简体中文](./README_zh.md)|English|[日本語](./README_ja.md)) 2 | 3 | 4 | # Introduction 5 | 6 | SenseVoice is a speech foundation model with multiple speech understanding capabilities, including automatic speech recognition (ASR), spoken language identification (LID), speech emotion recognition (SER), and audio event detection (AED). 7 | 8 |
9 | 10 |
11 | 12 | [//]: # (
) 13 | 14 |
15 |

16 | Homepage 17 | | What's News 18 | | Benchmarks 19 | | Install 20 | | Usage 21 | | Community 22 |

23 | 24 | Model Zoo: 25 | [modelscope](https://www.modelscope.cn/models/iic/SenseVoiceSmall), [huggingface](https://huggingface.co/FunAudioLLM/SenseVoiceSmall) 26 | 27 | Online Demo: 28 | [modelscope demo](https://www.modelscope.cn/studios/iic/SenseVoice), [huggingface space](https://huggingface.co/spaces/FunAudioLLM/SenseVoice) 29 | 30 | 31 |
32 | 33 | 34 | 35 | # Highlights 🎯 36 | **SenseVoice** focuses on high-accuracy multilingual speech recognition, speech emotion recognition, and audio event detection. 37 | - **Multilingual Speech Recognition:** Trained with over 400,000 hours of data, supporting more than 50 languages, the recognition performance surpasses that of the Whisper model. 38 | - **Rich transcribe:** 39 | - Possess excellent emotion recognition capabilities, achieving and surpassing the effectiveness of the current best emotion recognition models on test data. 40 | - Offer sound event detection capabilities, supporting the detection of various common human-computer interaction events such as bgm, applause, laughter, crying, coughing, and sneezing. 41 | - **Efficient Inference:** The SenseVoice-Small model utilizes a non-autoregressive end-to-end framework, leading to exceptionally low inference latency. It requires only 70ms to process 10 seconds of audio, which is 15 times faster than Whisper-Large. 42 | - **Convenient Finetuning:** Provide convenient finetuning scripts and strategies, allowing users to easily address long-tail sample issues according to their business scenarios. 43 | - **Service Deployment:** Offer service deployment pipeline, supporting multi-concurrent requests, with client-side languages including Python, C++, HTML, Java, and C#, among others. 44 | 45 | 46 | # What's New 🔥 47 | - 2024/11: Add support for timestamp based on the CTC alignment. 48 | - 2024/7: Added Export Features for [ONNX](./demo_onnx.py) and [libtorch](./demo_libtorch.py), as well as Python Version Runtimes: [funasr-onnx-0.4.0](https://pypi.org/project/funasr-onnx/), [funasr-torch-0.1.1](https://pypi.org/project/funasr-torch/) 49 | - 2024/7: The [SenseVoice-Small](https://www.modelscope.cn/models/iic/SenseVoiceSmall) voice understanding model is open-sourced, which offers high-precision multilingual speech recognition, emotion recognition, and audio event detection capabilities for Mandarin, Cantonese, English, Japanese, and Korean and leads to exceptionally low inference latency. 50 | - 2024/7: The CosyVoice for natural speech generation with multi-language, timbre, and emotion control. CosyVoice excels in multi-lingual voice generation, zero-shot voice generation, cross-lingual voice cloning, and instruction-following capabilities. [CosyVoice repo](https://github.com/FunAudioLLM/CosyVoice) and [CosyVoice space](https://www.modelscope.cn/studios/iic/CosyVoice-300M). 51 | - 2024/7: [FunASR](https://github.com/modelscope/FunASR) is a fundamental speech recognition toolkit that offers a variety of features, including speech recognition (ASR), Voice Activity Detection (VAD), Punctuation Restoration, Language Models, Speaker Verification, Speaker Diarization and multi-talker ASR. 52 | 53 | 54 | # Benchmarks 📝 55 | 56 | ## Multilingual Speech Recognition 57 | We compared the performance of multilingual speech recognition between SenseVoice and Whisper on open-source benchmark datasets, including AISHELL-1, AISHELL-2, Wenetspeech, LibriSpeech, and Common Voice. In terms of Chinese and Cantonese recognition, the SenseVoice-Small model has advantages. 58 | 59 |
60 | 61 |
62 | 63 | ## Speech Emotion Recognition 64 | 65 | Due to the current lack of widely-used benchmarks and methods for speech emotion recognition, we conducted evaluations across various metrics on multiple test sets and performed a comprehensive comparison with numerous results from recent benchmarks. The selected test sets encompass data in both Chinese and English, and include multiple styles such as performances, films, and natural conversations. Without finetuning on the target data, SenseVoice was able to achieve and exceed the performance of the current best speech emotion recognition models. 66 | 67 |
68 | 69 |
70 | 71 | Furthermore, we compared multiple open-source speech emotion recognition models on the test sets, and the results indicate that the SenseVoice-Large model achieved the best performance on nearly all datasets, while the SenseVoice-Small model also surpassed other open-source models on the majority of the datasets. 72 | 73 |
74 | 75 |
76 | 77 | ## Audio Event Detection 78 | 79 | Although trained exclusively on speech data, SenseVoice can still function as a standalone event detection model. We compared its performance on the environmental sound classification ESC-50 dataset against the widely used industry models BEATS and PANN. The SenseVoice model achieved commendable results on these tasks. However, due to limitations in training data and methodology, its event classification performance has some gaps compared to specialized AED models. 80 | 81 |
82 | 83 |
84 | 85 | ## Computational Efficiency 86 | 87 | The SenseVoice-Small model deploys a non-autoregressive end-to-end architecture, resulting in extremely low inference latency. With a similar number of parameters to the Whisper-Small model, it infers more than 5 times faster than Whisper-Small and 15 times faster than Whisper-Large. 88 | 89 |
90 | 91 |
92 | 93 | 94 | # Requirements 95 | 96 | ```shell 97 | pip install -r requirements.txt 98 | ``` 99 | 100 | 101 | # Usage 102 | 103 | ## Inference 104 | 105 | Supports input of audio in any format and of any duration. 106 | 107 | ```python 108 | from funasr import AutoModel 109 | from funasr.utils.postprocess_utils import rich_transcription_postprocess 110 | 111 | model_dir = "iic/SenseVoiceSmall" 112 | 113 | 114 | model = AutoModel( 115 | model=model_dir, 116 | trust_remote_code=True, 117 | remote_code="./model.py", 118 | vad_model="fsmn-vad", 119 | vad_kwargs={"max_single_segment_time": 30000}, 120 | device="cuda:0", 121 | ) 122 | 123 | # en 124 | res = model.generate( 125 | input=f"{model.model_path}/example/en.mp3", 126 | cache={}, 127 | language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech" 128 | use_itn=True, 129 | batch_size_s=60, 130 | merge_vad=True, # 131 | merge_length_s=15, 132 | ) 133 | text = rich_transcription_postprocess(res[0]["text"]) 134 | print(text) 135 | ``` 136 | 137 |
Parameter Description (Click to Expand) 138 | 139 | - `model_dir`: The name of the model, or the path to the model on the local disk. 140 | - `trust_remote_code`: 141 | - When `True`, it means that the model's code implementation is loaded from `remote_code`, which specifies the exact location of the `model` code (for example, `model.py` in the current directory). It supports absolute paths, relative paths, and network URLs. 142 | - When `False`, it indicates that the model's code implementation is the integrated version within [FunASR](https://github.com/modelscope/FunASR). At this time, modifications made to `model.py` in the current directory will not be effective, as the version loaded is the internal one from FunASR. For the model code, [click here to view](https://github.com/modelscope/FunASR/tree/main/funasr/models/sense_voice). 143 | - `vad_model`: This indicates the activation of VAD (Voice Activity Detection). The purpose of VAD is to split long audio into shorter clips. In this case, the inference time includes both VAD and SenseVoice total consumption, and represents the end-to-end latency. If you wish to test the SenseVoice model's inference time separately, the VAD model can be disabled. 144 | - `vad_kwargs`: Specifies the configurations for the VAD model. `max_single_segment_time`: denotes the maximum duration for audio segmentation by the `vad_model`, with the unit being milliseconds (ms). 145 | - `use_itn`: Whether the output result includes punctuation and inverse text normalization. 146 | - `batch_size_s`: Indicates the use of dynamic batching, where the total duration of audio in the batch is measured in seconds (s). 147 | - `merge_vad`: Whether to merge short audio fragments segmented by the VAD model, with the merged length being `merge_length_s`, in seconds (s). 148 | - `ban_emo_unk`: Whether to ban the output of the `emo_unk` token. 149 |
150 | 151 | If all inputs are short audios (<30s), and batch inference is needed to speed up inference efficiency, the VAD model can be removed, and `batch_size` can be set accordingly. 152 | ```python 153 | model = AutoModel(model=model_dir, trust_remote_code=True, device="cuda:0") 154 | 155 | res = model.generate( 156 | input=f"{model.model_path}/example/en.mp3", 157 | cache={}, 158 | language="zh", # "zh", "en", "yue", "ja", "ko", "nospeech" 159 | use_itn=False, 160 | batch_size=64, 161 | ) 162 | ``` 163 | 164 | For more usage, please refer to [docs](https://github.com/modelscope/FunASR/blob/main/docs/tutorial/README.md) 165 | 166 | ### Inference directly 167 | 168 | Supports input of audio in any format, with an input duration limit of 30 seconds or less. 169 | 170 | ```python 171 | from model import SenseVoiceSmall 172 | from funasr.utils.postprocess_utils import rich_transcription_postprocess 173 | 174 | model_dir = "iic/SenseVoiceSmall" 175 | m, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device="cuda:0") 176 | m.eval() 177 | 178 | res = m.inference( 179 | data_in=f"{kwargs['model_path']}/example/en.mp3", 180 | language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech" 181 | use_itn=False, 182 | ban_emo_unk=False, 183 | **kwargs, 184 | ) 185 | 186 | text = rich_transcription_postprocess(res[0][0]["text"]) 187 | print(text) 188 | ``` 189 | 190 | ### Export and Test 191 |
ONNX and Libtorch Export 192 | 193 | #### ONNX 194 | ```python 195 | # pip3 install -U funasr funasr-onnx 196 | from pathlib import Path 197 | from funasr_onnx import SenseVoiceSmall 198 | from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess 199 | 200 | 201 | model_dir = "iic/SenseVoiceSmall" 202 | 203 | model = SenseVoiceSmall(model_dir, batch_size=10, quantize=True) 204 | 205 | # inference 206 | wav_or_scp = ["{}/.cache/modelscope/hub/{}/example/en.mp3".format(Path.home(), model_dir)] 207 | 208 | res = model(wav_or_scp, language="auto", use_itn=True) 209 | print([rich_transcription_postprocess(i) for i in res]) 210 | ``` 211 | Note: ONNX model is exported to the original model directory. 212 | 213 | #### Libtorch 214 | ```python 215 | from pathlib import Path 216 | from funasr_torch import SenseVoiceSmall 217 | from funasr_torch.utils.postprocess_utils import rich_transcription_postprocess 218 | 219 | 220 | model_dir = "iic/SenseVoiceSmall" 221 | 222 | model = SenseVoiceSmall(model_dir, batch_size=10, device="cuda:0") 223 | 224 | wav_or_scp = ["{}/.cache/modelscope/hub/{}/example/en.mp3".format(Path.home(), model_dir)] 225 | 226 | res = model(wav_or_scp, language="auto", use_itn=True) 227 | print([rich_transcription_postprocess(i) for i in res]) 228 | ``` 229 | Note: Libtorch model is exported to the original model directory. 230 |
231 | 232 | ## Service 233 | 234 | ### Deployment with FastAPI 235 | ```shell 236 | export SENSEVOICE_DEVICE=cuda:0 237 | fastapi run --port 50000 238 | ``` 239 | 240 | ## Finetune 241 | 242 | ### Requirements 243 | 244 | ```shell 245 | git clone https://github.com/alibaba/FunASR.git && cd FunASR 246 | pip3 install -e ./ 247 | ``` 248 | 249 | ### Data prepare 250 | 251 | Data examples 252 | 253 | ```text 254 | {"key": "YOU0000008470_S0000238_punc_itn", "text_language": "<|en|>", "emo_target": "<|NEUTRAL|>", "event_target": "<|Speech|>", "with_or_wo_itn": "<|withitn|>", "target": "Including legal due diligence, subscription agreement, negotiation.", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/industrial_data/english_all/audio/YOU0000008470_S0000238.wav", "target_len": 7, "source_len": 140} 255 | {"key": "AUD0000001556_S0007580", "text_language": "<|en|>", "emo_target": "<|NEUTRAL|>", "event_target": "<|Speech|>", "with_or_wo_itn": "<|woitn|>", "target": "there is a tendency to identify the self or take interest in what one has got used to", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/industrial_data/english_all/audio/AUD0000001556_S0007580.wav", "target_len": 18, "source_len": 360} 256 | ``` 257 | 258 | Full ref to `data/train_example.jsonl` 259 | 260 |
Data Prepare Details 261 | 262 | Description: 263 | - `key`: audio file unique ID 264 | - `source`:path to the audio file 265 | - `source_len`:number of fbank frames of the audio file 266 | - `target`:transcription 267 | - `target_len`:length of target 268 | - `text_language`:language id of the audio file 269 | - `emo_target`:emotion label of the audio file 270 | - `event_target`:event label of the audio file 271 | - `with_or_wo_itn`:whether includes punctuation and inverse text normalization 272 | 273 | 274 | `train_text.txt` 275 | 276 | 277 | ```bash 278 | BAC009S0764W0121 甚至出现交易几乎停滞的情况 279 | BAC009S0916W0489 湖北一公司以员工名义贷款数十员工负债千万 280 | asr_example_cn_en 所有只要处理 data 不管你是做 machine learning 做 deep learning 做 data analytics 做 data science 也好 scientist 也好通通都要都做的基本功啊那 again 先先对有一些>也许对 281 | ID0012W0014 he tried to think how it could be 282 | ``` 283 | 284 | `train_wav.scp` 285 | 286 | 287 | 288 | ```bash 289 | BAC009S0764W0121 https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/BAC009S0764W0121.wav 290 | BAC009S0916W0489 https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/BAC009S0916W0489.wav 291 | asr_example_cn_en https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_cn_en.wav 292 | ID0012W0014 https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_en.wav 293 | ``` 294 | 295 | `train_text_language.txt` 296 | 297 | The language ids include `<|zh|>`、`<|en|>`、`<|yue|>`、`<|ja|>` and `<|ko|>`. 298 | 299 | ```bash 300 | BAC009S0764W0121 <|zh|> 301 | BAC009S0916W0489 <|zh|> 302 | asr_example_cn_en <|zh|> 303 | ID0012W0014 <|en|> 304 | ``` 305 | 306 | `train_emo.txt` 307 | 308 | The emotion labels include`<|HAPPY|>`、`<|SAD|>`、`<|ANGRY|>`、`<|NEUTRAL|>`、`<|FEARFUL|>`、`<|DISGUSTED|>` and `<|SURPRISED|>`. 309 | 310 | ```bash 311 | BAC009S0764W0121 <|NEUTRAL|> 312 | BAC009S0916W0489 <|NEUTRAL|> 313 | asr_example_cn_en <|NEUTRAL|> 314 | ID0012W0014 <|NEUTRAL|> 315 | ``` 316 | 317 | `train_event.txt` 318 | 319 | The event labels include`<|BGM|>`、`<|Speech|>`、`<|Applause|>`、`<|Laughter|>`、`<|Cry|>`、`<|Sneeze|>`、`<|Breath|>` and `<|Cough|>`. 320 | 321 | ```bash 322 | BAC009S0764W0121 <|Speech|> 323 | BAC009S0916W0489 <|Speech|> 324 | asr_example_cn_en <|Speech|> 325 | ID0012W0014 <|Speech|> 326 | ``` 327 | 328 | `Command` 329 | ```shell 330 | # generate train.jsonl and val.jsonl from wav.scp, text.txt, text_language.txt, emo_target.txt, event_target.txt 331 | sensevoice2jsonl \ 332 | ++scp_file_list='["../../../data/list/train_wav.scp", "../../../data/list/train_text.txt", "../../../data/list/train_text_language.txt", "../../../data/list/train_emo.txt", "../../../data/list/train_event.txt"]' \ 333 | ++data_type_list='["source", "target", "text_language", "emo_target", "event_target"]' \ 334 | ++jsonl_file_out="../../../data/list/train.jsonl" 335 | ``` 336 | 337 | If there is no `train_text_language.txt`, `train_emo_target.txt` and `train_event_target.txt`, the language, emotion and event label will be predicted automatically by using the `SenseVoice` model. 338 | ```shell 339 | # generate train.jsonl and val.jsonl from wav.scp and text.txt 340 | sensevoice2jsonl \ 341 | ++scp_file_list='["../../../data/list/train_wav.scp", "../../../data/list/train_text.txt"]' \ 342 | ++data_type_list='["source", "target"]' \ 343 | ++jsonl_file_out="../../../data/list/train.jsonl" \ 344 | ++model_dir='iic/SenseVoiceSmall' 345 | ``` 346 |
347 | 348 | ### Finetune 349 | 350 | Ensure to modify the train_tool in finetune.sh to the absolute path of `funasr/bin/train_ds.py` from the FunASR installation directory you have set up earlier. 351 | 352 | ```shell 353 | bash finetune.sh 354 | ``` 355 | 356 | ## WebUI 357 | 358 | ```shell 359 | python webui.py 360 | ``` 361 | 362 |
363 | 364 | 365 | ## Remarkable Third-Party Work 366 | - Triton (GPU) Deployment Best Practices: Using Triton + TensorRT, tested with FP32, achieving an acceleration ratio of 526 on V100 GPU. FP16 support is in progress. [Repository](https://github.com/modelscope/FunASR/blob/main/runtime/triton_gpu/README.md) 367 | - Sherpa-onnx Deployment Best Practices: Supports using SenseVoice in 10 programming languages: C++, C, Python, C#, Go, Swift, Kotlin, Java, JavaScript, and Dart. Also supports deploying SenseVoice on platforms like iOS, Android, and Raspberry Pi. [Repository](https://k2-fsa.github.io/sherpa/onnx/sense-voice/index.html) 368 | - [SenseVoice.cpp](https://github.com/lovemefan/SenseVoice.cpp). Inference of SenseVoice in pure C/C++ based on GGML, supporting 3-bit, 4-bit, 5-bit, 8-bit quantization, etc. with no third-party dependencies. 369 | - [streaming-sensevoice](https://github.com/pengzhendong/streaming-sensevoice) processes inference in chunks. To achieve pseudo-streaming, it employs a truncated attention mechanism, sacrificing some accuracy. Additionally, this technology supports CTC prefix beam search and hot-word boosting features. 370 | - [OmniSenseVoice](https://github.com/lifeiteng/OmniSenseVoice) is optimized for lightning-fast inference and batching process. 371 | - [SenseVoice Hotword](https://www.modelscope.cn/models/dengcunqin/SenseVoiceSmall_hotword),Neural Network Hotword Enhancement,[Contextualized End-to-End Speech Recognition with Contextual Phrase Prediction Network](https://mp.weixin.qq.com/s/1QkIvh8j7rrUjRyWOgAvdA)。 372 | 373 | # Community 374 | If you encounter problems in use, you can directly raise Issues on the github page. 375 | 376 | You can also scan the following DingTalk group QR code to join the community group for communication and discussion. 377 | 378 | | FunASR | 379 | |:--------------------------------------------------------:| 380 | | | 381 | 382 | 383 | -------------------------------------------------------------------------------- /README_ja.md: -------------------------------------------------------------------------------- 1 | # SenseVoice 2 | 3 | 「[简体中文](./README_zh.md)」|「[English](./README.md)」|「日本語」 4 | 5 | SenseVoiceは、音声認識(ASR)、言語識別(LID)、音声感情認識(SER)、および音響イベント分類(AEC)または音響イベント検出(AED)を含む音声理解能力を備えた音声基盤モデルです。本プロジェクトでは、SenseVoiceモデルの紹介と、複数のタスクテストセットでのベンチマーク、およびモデルの体験に必要な環境のインストールと推論方法を提供します。 6 | 7 |
8 | 9 |
10 | [//]: # (
) 11 | 12 |
13 |

14 | ホームページ 15 | | 最新情報 16 | | 性能評価 17 | | 環境インストール 18 | | 使用方法チュートリアル 19 | | お問い合わせ 20 |

21 | 22 | モデルリポジトリ:[modelscope](https://www.modelscope.cn/models/iic/SenseVoiceSmall),[huggingface](https://huggingface.co/FunAudioLLM/SenseVoiceSmall) 23 | 24 | オンライン体験: 25 | [modelscope demo](https://www.modelscope.cn/studios/iic/SenseVoice), [huggingface space](https://huggingface.co/spaces/FunAudioLLM/SenseVoice) 26 | 27 |
28 | 29 | 30 | # コア機能 🎯 31 | **SenseVoice**は、高精度な多言語音声認識、感情認識、および音声イベント検出に焦点を当てています。 32 | - **多言語認識:** 40万時間以上のデータを使用してトレーニングされ、50以上の言語をサポートし、認識性能はWhisperモデルを上回ります。 33 | - **リッチテキスト認識:** 34 | - 優れた感情認識能力を持ち、テストデータで現在の最良の感情認識モデルの効果を達成および上回ります。 35 | - 音声イベント検出能力を提供し、音楽、拍手、笑い声、泣き声、咳、くしゃみなどのさまざまな一般的な人間とコンピュータのインタラクションイベントを検出します。 36 | - **効率的な推論:** SenseVoice-Smallモデルは非自己回帰エンドツーエンドフレームワークを採用しており、推論遅延が非常に低く、10秒の音声の推論に70msしかかかりません。Whisper-Largeより15倍高速です。 37 | - **簡単な微調整:** 便利な微調整スクリプトと戦略を提供し、ユーザーがビジネスシナリオに応じてロングテールサンプルの問題を簡単に解決できるようにします。 38 | - **サービス展開:** マルチコンカレントリクエストをサポートする完全なサービス展開パイプラインを提供し、クライアントサイドの言語にはPython、C++、HTML、Java、C#などがあります。 39 | 40 | 41 | # 最新情報 🔥 42 | - 2024/7:新しく[ONNX](./demo_onnx.py)と[libtorch](./demo_libtorch.py)のエクスポート機能を追加し、Pythonバージョンのランタイム:[funasr-onnx-0.4.0](https://pypi.org/project/funasr-onnx/)、[funasr-torch-0.1.1](https://pypi.org/project/funasr-torch/)も提供開始。 43 | - 2024/7: [SenseVoice-Small](https://www.modelscope.cn/models/iic/SenseVoiceSmall) 多言語音声理解モデルがオープンソース化されました。中国語、広東語、英語、日本語、韓国語の多言語音声認識、感情認識、およびイベント検出能力をサポートし、非常に低い推論遅延を実現しています。 44 | - 2024/7: CosyVoiceは自然な音声生成に取り組んでおり、多言語、音色、感情制御をサポートします。多言語音声生成、ゼロショット音声生成、クロスランゲージ音声クローン、および指示に従う能力に優れています。[CosyVoice repo](https://github.com/FunAudioLLM/CosyVoice) and [CosyVoice オンライン体験](https://www.modelscope.cn/studios/iic/CosyVoice-300M). 45 | - 2024/7: [FunASR](https://github.com/modelscope/FunASR) は、音声認識(ASR)、音声活動検出(VAD)、句読点復元、言語モデル、話者検証、話者分離、およびマルチトーカーASRなどの機能を提供する基本的な音声認識ツールキットです。 46 | 47 | 48 | # ベンチマーク 📝 49 | 50 | ## 多言語音声認識 51 | 52 | オープンソースのベンチマークデータセット(AISHELL-1、AISHELL-2、Wenetspeech、Librispeech、Common Voiceを含む)でSenseVoiceとWhisperの多言語音声認識性能と推論効率を比較しました。中国語と広東語の認識効果において、SenseVoice-Smallモデルは明らかな効果の優位性を持っています。 53 | 54 |
55 | 56 |
57 | 58 | ## 感情認識 59 | 60 | 現在、広く使用されている感情認識のテスト指標と方法が不足しているため、複数のテストセットでさまざまな指標をテストし、最近のベンチマークの複数の結果と包括的に比較しました。選択されたテストセットには、中国語/英語の両方の言語と、パフォーマンス、映画、自然な会話などのさまざまなスタイルのデータが含まれています。ターゲットデータの微調整を行わない前提で、SenseVoiceはテストデータで現在の最良の感情認識モデルの効果を達成および上回ることができました。 61 | 62 |
63 | 64 |
65 | 66 | さらに、テストセットで複数のオープンソースの感情認識モデルを比較し、結果はSenseVoice-Largeモデルがほぼすべてのデータで最良の効果を達成し、SenseVoice-Smallモデルも多数のデータセットで他のオープンソースモデルを上回る効果を達成したことを示しています。 67 | 68 |
69 | 70 |
71 | 72 | ## イベント検出 73 | 74 | SenseVoiceは音声データのみでトレーニングされていますが、イベント検出モデルとして単独で使用することもできます。環境音分類ESC-50データセットで、現在業界で広く使用されているBEATSおよびPANNモデルの効果と比較しました。SenseVoiceモデルはこれらのタスクで良好な効果を達成しましたが、トレーニングデータとトレーニング方法の制約により、イベント分類の効果は専門のイベント検出モデルと比較してまだ一定の差があります。 75 | 76 |
77 | 78 |
79 | 80 | ## 推論効率 81 | 82 | SenseVoice-smallモデルは非自己回帰エンドツーエンドアーキテクチャを採用しており、推論遅延が非常に低いです。Whisper-Smallモデルと同等のパラメータ量で、Whisper-Smallモデルより5倍高速で、Whisper-Largeモデルより15倍高速です。同時に、SenseVoice-smallモデルは音声の長さが増加しても、推論時間に明らかな増加はありません。 83 | 84 |
85 | 86 |
87 | 88 | 89 | # 環境インストール 🐍 90 | 91 | ```shell 92 | pip install -r requirements.txt 93 | ``` 94 | 95 | 96 | # 使用方法 🛠️ 97 | 98 | ## 推論 99 | 100 | 任意の形式の音声入力をサポートし、任意の長さの入力をサポートします。 101 | 102 | ```python 103 | from funasr import AutoModel 104 | from funasr.utils.postprocess_utils import rich_transcription_postprocess 105 | 106 | model_dir = "iic/SenseVoiceSmall" 107 | 108 | 109 | model = AutoModel( 110 | model=model_dir, 111 | trust_remote_code=True, 112 | remote_code="./model.py", 113 | vad_model="fsmn-vad", 114 | vad_kwargs={"max_single_segment_time": 30000}, 115 | device="cuda:0", 116 | ) 117 | 118 | # en 119 | res = model.generate( 120 | input=f"{model.model_path}/example/en.mp3", 121 | cache={}, 122 | language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech" 123 | use_itn=True, 124 | batch_size_s=60, 125 | merge_vad=True, # 126 | merge_length_s=15, 127 | ) 128 | text = rich_transcription_postprocess(res[0]["text"]) 129 | print(text) 130 | ``` 131 | 132 |
パラメータの説明(クリックして展開) 133 | 134 | - `model_dir`:モデル名、またはローカルディスク上のモデルパス。 135 | - `trust_remote_code`: 136 | - `True`は、modelコードの実装が`remote_code`からロードされることを意味し、`remote_code`は`model`コードの正確な位置を指定します(例:現在のディレクトリの`model.py`)。絶対パス、相対パス、およびネットワークURLをサポートします。 137 | - `False`は、modelコードの実装が[FunASR](https://github.com/modelscope/FunASR)内部に統合されたバージョンであることを意味し、この場合、現在のディレクトリの`model.py`を変更しても効果がありません。FunASR内部バージョンがロードされるためです。モデルコード[こちらを参照](https://github.com/modelscope/FunASR/tree/main/funasr/models/sense_voice)。 138 | - `vad_model`:VAD(音声活動検出)を有効にすることを示します。VADの目的は、長い音声を短いクリップに分割することです。この場合、推論時間にはVADとSenseVoiceの合計消費が含まれ、エンドツーエンドの遅延を表します。SenseVoiceモデルの推論時間を個別にテストする場合は、VADモデルを無効にできます。 139 | - `vad_kwargs`:VADモデルの設定を指定します。`max_single_segment_time`:`vad_model`による音声セグメントの最大長を示し、単位はミリ秒(ms)です。 140 | - `use_itn`:出力結果に句読点と逆テキスト正規化が含まれるかどうか。 141 | - `batch_size_s`:動的バッチの使用を示し、バッチ内の音声の合計長を秒(s)で測定します。 142 | - `merge_vad`:VADモデルによって分割された短い音声フラグメントをマージするかどうか。マージ後の長さは`merge_length_s`で、単位は秒(s)です。 143 | - `ban_emo_unk`:emo_unkラベルを無効にする。 144 |
145 | 146 | すべての入力が短い音声(30秒未満)であり、バッチ推論が必要な場合、推論効率を向上させるためにVADモデルを削除し、`batch_size`を設定できます。 147 | 148 | ```python 149 | model = AutoModel(model=model_dir, trust_remote_code=True, device="cuda:0") 150 | 151 | res = model.generate( 152 | input=f"{model.model_path}/example/en.mp3", 153 | cache={}, 154 | language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech" 155 | use_itn=True, 156 | batch_size=64, 157 | ) 158 | ``` 159 | 160 | 詳細な使用方法については、[ドキュメント](https://github.com/modelscope/FunASR/blob/main/docs/tutorial/README.md)を参照してください。 161 | 162 | ### 直接推論 163 | 164 | 任意の形式の音声入力をサポートし、入力音声の長さは30秒以下に制限されます。 165 | 166 | ```python 167 | from model import SenseVoiceSmall 168 | from funasr.utils.postprocess_utils import rich_transcription_postprocess 169 | 170 | model_dir = "iic/SenseVoiceSmall" 171 | m, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device="cuda:0") 172 | m.eval() 173 | 174 | res = m.inference( 175 | data_in=f"{kwargs['model_path']}/example/en.mp3", 176 | language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech" 177 | use_itn=False, 178 | ban_emo_unk=False, 179 | **kwargs, 180 | ) 181 | 182 | text = rich_transcription_postprocess(res[0][0]["text"]) 183 | print(text) 184 | ``` 185 | 186 | ## サービス展開 187 | 188 | 未完了 189 | 190 | ### エクスポートとテスト 191 |
ONNXとLibtorchのエクスポート 192 | 193 | #### ONNX 194 | ```python 195 | # pip3 install -U funasr funasr-onnx 196 | from pathlib import Path 197 | from funasr_onnx import SenseVoiceSmall 198 | from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess 199 | 200 | 201 | model_dir = "iic/SenseVoiceSmall" 202 | 203 | model = SenseVoiceSmall(model_dir, batch_size=10, quantize=True) 204 | 205 | # inference 206 | wav_or_scp = ["{}/.cache/modelscope/hub/{}/example/en.mp3".format(Path.home(), model_dir)] 207 | 208 | res = model(wav_or_scp, language="auto", use_itn=True) 209 | print([rich_transcription_postprocess(i) for i in res]) 210 | ``` 211 | 備考:ONNXモデルは元のモデルディレクトリにエクスポートされます。 212 | 213 | #### Libtorch 214 | ```python 215 | from pathlib import Path 216 | from funasr_torch import SenseVoiceSmall 217 | from funasr_torch.utils.postprocess_utils import rich_transcription_postprocess 218 | 219 | 220 | model_dir = "iic/SenseVoiceSmall" 221 | 222 | model = SenseVoiceSmall(model_dir, batch_size=10, device="cuda:0") 223 | 224 | wav_or_scp = ["{}/.cache/modelscope/hub/{}/example/en.mp3".format(Path.home(), model_dir)] 225 | 226 | res = model(wav_or_scp, language="auto", use_itn=True) 227 | print([rich_transcription_postprocess(i) for i in res]) 228 | ``` 229 | 備考:Libtorchモデルは元のモデルディレクトリにエクスポートされます。 230 | 231 |
232 | 233 | ### 展開 234 | 235 | ### FastAPIを使った展開 236 | ```shell 237 | export SENSEVOICE_DEVICE=cuda:0 238 | fastapi run --port 50000 239 | ``` 240 | 241 | ## 微調整 242 | 243 | ### トレーニング環境のインストール 244 | 245 | ```shell 246 | git clone https://github.com/alibaba/FunASR.git && cd FunASR 247 | pip3 install -e ./ 248 | ``` 249 | 250 | ### データ準備 251 | 252 | データ例 253 | ```text 254 | {"key": "YOU0000008470_S0000238_punc_itn", "text_language": "<|en|>", "emo_target": "<|NEUTRAL|>", "event_target": "<|Speech|>", "with_or_wo_itn": "<|withitn|>", "target": "Including legal due diligence, subscription agreement, negotiation.", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/industrial_data/english_all/audio/YOU0000008470_S0000238.wav", "target_len": 7, "source_len": 140} 255 | {"key": "AUD0000001556_S0007580", "text_language": "<|en|>", "emo_target": "<|NEUTRAL|>", "event_target": "<|Speech|>", "with_or_wo_itn": "<|woitn|>", "target": "there is a tendency to identify the self or take interest in what one has got used to", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/industrial_data/english_all/audio/AUD0000001556_S0007580.wav", "target_len": 18, "source_len": 360} 256 | ``` 257 | 詳細は `data/train_example.jsonl` を参照してください。 258 | 259 |
データ準備の詳細 260 | 261 | 説明: 262 | - `key`:音声ファイルのユニークID 263 | - `source`:音声ファイルのパス 264 | - `source_len`:音声ファイルのfbankフレーム数 265 | - `target`:文字起こし結果 266 | - `target_len`:target(文字起こし)の長さ 267 | - `text_language`:音声ファイルの言語ID 268 | - `emo_target`:音声ファイルの感情ラベル 269 | - `event_target`:音声ファイルのイベントラベル 270 | - `with_or_wo_itn`:句読点と逆テキスト正規化を含むかどうか 271 | 272 | `train_text.txt` 273 | ```bash 274 | BAC009S0764W0121 甚至出现交易几乎停滞的情况 275 | BAC009S0916W0489 湖北一公司以员工名义贷款数十员工负债千万 276 | asr_example_cn_en 所有只要处理 data 不管你是做 machine learning 做 deep learning 做 data analytics 做 data science 也好 scientist 也好通通都要都做的基本功啊那 again 先先对有一些>也许对 277 | ID0012W0014 he tried to think how it could be 278 | ``` 279 | `train_wav.scp` 280 | ```bash 281 | BAC009S0764W0121 https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/BAC009S0764W0121.wav 282 | BAC009S0916W0489 https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/BAC009S0916W0489.wav 283 | asr_example_cn_en https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_cn_en.wav 284 | ID0012W0014 https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_en.wav 285 | ``` 286 | `train_text_language.txt` 287 | 言語IDは `<|zh|>`、`<|en|>`、`<|yue|>`、`<|ja|>`、および `<|ko|>`を含みます。 288 | ```bash 289 | BAC009S0764W0121 <|zh|> 290 | BAC009S0916W0489 <|zh|> 291 | asr_example_cn_en <|zh|> 292 | ID0012W0014 <|en|> 293 | ``` 294 | `train_emo.txt` 295 | 感情ラベルは、`<|HAPPY|>`、`<|SAD|>`、`<|ANGRY|>`、`<|NEUTRAL|>`、`<|FEARFUL|>`、`<|DISGUSTED|>` および `<|SURPRISED|>`を含みます。 296 | ```bash 297 | BAC009S0764W0121 <|NEUTRAL|> 298 | BAC009S0916W0489 <|NEUTRAL|> 299 | asr_example_cn_en <|NEUTRAL|> 300 | ID0012W0014 <|NEUTRAL|> 301 | ``` 302 | `train_event.txt` 303 | イベントラベルは、 `<|BGM|>`、`<|Speech|>`、`<|Applause|>`、`<|Laughter|>`、`<|Cry|>`、`<|Sneeze|>`、`<|Breath|>` および `<|Cough|>`を含みます。 304 | ```bash 305 | BAC009S0764W0121 <|Speech|> 306 | BAC009S0916W0489 <|Speech|> 307 | asr_example_cn_en <|Speech|> 308 | ID0012W0014 <|Speech|> 309 | ``` 310 | `コマンド` 311 | ```shell 312 | # wav.scp、text.txt、text_language.txt、emo_target.txt、event_target.txt から train.jsonl と val.jsonl を生成します 313 | sensevoice2jsonl \ 314 | ++scp_file_list='["../../../data/list/train_wav.scp", "../../../data/list/train_text.txt", "../../../data/list/train_text_language.txt", "../../../data/list/train_emo.txt", "../../../data/list/train_event.txt"]' \ 315 | ++data_type_list='["source", "target", "text_language", "emo_target", "event_target"]' \ 316 | ++jsonl_file_out="../../../data/list/train.jsonl" 317 | ``` 318 | `train_text_language.txt`、`train_emo_target.txt`、`train_event_target.txt` がない場合、`SenseVoice` モデルを使用して言語、感情、およびイベントラベルが自動的に予測されます。 319 | ```shell 320 | # wav.scp と text.txt から train.jsonl と val.jsonl を生成します 321 | sensevoice2jsonl \ 322 | ++scp_file_list='["../../../data/list/train_wav.scp", "../../../data/list/train_text.txt"]' \ 323 | ++data_type_list='["source", "target"]' \ 324 | ++jsonl_file_out="../../../data/list/train.jsonl" 325 | ``` 326 |
327 | 328 | ### トレーニングの開始 329 | 330 | `finetune.sh`の`train_tool`を、前述のFunASRパス内の`funasr/bin/train_ds.py`の絶対パスに変更することを忘れないでください。 331 | 332 | ```shell 333 | bash finetune.sh 334 | ``` 335 | 336 | ## WebUI 337 | 338 | ```shell 339 | python webui.py 340 | ``` 341 | 342 |
343 | 344 | ## 注目すべきサードパーティの取り組み 345 | - Triton (GPU) デプロイメントのベストプラクティス:Triton + TensorRT を使用し、FP32 でテスト。V100 GPU で加速比 526 を達成。FP16 のサポートは進行中です。[リポジトリ](https://github.com/modelscope/FunASR/blob/main/runtime/triton_gpu/README.md) 346 | - Sherpa-onnx デプロイメントのベストプラクティス:SenseVoice を10種類のプログラミング言語(C++, C, Python, C#, Go, Swift, Kotlin, Java, JavaScript, Dart)で使用可能。また、iOS, Android, Raspberry Pi などのプラットフォームでも SenseVoice をデプロイできます。[リポジトリ](https://k2-fsa.github.io/sherpa/onnx/sense-voice/index.html) 347 | - [SenseVoice.cpp](https://github.com/lovemefan/SenseVoice.cpp) GGMLに基づいて純粋なC/C++でSenseVoiceを推測し、3ビット、4ビット、5ビット、8ビット量子化などをサポートし、サードパーティの依存関係はありません。 348 | - [streaming-sensevoice](https://github.com/pengzhendong/streaming-sensevoice) ストリーム型SenseVoiceは、チャンク(chunk)方式で推論を行います。擬似ストリーミング処理を実現するために、一部の精度を犠牲にして切り捨て注意機構(truncated attention)を採用しています。さらに、この技術はCTCプレフィックスビームサーチ(CTC prefix beam search)とホットワード強化機能もサポートしています。 349 | - [OmniSenseVoice](https://github.com/lifeiteng/OmniSenseVoice) は、超高速推論とバッチ処理のために最適化されています。 350 | - [SenseVoice Hotword](https://www.modelscope.cn/models/dengcunqin/SenseVoiceSmall_hotword),ニューラルネットワークホットワード強化,[WeNetにおけるCPPNベースのニューラルネットワークホットワード強化のオープンソース](https://mp.weixin.qq.com/s/1QkIvh8j7rrUjRyWOgAvdA)。 351 | # お問い合わせ 352 | 353 | 使用中に問題が発生した場合は、githubページで直接Issuesを提起できます。音声に興味のある方は、以下のDingTalkグループQRコードをスキャンしてコミュニティグループに参加し、交流と議論を行ってください。 354 | 355 | | FunASR | 356 | |:--------------------------------------------------------:| 357 | | | 358 | 359 | -------------------------------------------------------------------------------- /README_zh.md: -------------------------------------------------------------------------------- 1 | # SenseVoice 2 | 3 | 「简体中文」|「[English](./README.md)」|「[日本語](./README_ja.md)」 4 | 5 | SenseVoice 是具有音频理解能力的音频基础模型,包括语音识别(ASR)、语种识别(LID)、语音情感识别(SER)和声学事件分类(AEC)或声学事件检测(AED)。本项目提供 SenseVoice 模型的介绍以及在多个任务测试集上的 benchmark,以及体验模型所需的环境安装的与推理方式。 6 | 7 |
8 | 9 |
10 | 11 |
12 |

13 | Homepage 14 | | 最新动态 15 | | 性能评测 16 | | 环境安装 17 | | 用法教程 18 | | 联系我们 19 | 20 |

21 | 22 | 模型仓库:[modelscope](https://www.modelscope.cn/models/iic/SenseVoiceSmall),[huggingface](https://huggingface.co/FunAudioLLM/SenseVoiceSmall) 23 | 24 | 在线体验: 25 | [modelscope demo](https://www.modelscope.cn/studios/iic/SenseVoice), [huggingface space](https://huggingface.co/spaces/FunAudioLLM/SenseVoice) 26 | 27 |
28 | 29 | 30 | 31 | # 核心功能 🎯 32 | 33 | **SenseVoice** 专注于高精度多语言语音识别、情感辨识和音频事件检测 34 | 35 | - **多语言识别:** 采用超过 40 万小时数据训练,支持超过 50 种语言,识别效果上优于 Whisper 模型。 36 | - **富文本识别:** 37 | - 具备优秀的情感识别,能够在测试数据上达到和超过目前最佳情感识别模型的效果。 38 | - 支持声音事件检测能力,支持音乐、掌声、笑声、哭声、咳嗽、喷嚏等多种常见人机交互事件进行检测。 39 | - **高效推理:** SenseVoice-Small 模型采用非自回归端到端框架,推理延迟极低,10s 音频推理仅耗时 70ms,15 倍优于 Whisper-Large。 40 | - **微调定制:** 具备便捷的微调脚本与策略,方便用户根据业务场景修复长尾样本问题。 41 | - **服务部署:** 具有完整的服务部署链路,支持多并发请求,支持客户端语言有,python、c++、html、java 与 c# 等。 42 | 43 | 44 | 45 | # 最新动态 🔥 46 | 47 | - 2024/7:新增加导出 [ONNX](./demo_onnx.py) 与 [libtorch](./demo_libtorch.py) 功能,以及 python 版本 runtime:[funasr-onnx-0.4.0](https://pypi.org/project/funasr-onnx/),[funasr-torch-0.1.1](https://pypi.org/project/funasr-torch/) 48 | - 2024/7: [SenseVoice-Small](https://www.modelscope.cn/models/iic/SenseVoiceSmall) 多语言音频理解模型开源,支持中、粤、英、日、韩语的多语言语音识别,情感识别和事件检测能力,具有极低的推理延迟。。 49 | - 2024/7: CosyVoice 致力于自然语音生成,支持多语言、音色和情感控制,擅长多语言语音生成、零样本语音生成、跨语言语音克隆以及遵循指令的能力。[CosyVoice repo](https://github.com/FunAudioLLM/CosyVoice) and [CosyVoice 在线体验](https://www.modelscope.cn/studios/iic/CosyVoice-300M). 50 | - 2024/7: [FunASR](https://github.com/modelscope/FunASR) 是一个基础语音识别工具包,提供多种功能,包括语音识别(ASR)、语音端点检测(VAD)、标点恢复、语言模型、说话人验证、说话人分离和多人对话语音识别等。 51 | 52 | 53 | 54 | # 性能评测 📝 55 | 56 | ## 多语言语音识别 57 | 58 | 我们在开源基准数据集(包括 AISHELL-1、AISHELL-2、Wenetspeech、Librispeech 和 Common Voice)上比较了 SenseVoice 与 Whisper 的多语言语音识别性能和推理效率。在中文和粤语识别效果上,SenseVoice-Small 模型具有明显的效果优势。 59 | 60 |
61 | 62 |
63 | 64 | ## 情感识别 65 | 66 | 由于目前缺乏被广泛使用的情感识别测试指标和方法,我们在多个测试集的多种指标进行测试,并与近年来 Benchmark 上的多个结果进行了全面的对比。所选取的测试集同时包含中文 / 英文两种语言以及表演、影视剧、自然对话等多种风格的数据,在不进行目标数据微调的前提下,SenseVoice 能够在测试数据上达到和超过目前最佳情感识别模型的效果。 67 | 68 |
69 | 70 |
71 | 72 | 同时,我们还在测试集上对多个开源情感识别模型进行对比,结果表明,SenseVoice-Large 模型可以在几乎所有数据上都达到了最佳效果,而 SenseVoice-Small 模型同样可以在多数数据集上取得超越其他开源模型的效果。 73 | 74 |
75 | 76 |
77 | 78 | ## 事件检测 79 | 80 | 尽管 SenseVoice 只在语音数据上进行训练,它仍然可以作为事件检测模型进行单独使用。我们在环境音分类 ESC-50 数据集上与目前业内广泛使用的 BEATS 与 PANN 模型的效果进行了对比。SenseVoice 模型能够在这些任务上取得较好的效果,但受限于训练数据与训练方式,其事件分类效果专业的事件检测模型相比仍然有一定的差距。 81 | 82 |
83 | 84 |
85 | 86 | ## 推理效率 87 | 88 | SenseVoice-small 模型采用非自回归端到端架构,推理延迟极低。在参数量与 Whisper-Small 模型相当的情况下,比 Whisper-Small 模型推理速度快 5 倍,比 Whisper-Large 模型快 15 倍。同时 SenseVoice-small 模型在音频时长增加的情况下,推理耗时也无明显增加。 89 | 90 |
91 | 92 |
93 | 94 | 95 | 96 | # 安装依赖环境 🐍 97 | 98 | ```shell 99 | pip install -r requirements.txt 100 | ``` 101 | 102 | 103 | 104 | # 用法 🛠️ 105 | 106 | ## 推理 107 | 108 | ### 使用 funasr 推理 109 | 110 | 支持任意格式音频输入,支持任意时长输入 111 | 112 | ```python 113 | from funasr import AutoModel 114 | from funasr.utils.postprocess_utils import rich_transcription_postprocess 115 | 116 | model_dir = "iic/SenseVoiceSmall" 117 | 118 | 119 | model = AutoModel( 120 | model=model_dir, 121 | trust_remote_code=True, 122 | remote_code="./model.py", 123 | vad_model="fsmn-vad", 124 | vad_kwargs={"max_single_segment_time": 30000}, 125 | device="cuda:0", 126 | ) 127 | 128 | # en 129 | res = model.generate( 130 | input=f"{model.model_path}/example/en.mp3", 131 | cache={}, 132 | language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech" 133 | use_itn=True, 134 | batch_size_s=60, 135 | merge_vad=True, 136 | merge_length_s=15, 137 | ) 138 | text = rich_transcription_postprocess(res[0]["text"]) 139 | print(text) 140 | ``` 141 | 142 |
参数说明(点击展开) 143 | 144 | - `model_dir`:模型名称,或本地磁盘中的模型路径。 145 | - `trust_remote_code`: 146 | - `True` 表示 model 代码实现从 `remote_code` 处加载,`remote_code` 指定 `model` 具体代码的位置(例如,当前目录下的 `model.py`),支持绝对路径与相对路径,以及网络 url。 147 | - `False` 表示,model 代码实现为 [FunASR](https://github.com/modelscope/FunASR) 内部集成版本,此时修改当前目录下的 `model.py` 不会生效,因为加载的是 funasr 内部版本,模型代码 [点击查看](https://github.com/modelscope/FunASR/tree/main/funasr/models/sense_voice)。 148 | - `vad_model`:表示开启 VAD,VAD 的作用是将长音频切割成短音频,此时推理耗时包括了 VAD 与 SenseVoice 总耗时,为链路耗时,如果需要单独测试 SenseVoice 模型耗时,可以关闭 VAD 模型。 149 | - `vad_kwargs`:表示 VAD 模型配置,`max_single_segment_time`: 表示 `vad_model` 最大切割音频时长,单位是毫秒 ms。 150 | - `use_itn`:输出结果中是否包含标点与逆文本正则化。 151 | - `batch_size_s` 表示采用动态 batch,batch 中总音频时长,单位为秒 s。 152 | - `merge_vad`:是否将 vad 模型切割的短音频碎片合成,合并后长度为 `merge_length_s`,单位为秒 s。 153 | - `ban_emo_unk`:禁用 emo_unk 标签,禁用后所有的句子都会被赋与情感标签。默认 `False` 154 | 155 |
156 | 157 | 如果输入均为短音频(小于 30s),并且需要批量化推理,为了加快推理效率,可以移除 vad 模型,并设置 `batch_size` 158 | 159 | ```python 160 | model = AutoModel(model=model_dir, trust_remote_code=True, device="cuda:0") 161 | 162 | res = model.generate( 163 | input=f"{model.model_path}/example/en.mp3", 164 | cache={}, 165 | language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech" 166 | use_itn=True, 167 | batch_size=64, 168 | ) 169 | ``` 170 | 171 | 更多详细用法,请参考 [文档](https://github.com/modelscope/FunASR/blob/main/docs/tutorial/README.md) 172 | 173 | ### 直接推理 174 | 175 | 支持任意格式音频输入,输入音频时长限制在 30s 以下 176 | 177 | ```python 178 | from model import SenseVoiceSmall 179 | from funasr.utils.postprocess_utils import rich_transcription_postprocess 180 | 181 | model_dir = "iic/SenseVoiceSmall" 182 | m, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device="cuda:0") 183 | m.eval() 184 | 185 | res = m.inference( 186 | data_in=f"{kwargs ['model_path']}/example/en.mp3", 187 | language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech" 188 | use_itn=False, 189 | ban_emo_unk=False, 190 | **kwargs, 191 | ) 192 | 193 | text = rich_transcription_postprocess(res [0][0]["text"]) 194 | print(text) 195 | ``` 196 | 197 | ## 服务部署 198 | 199 | Undo 200 | 201 | ### 导出与测试 202 | 203 |
ONNX 与 Libtorch 导出 204 | 205 | #### ONNX 206 | 207 | ```python 208 | # pip3 install -U funasr funasr-onnx 209 | from pathlib import Path 210 | from funasr_onnx import SenseVoiceSmall 211 | from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess 212 | 213 | 214 | model_dir = "iic/SenseVoiceSmall" 215 | 216 | model = SenseVoiceSmall(model_dir, batch_size=10, quantize=True) 217 | 218 | # inference 219 | wav_or_scp = ["{}/.cache/modelscope/hub/{}/example/en.mp3".format(Path.home(), model_dir)] 220 | 221 | res = model(wav_or_scp, language="auto", use_itn=True) 222 | print([rich_transcription_postprocess(i) for i in res]) 223 | ``` 224 | 225 | 备注:ONNX 模型导出到原模型目录中 226 | 227 | #### Libtorch 228 | 229 | ```python 230 | from pathlib import Path 231 | from funasr_torch import SenseVoiceSmall 232 | from funasr_torch.utils.postprocess_utils import rich_transcription_postprocess 233 | 234 | 235 | model_dir = "iic/SenseVoiceSmall" 236 | 237 | model = SenseVoiceSmall(model_dir, batch_size=10, device="cuda:0") 238 | 239 | wav_or_scp = ["{}/.cache/modelscope/hub/{}/example/en.mp3".format(Path.home(), model_dir)] 240 | 241 | res = model(wav_or_scp, language="auto", use_itn=True) 242 | print([rich_transcription_postprocess (i) for i in res]) 243 | ``` 244 | 245 | 备注:Libtorch 模型导出到原模型目录中 246 | 247 |
248 | 249 | ### 部署 250 | 251 | ### 使用 FastAPI 部署 252 | 253 | ```shell 254 | export SENSEVOICE_DEVICE=cuda:0 255 | fastapi run --port 50000 256 | ``` 257 | 258 | ## 微调 259 | 260 | ### 安装训练环境 261 | 262 | ```shell 263 | git clone https://github.com/alibaba/FunASR.git && cd FunASR 264 | pip3 install -e ./ 265 | ``` 266 | 267 | ### 数据准备 268 | 269 | 数据格式需要包括如下几个字段: 270 | 271 | ```text 272 | {"key": "YOU0000008470_S0000238_punc_itn", "text_language": "<|en|>", "emo_target": "<|NEUTRAL|>", "event_target": "<|Speech|>", "with_or_wo_itn": "<|withitn|>", "target": "Including legal due diligence, subscription agreement, negotiation.", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/industrial_data/english_all/audio/YOU0000008470_S0000238.wav", "target_len": 7, "source_len": 140} 273 | {"key": "AUD0000001556_S0007580", "text_language": "<|en|>", "emo_target": "<|NEUTRAL|>", "event_target": "<|Speech|>", "with_or_wo_itn": "<|woitn|>", "target": "there is a tendency to identify the self or take interest in what one has got used to", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/industrial_data/english_all/audio/AUD0000001556_S0007580.wav", "target_len": 18, "source_len": 360} 274 | ``` 275 | 276 | 详细可以参考:`data/train_example.jsonl` 277 | 278 |
数据准备细节介绍 279 | 280 | - `key`: 数据唯一 ID 281 | - `source`:音频文件的路径 282 | - `source_len`:音频文件的 fbank 帧数 283 | - `target`:音频文件标注文本 284 | - `target_len`:音频文件标注文本长度 285 | - `text_language`:音频文件的语种标签 286 | - `emo_target`:音频文件的情感标签 287 | - `event_target`:音频文件的事件标签 288 | - `with_or_wo_itn`:标注文本中是否包含标点与逆文本正则化 289 | 290 | 可以用指令 `sensevoice2jsonl` 从 train_wav.scp、train_text.txt、train_text_language.txt、train_emo_target.txt 和 train_event_target.txt 生成,准备过程如下: 291 | 292 | `train_text.txt` 293 | 294 | 左边为数据唯一 ID,需与 `train_wav.scp` 中的 `ID` 一一对应 295 | 右边为音频文件标注文本,格式如下: 296 | 297 | ```bash 298 | BAC009S0764W0121 甚至出现交易几乎停滞的情况 299 | BAC009S0916W0489 湖北一公司以员工名义贷款数十员工负债千万 300 | asr_example_cn_en 所有只要处理 data 不管你是做 machine learning 做 deep learning 做 data analytics 做 data science 也好 scientist 也好通通都要都做的基本功啊那 again 先先对有一些 > 也许对 301 | ID0012W0014 he tried to think how it could be 302 | ``` 303 | 304 | `train_wav.scp` 305 | 306 | 左边为数据唯一 ID,需与 `train_text.txt` 中的 `ID` 一一对应 307 | 右边为音频文件的路径,格式如下 308 | 309 | ```bash 310 | BAC009S0764W0121 https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/BAC009S0764W0121.wav 311 | BAC009S0916W0489 https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/BAC009S0916W0489.wav 312 | asr_example_cn_en https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_cn_en.wav 313 | ID0012W0014 https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_en.wav 314 | ``` 315 | 316 | `train_text_language.txt` 317 | 318 | 左边为数据唯一 ID,需与 `train_text_language.txt` 中的 `ID` 一一对应 319 | 右边为音频文件的语种标签,支持 `<|zh|>`、`<|en|>`、`<|yue|>`、`<|ja|>` 和 `<|ko|>`,格式如下 320 | 321 | ```bash 322 | BAC009S0764W0121 <|zh|> 323 | BAC009S0916W0489 <|zh|> 324 | asr_example_cn_en <|zh|> 325 | ID0012W0014 <|en|> 326 | ``` 327 | 328 | `train_emo.txt` 329 | 330 | 左边为数据唯一 ID,需与 `train_emo.txt` 中的 `ID` 一一对应 331 | 右边为音频文件的情感标签,支持 `<|HAPPY|>`、`<|SAD|>`、`<|ANGRY|>`、`<|NEUTRAL|>`、`<|FEARFUL|>`、`<|DISGUSTED|>` 和 `<|SURPRISED|>`,格式如下 332 | 333 | ```bash 334 | BAC009S0764W0121 <|NEUTRAL|> 335 | BAC009S0916W0489 <|NEUTRAL|> 336 | asr_example_cn_en <|NEUTRAL|> 337 | ID0012W0014 <|NEUTRAL|> 338 | ``` 339 | 340 | `train_event.txt` 341 | 342 | 左边为数据唯一 ID,需与 `train_event.txt` 中的 `ID` 一一对应 343 | 右边为音频文件的事件标签,支持 `<|BGM|>`、`<|Speech|>`、`<|Applause|>`、`<|Laughter|>`、`<|Cry|>`、`<|Sneeze|>`、`<|Breath|>` 和 `<|Cough|>`,格式如下 344 | 345 | ```bash 346 | BAC009S0764W0121 <|Speech|> 347 | BAC009S0916W0489 <|Speech|> 348 | asr_example_cn_en <|Speech|> 349 | ID0012W0014 <|Speech|> 350 | ``` 351 | 352 | `生成指令` 353 | 354 | ```shell 355 | # generate train.jsonl and val.jsonl from wav.scp, text.txt, text_language.txt, emo_target.txt, event_target.txt 356 | sensevoice2jsonl \ 357 | ++scp_file_list='["../../../data/list/train_wav.scp", "../../../data/list/train_text.txt", "../../../data/list/train_text_language.txt", "../../../data/list/train_emo.txt", "../../../data/list/train_event.txt"]' \ 358 | ++data_type_list='["source", "target", "text_language", "emo_target", "event_target"]' \ 359 | ++jsonl_file_out="../../../data/list/train.jsonl" 360 | ``` 361 | 362 | 若无 train_text_language.txt、train_emo_target.txt 和 train_event_target.txt,则自动通过使用 `SenseVoice` 模型对语种、情感和事件打标。 363 | 364 | ```shell 365 | # generate train.jsonl and val.jsonl from wav.scp and text.txt 366 | sensevoice2jsonl \ 367 | ++scp_file_list='["../../../data/list/train_wav.scp", "../../../data/list/train_text.txt"]' \ 368 | ++data_type_list='["source", "target"]' \ 369 | ++jsonl_file_out="../../../data/list/train.jsonl" \ 370 | ++model_dir='iic/SenseVoiceSmall' 371 | ``` 372 | 373 |
374 | 375 | ### 启动训练 376 | 377 | 注意修改 `finetune.sh` 中 `train_tool` 为你前面安装 FunASR 路径中 `funasr/bin/train_ds.py` 绝对路径 378 | 379 | ```shell 380 | bash finetune.sh 381 | ``` 382 | 383 | ## WebUI 384 | 385 | ```shell 386 | python webui.py 387 | ``` 388 | 389 |
390 | 391 | ## 优秀三方工作 392 | 393 | - Triton(GPU)部署最佳实践,triton + tensorrt,fp32 测试,V100 GPU 上加速比 526,fp16 支持中,[repo](https://github.com/modelscope/FunASR/blob/main/runtime/triton_gpu/README.md) 394 | - sherpa-onnx 部署最佳实践,支持在 10 种编程语言里面使用 SenseVoice, 即 C++, C, Python, C#, Go, Swift, Kotlin, Java, JavaScript, Dart. 支持在 iOS, Android, Raspberry Pi 等平台使用 SenseVoice,[repo](https://k2-fsa.github.io/sherpa/onnx/sense-voice/index.html) 395 | - [SenseVoice.cpp](https://github.com/lovemefan/SenseVoice.cpp) 基于GGML,在纯C/C++中推断SenseVoice,支持3位、4位、5位、8位量化等,无需第三方依赖。 396 | - [流式SenseVoice](https://github.com/pengzhendong/streaming-sensevoice),通过分块(chunk)的方式进行推理,为了实现伪流式处理,采用了截断注意力机制(truncated attention),牺牲了部分精度。此外,该技术还支持CTC前缀束搜索(CTC prefix beam search)以及热词增强功能。 397 | - [OmniSenseVoice](https://github.com/lifeiteng/OmniSenseVoice) 轻量化推理库,支持batch推理。 398 | - [SenseVoice Hotword](https://www.modelscope.cn/models/dengcunqin/SenseVoiceSmall_hotword),神经网络热词增强,[WeNet 中开源基于 CPPN 的神经网络热词增强](https://mp.weixin.qq.com/s/1QkIvh8j7rrUjRyWOgAvdA)。 399 | # 联系我们 400 | 401 | 如果您在使用中遇到问题,可以直接在 github 页面提 Issues。欢迎语音兴趣爱好者扫描以下的钉钉群二维码加入社区群,进行交流和讨论。 402 | 403 | | FunASR | 404 | |:--------------------------------------------------------:| 405 | | | 406 | -------------------------------------------------------------------------------- /api.py: -------------------------------------------------------------------------------- 1 | # Set the device with environment, default is cuda:0 2 | # export SENSEVOICE_DEVICE=cuda:1 3 | 4 | import os, re 5 | from fastapi import FastAPI, File, Form 6 | from fastapi.responses import HTMLResponse 7 | from typing_extensions import Annotated 8 | from typing import List 9 | from enum import Enum 10 | import torchaudio 11 | from model import SenseVoiceSmall 12 | from funasr.utils.postprocess_utils import rich_transcription_postprocess 13 | from io import BytesIO 14 | 15 | 16 | class Language(str, Enum): 17 | auto = "auto" 18 | zh = "zh" 19 | en = "en" 20 | yue = "yue" 21 | ja = "ja" 22 | ko = "ko" 23 | nospeech = "nospeech" 24 | 25 | model_dir = "iic/SenseVoiceSmall" 26 | m, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device=os.getenv("SENSEVOICE_DEVICE", "cuda:0")) 27 | m.eval() 28 | 29 | regex = r"<\|.*\|>" 30 | 31 | app = FastAPI() 32 | 33 | 34 | @app.get("/", response_class=HTMLResponse) 35 | async def root(): 36 | return """ 37 | 38 | 39 | 40 | 41 | Api information 42 | 43 | 44 | Documents of API 45 | 46 | 47 | """ 48 | 49 | @app.post("/api/v1/asr") 50 | async def turn_audio_to_text(files: Annotated[List[bytes], File(description="wav or mp3 audios in 16KHz")], keys: Annotated[str, Form(description="name of each audio joined with comma")], lang: Annotated[Language, Form(description="language of audio content")] = "auto"): 51 | audios = [] 52 | audio_fs = 0 53 | for file in files: 54 | file_io = BytesIO(file) 55 | data_or_path_or_list, audio_fs = torchaudio.load(file_io) 56 | data_or_path_or_list = data_or_path_or_list.mean(0) 57 | audios.append(data_or_path_or_list) 58 | file_io.close() 59 | if lang == "": 60 | lang = "auto" 61 | if keys == "": 62 | key = ["wav_file_tmp_name"] 63 | else: 64 | key = keys.split(",") 65 | res = m.inference( 66 | data_in=audios, 67 | language=lang, # "zh", "en", "yue", "ja", "ko", "nospeech" 68 | use_itn=False, 69 | ban_emo_unk=False, 70 | key=key, 71 | fs=audio_fs, 72 | **kwargs, 73 | ) 74 | if len(res) == 0: 75 | return {"result": []} 76 | for it in res[0]: 77 | it["raw_text"] = it["text"] 78 | it["clean_text"] = re.sub(regex, "", it["text"], 0, re.MULTILINE) 79 | it["text"] = rich_transcription_postprocess(it["text"]) 80 | return {"result": res[0]} 81 | -------------------------------------------------------------------------------- /data/train_example.jsonl: -------------------------------------------------------------------------------- 1 | {"key": "YOU0000008470_S0000238_punc_itn", "text_language": "<|en|>", "emo_target": "<|NEUTRAL|>", "event_target": "<|Speech|>", "with_or_wo_itn": "<|withitn|>", "target": "Including legal due diligence, subscription agreement, negotiation.", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/industrial_data/english_all/audio/YOU0000008470_S0000238.wav", "target_len": 7, "source_len": 140} 2 | {"key": "AUD0000001556_S0007580", "text_language": "<|en|>", "emo_target": "<|NEUTRAL|>", "event_target": "<|Speech|>", "with_or_wo_itn": "<|woitn|>", "target": "there is a tendency to identify the self or take interest in what one has got used to", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/industrial_data/english_all/audio/AUD0000001556_S0007580.wav", "target_len": 18, "source_len": 360} 3 | {"key": "19208207_HJwKrcFJ8_o_segment720", "text_language": "<|en|>", "emo_target": "<|EMO_UNKNOWN|>", "event_target": "<|Speech|>", "with_or_wo_itn": "<|woitn|>", "target": "fourth foul up top now to austin three for leonard and in and out no good rebounded by murray looking for some help and he almost throws it out of bounds", "source": "/cpfs_speech/data/shared/Group-speech/beinian.lzr/data/multilingual/lcd_data/english/production/20240222/wer_0_5/taskid_19208207/wav/19208207_HJwKrcFJ8_o_segment720.wav", "target_len": 31, "source_len": 620} 4 | {"key": "wav000_0872_bb6f4a79bb9f49249083465445b1cafa_01c38131ed3a4c609e8270a09170525d", "text_language": "<|zh|>", "emo_target": "<|NEUTRAL|>", "event_target": "<|Speech|>", "with_or_wo_itn": "<|woitn|>", "target": "案件受理费减半收取计一千六百三十一元赵会龙已预交由赵会龙负担一百零七元", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/industrial_data/16k_common/audio/wav000_0872_bb6f4a79bb9f49249083465445b1cafa_01c38131ed3a4c609e8270a09170525d.wav", "target_len": 35, "source_len": 700} 5 | {"key": "Speaker0244_iPhone_s0_227", "text_language": "<|ko|>", "emo_target": "<|NEUTRAL|>", "event_target": "<|Speech|>", "with_or_wo_itn": "<|woitn|>", "target": "저녁 다 해결합니다", "source": "/cpfs_speech/data/shared/Group-speech/beinian.lzr/data/multilingual/korean/audio/Speaker0244_iPhone_s0_227.wav", "target_len": 10, "source_len": 200} 6 | {"key": "data2sim_speed_part1_channel0_CHANNEL0_SPEAKER0948_SESSION1_009481629_speed12", "text_language": "<|en|>", "emo_target": "<|NEUTRAL|>", "event_target": "<|Speech|>", "with_or_wo_itn": "<|woitn|>", "target": "the money was entrust to him in february this year before he resign in june according to the documents", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/industrial_data/english_all/audio/data2sim_speed_part1_channel0_CHANNEL0_SPEAKER0948_SESSION1_009481629_speed12.wav", "target_len": 19, "source_len": 380} 7 | {"key": "SPEAKER0272_SESSION1_002721613_punc_itn", "text_language": "<|en|>", "emo_target": "<|EMO_UNKNOWN|>", "event_target": "<|Speech|>", "with_or_wo_itn": "<|withitn|>", "target": "Current proposals don't go far enough.", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/industrial_data/english_all/audio/SPEAKER0272_SESSION1_002721613.wav", "target_len": 6, "source_len": 120} 8 | {"key": "wav004_0490_04c4f9cb2cb347a2a156c5cad1a903aa", "text_language": "<|zh|>", "emo_target": "<|NEUTRAL|>", "event_target": "<|Speech|>", "with_or_wo_itn": "<|woitn|>", "target": "小凳子", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/industrial_data/16k_common/audio/wav004_0490_04c4f9cb2cb347a2a156c5cad1a903aa.wav", "target_len": 3, "source_len": 60} 9 | {"key": "18874657_MSnA4nfDC7Q_segment680", "text_language": "<|en|>", "emo_target": "<|NEUTRAL|>", "event_target": "<|Speech|>", "with_or_wo_itn": "<|woitn|>", "target": "and anything that you would like to know please just put it in the email", "source": "/cpfs_speech/data/shared/Group-speech/beinian.lzr/data/multilingual/lcd_data/english/production/20240105/wer_0_5/taskid_18874657/wav/18874657_MSnA4nfDC7Q_segment680.wav", "target_len": 15, "source_len": 300} 10 | {"key": "POD0000007250_S0000518", "text_language": "<|en|>", "emo_target": "<|EMO_UNKNOWN|>", "event_target": "<|Speech|>", "with_or_wo_itn": "<|woitn|>", "target": "i use netflix but and that's not an app", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/industrial_data/english_all/audio/POD0000007250_S0000518.wav", "target_len": 9, "source_len": 180} 11 | -------------------------------------------------------------------------------- /data/val_example.jsonl: -------------------------------------------------------------------------------- 1 | {"key": "datasim_speed_Speaker0129_winPhone_s0_149_speed-10_punc_itn", "text_language": "<|ko|>", "emo_target": "<|NEUTRAL|>", "event_target": "<|Speech|>", "with_or_wo_itn": "<|withitn|>", "target": "이거 올리고 리베옹과 안 좋은 사이가 되었다는.", "source": "/cpfs_speech/data/shared/Group-speech/beinian.lzr/data/multilingual/korean/audio/datasim_speed_Speaker0129_winPhone_s0_149_speed-10.wav", "target_len": 26, "source_len": 520} 2 | {"key": "data2sim_noise_rir_new_Speaker0048_winPhone_s0_102_punc_itn", "text_language": "<|ko|>", "emo_target": "<|NEUTRAL|>", "event_target": "<|Speech|>", "with_or_wo_itn": "<|withitn|>", "target": "개통 대리점이랑 얘길 해봐야 합니다.", "source": "/cpfs_speech/data/shared/Group-speech/beinian.lzr/data/multilingual/korean/audio/data2sim_noise_rir_new_Speaker0048_winPhone_s0_102.wav", "target_len": 20, "source_len": 400} 3 | {"key": "wav005_0655_1225906248786196892_punc_itn", "text_language": "<|yue|>", "emo_target": "<|EMO_UNKNOWN|>", "event_target": "<|Speech|>", "with_or_wo_itn": "<|withitn|>", "target": "万科租售中心。", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/multilingual/cantonese/audio/wav005_0655_1225906248786196892.wav", "target_len": 7, "source_len": 140} 4 | {"key": "datasim_speed_Speaker0732_S0_Android_533_speed-10", "text_language": "<|ko|>", "emo_target": "<|NEUTRAL|>", "event_target": "<|Speech|>", "with_or_wo_itn": "<|woitn|>", "target": "댁 공개 석상에서 에이즈 환자랑 껴안고 뭐 키스하고", "source": "/cpfs_speech/data/shared/Group-speech/beinian.lzr/data/multilingual/korean/audio/datasim_speed_Speaker0732_S0_Android_533_speed-10.wav", "target_len": 28, "source_len": 560} 5 | {"key": "wav001_0437_lATPJxDjwp5xwcnOQp9lTs51zRmz", "text_language": "<|zh|>", "emo_target": "<|NEUTRAL|>", "event_target": "<|Speech|>", "with_or_wo_itn": "<|woitn|>", "target": "郑超啊你到我这儿来一下", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/industrial_data/16k_common/audio/wav001_0437_lATPJxDjwp5xwcnOQp9lTs51zRmz.wav", "target_len": 11, "source_len": 220} 6 | {"key": "wav010_0212_Speaker0045_iOS_s0_088_punc_itn", "text_language": "<|en|>", "emo_target": "<|EMO_UNKNOWN|>", "event_target": "<|Speech|>", "with_or_wo_itn": "<|withitn|>", "target": "In dark moments, I speak with her.", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/industrial_data/english_all/audio/wav010_0212_Speaker0045_iOS_s0_088.wav", "target_len": 7, "source_len": 140} 7 | {"key": "18934860_GSD17-Sz1vw_segment107", "text_language": "<|en|>", "emo_target": "<|NEUTRAL|>", "event_target": "<|Speech|>", "with_or_wo_itn": "<|woitn|>", "target": "some states also include suspected domestic violence within mandatory reporting laws", "source": "/cpfs_speech/data/shared/Group-speech/beinian.lzr/data/multilingual/lcd_data/english/production/20240114/wer_0_5/taskid_18934860/wav/18934860_GSD17-Sz1vw_segment107.wav", "target_len": 11, "source_len": 220} 8 | {"key": "wav003_0734_XSG5MY6Eym.mp4_160", "text_language": "<|zh|>", "emo_target": "<|NEUTRAL|>", "event_target": "<|Speech|>", "with_or_wo_itn": "<|woitn|>", "target": "四万多四万多", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/industrial_data/dialect/audio/wav003_0734_XSG5MY6Eym.mp4_160.wav", "target_len": 6, "source_len": 120} 9 | {"key": "19208463_4_dWQ34YNU4_segment311", "text_language": "<|en|>", "emo_target": "<|EMO_UNKNOWN|>", "event_target": "<|Speech|>", "with_or_wo_itn": "<|woitn|>", "target": "i said well you have to see that movie history of the world", "source": "/cpfs_speech/data/shared/Group-speech/beinian.lzr/data/multilingual/lcd_data/english/production/20240222/wer_0_5/taskid_19208463/wav/19208463_4_dWQ34YNU4_segment311.wav", "target_len": 13, "source_len": 260} 10 | {"key": "wav005_0682_lATPJv8gRtSVeIDOUGfqsc5DF8yn_13_punc_itn", "text_language": "<|zh|>", "emo_target": "<|NEUTRAL|>", "event_target": "<|Speech|>", "with_or_wo_itn": "<|withitn|>", "target": "这也是我们呃。", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/industrial_data/16k_common/audio/wav005_0682_lATPJv8gRtSVeIDOUGfqsc5DF8yn_13.wav", "target_len": 7, "source_len": 140} 11 | -------------------------------------------------------------------------------- /deepspeed_conf/ds_stage1.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 1, 3 | "gradient_accumulation_steps": 1, 4 | "steps_per_print": 100, 5 | "gradient_clipping": 5, 6 | "fp16": { 7 | "enabled": false, 8 | "auto_cast": false, 9 | "loss_scale": 0, 10 | "initial_scale_power": 16, 11 | "loss_scale_window": 1000, 12 | "hysteresis": 2, 13 | "consecutive_hysteresis": false, 14 | "min_loss_scale": 1 15 | }, 16 | "bf16": { 17 | "enabled": true 18 | }, 19 | "zero_force_ds_cpu_optimizer": false, 20 | "zero_optimization": { 21 | "stage": 1, 22 | "offload_optimizer": { 23 | "device": "none", 24 | "pin_memory": true 25 | }, 26 | "allgather_partitions": true, 27 | "allgather_bucket_size": 5e8, 28 | "overlap_comm": true, 29 | "reduce_scatter": true, 30 | "reduce_bucket_size": 5e8, 31 | "contiguous_gradients" : true 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /demo1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | # Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved. 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | from funasr import AutoModel 7 | from funasr.utils.postprocess_utils import rich_transcription_postprocess 8 | 9 | model_dir = "iic/SenseVoiceSmall" 10 | 11 | 12 | model = AutoModel( 13 | model=model_dir, 14 | trust_remote_code=True, 15 | remote_code="./model.py", 16 | vad_model="fsmn-vad", 17 | vad_kwargs={"max_single_segment_time": 30000}, 18 | device="cuda:0", 19 | ) 20 | 21 | # en 22 | res = model.generate( 23 | input=f"{model.model_path}/example/en.mp3", 24 | cache={}, 25 | language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech" 26 | use_itn=True, 27 | batch_size_s=60, 28 | merge_vad=True, # 29 | merge_length_s=15, 30 | ) 31 | text = rich_transcription_postprocess(res[0]["text"]) 32 | print(text) 33 | 34 | # zh 35 | res = model.generate( 36 | input=f"{model.model_path}/example/zh.mp3", 37 | cache={}, 38 | language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech" 39 | use_itn=True, 40 | batch_size_s=60, 41 | merge_vad=True, # 42 | merge_length_s=15, 43 | ) 44 | text = rich_transcription_postprocess(res[0]["text"]) 45 | print(text) 46 | 47 | # yue 48 | res = model.generate( 49 | input=f"{model.model_path}/example/yue.mp3", 50 | cache={}, 51 | language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech" 52 | use_itn=True, 53 | batch_size_s=60, 54 | merge_vad=True, # 55 | merge_length_s=15, 56 | ) 57 | text = rich_transcription_postprocess(res[0]["text"]) 58 | print(text) 59 | 60 | # ja 61 | res = model.generate( 62 | input=f"{model.model_path}/example/ja.mp3", 63 | cache={}, 64 | language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech" 65 | use_itn=True, 66 | batch_size_s=60, 67 | merge_vad=True, # 68 | merge_length_s=15, 69 | ) 70 | text = rich_transcription_postprocess(res[0]["text"]) 71 | print(text) 72 | 73 | 74 | # ko 75 | res = model.generate( 76 | input=f"{model.model_path}/example/ko.mp3", 77 | cache={}, 78 | language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech" 79 | use_itn=True, 80 | batch_size_s=60, 81 | merge_vad=True, # 82 | merge_length_s=15, 83 | ) 84 | text = rich_transcription_postprocess(res[0]["text"]) 85 | print(text) 86 | -------------------------------------------------------------------------------- /demo2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | # Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved. 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | from model import SenseVoiceSmall 7 | from funasr.utils.postprocess_utils import rich_transcription_postprocess 8 | 9 | 10 | model_dir = "iic/SenseVoiceSmall" 11 | m, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device="cuda:0") 12 | m.eval() 13 | 14 | res = m.inference( 15 | data_in=f"{kwargs['model_path']}/example/en.mp3", 16 | language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech" 17 | use_itn=False, 18 | ban_emo_unk=False, 19 | **kwargs, 20 | ) 21 | 22 | text = rich_transcription_postprocess(res[0][0]["text"]) 23 | print(text) 24 | 25 | res = m.inference( 26 | data_in=f"{kwargs['model_path']}/example/en.mp3", 27 | language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech" 28 | use_itn=False, 29 | ban_emo_unk=False, 30 | output_timestamp=True, 31 | **kwargs, 32 | ) 33 | 34 | timestamp = res[0][0]["timestamp"] 35 | text = rich_transcription_postprocess(res[0][0]["text"]) 36 | print(text) 37 | print(timestamp) 38 | -------------------------------------------------------------------------------- /demo_libtorch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | # Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved. 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | from pathlib import Path 7 | from funasr_torch import SenseVoiceSmall 8 | from funasr_torch.utils.postprocess_utils import rich_transcription_postprocess 9 | 10 | 11 | model_dir = "iic/SenseVoiceSmall" 12 | 13 | model = SenseVoiceSmall(model_dir, batch_size=10, device="cuda:0") 14 | 15 | wav_or_scp = ["{}/.cache/modelscope/hub/{}/example/en.mp3".format(Path.home(), model_dir)] 16 | 17 | res = model(wav_or_scp, language="auto", use_itn=True) 18 | print([rich_transcription_postprocess(i) for i in res]) 19 | -------------------------------------------------------------------------------- /demo_onnx.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | # Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved. 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | from pathlib import Path 7 | from funasr_onnx import SenseVoiceSmall 8 | from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess 9 | 10 | 11 | model_dir = "iic/SenseVoiceSmall" 12 | 13 | model = SenseVoiceSmall(model_dir, batch_size=10, quantize=True) 14 | 15 | # inference 16 | wav_or_scp = ["{}/.cache/modelscope/hub/{}/example/en.mp3".format(Path.home(), model_dir)] 17 | 18 | res = model(wav_or_scp, language="auto", textnorm="withitn") 19 | print([rich_transcription_postprocess(i) for i in res]) 20 | -------------------------------------------------------------------------------- /export.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | # Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved. 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | import os 7 | import torch 8 | from model import SenseVoiceSmall 9 | from utils import export_utils 10 | from utils.model_bin import SenseVoiceSmallONNX 11 | from funasr.utils.postprocess_utils import rich_transcription_postprocess 12 | 13 | quantize = False 14 | 15 | model_dir = "iic/SenseVoiceSmall" 16 | model, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device="cuda:0") 17 | 18 | rebuilt_model = model.export(type="onnx", quantize=False) 19 | model_path = kwargs.get("output_dir", os.path.dirname(kwargs.get("init_param"))) 20 | 21 | model_file = os.path.join(model_path, "model.onnx") 22 | if quantize: 23 | model_file = os.path.join(model_path, "model_quant.onnx") 24 | 25 | # export model 26 | if not os.path.exists(model_file): 27 | with torch.no_grad(): 28 | del kwargs['model'] 29 | export_dir = export_utils.export(model=rebuilt_model, **kwargs) 30 | print("Export model onnx to {}".format(model_file)) 31 | 32 | # export model init 33 | model_bin = SenseVoiceSmallONNX(model_path) 34 | 35 | # build tokenizer 36 | try: 37 | from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer 38 | tokenizer = SentencepiecesTokenizer(bpemodel=os.path.join(model_path, "chn_jpn_yue_eng_ko_spectok.bpe.model")) 39 | except: 40 | tokenizer = None 41 | 42 | # inference 43 | wav_or_scp = "/Users/shixian/Downloads/asr_example_hotword.wav" 44 | language_list = [0] 45 | textnorm_list = [15] 46 | res = model_bin(wav_or_scp, language_list, textnorm_list, tokenizer=tokenizer) 47 | print([rich_transcription_postprocess(i) for i in res]) 48 | -------------------------------------------------------------------------------- /export_meta.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | import types 7 | import torch 8 | from funasr.utils.torch_function import sequence_mask 9 | 10 | 11 | def export_rebuild_model(model, **kwargs): 12 | model.device = kwargs.get("device") 13 | model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False) 14 | model.forward = types.MethodType(export_forward, model) 15 | model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model) 16 | model.export_input_names = types.MethodType(export_input_names, model) 17 | model.export_output_names = types.MethodType(export_output_names, model) 18 | model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model) 19 | model.export_name = types.MethodType(export_name, model) 20 | return model 21 | 22 | def export_forward( 23 | self, 24 | speech: torch.Tensor, 25 | speech_lengths: torch.Tensor, 26 | language: torch.Tensor, 27 | textnorm: torch.Tensor, 28 | **kwargs, 29 | ): 30 | # speech = speech.to(device="cuda") 31 | # speech_lengths = speech_lengths.to(device="cuda") 32 | language_query = self.embed(language.to(speech.device)).unsqueeze(1) 33 | textnorm_query = self.embed(textnorm.to(speech.device)).unsqueeze(1) 34 | print(textnorm_query.shape, speech.shape) 35 | speech = torch.cat((textnorm_query, speech), dim=1) 36 | speech_lengths += 1 37 | 38 | event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat( 39 | speech.size(0), 1, 1 40 | ) 41 | input_query = torch.cat((language_query, event_emo_query), dim=1) 42 | speech = torch.cat((input_query, speech), dim=1) 43 | speech_lengths += 3 44 | 45 | encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths) 46 | if isinstance(encoder_out, tuple): 47 | encoder_out = encoder_out[0] 48 | 49 | ctc_logits = self.ctc.ctc_lo(encoder_out) 50 | 51 | return ctc_logits, encoder_out_lens 52 | 53 | def export_dummy_inputs(self): 54 | speech = torch.randn(2, 30, 560) 55 | speech_lengths = torch.tensor([6, 30], dtype=torch.int32) 56 | language = torch.tensor([0, 0], dtype=torch.int32) 57 | textnorm = torch.tensor([15, 15], dtype=torch.int32) 58 | return (speech, speech_lengths, language, textnorm) 59 | 60 | def export_input_names(self): 61 | return ["speech", "speech_lengths", "language", "textnorm"] 62 | 63 | def export_output_names(self): 64 | return ["ctc_logits", "encoder_out_lens"] 65 | 66 | def export_dynamic_axes(self): 67 | return { 68 | "speech": {0: "batch_size", 1: "feats_length"}, 69 | "speech_lengths": {0: "batch_size"}, 70 | "language": {0: "batch_size"}, 71 | "textnorm": {0: "batch_size"}, 72 | "ctc_logits": {0: "batch_size", 1: "logits_length"}, 73 | "encoder_out_lens": {0: "batch_size"}, 74 | } 75 | 76 | def export_name(self): 77 | return "model.onnx" 78 | 79 | -------------------------------------------------------------------------------- /finetune.sh: -------------------------------------------------------------------------------- 1 | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. 2 | # MIT License (https://opensource.org/licenses/MIT) 3 | 4 | workspace=`pwd` 5 | 6 | # which gpu to train or finetune 7 | export CUDA_VISIBLE_DEVICES="0,1" 8 | gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') 9 | 10 | # model_name from model_hub, or model_dir in local path 11 | 12 | ## option 1, download model automatically 13 | model_name_or_model_dir="iic/SenseVoiceSmall" 14 | 15 | ## option 2, download model by git 16 | #local_path_root=${workspace}/modelscope_models 17 | #mkdir -p ${local_path_root}/${model_name_or_model_dir} 18 | #git clone https://www.modelscope.cn/${model_name_or_model_dir}.git ${local_path_root}/${model_name_or_model_dir} 19 | #model_name_or_model_dir=${local_path_root}/${model_name_or_model_dir} 20 | 21 | 22 | # data dir, which contains: train.json, val.json 23 | train_data=${workspace}/data/train_example.jsonl 24 | val_data=${workspace}/data/val_example.jsonl 25 | 26 | # exp output dir 27 | output_dir="./outputs" 28 | log_file="${output_dir}/log.txt" 29 | 30 | deepspeed_config=${workspace}/deepspeed_conf/ds_stage1.json 31 | 32 | mkdir -p ${output_dir} 33 | echo "log_file: ${log_file}" 34 | 35 | DISTRIBUTED_ARGS=" 36 | --nnodes ${WORLD_SIZE:-1} \ 37 | --nproc_per_node $gpu_num \ 38 | --node_rank ${RANK:-0} \ 39 | --master_addr ${MASTER_ADDR:-127.0.0.1} \ 40 | --master_port ${MASTER_PORT:-26669} 41 | " 42 | 43 | echo $DISTRIBUTED_ARGS 44 | 45 | # funasr trainer path 46 | train_tool=`dirname $(which funasr)`/train_ds.py 47 | 48 | torchrun $DISTRIBUTED_ARGS \ 49 | ${train_tool} \ 50 | ++model="${model_name_or_model_dir}" \ 51 | ++trust_remote_code=true \ 52 | ++train_data_set_list="${train_data}" \ 53 | ++valid_data_set_list="${val_data}" \ 54 | ++dataset_conf.data_split_num=1 \ 55 | ++dataset_conf.batch_sampler="BatchSampler" \ 56 | ++dataset_conf.batch_size=6000 \ 57 | ++dataset_conf.sort_size=1024 \ 58 | ++dataset_conf.batch_type="token" \ 59 | ++dataset_conf.num_workers=4 \ 60 | ++train_conf.max_epoch=50 \ 61 | ++train_conf.log_interval=1 \ 62 | ++train_conf.resume=true \ 63 | ++train_conf.validate_interval=2000 \ 64 | ++train_conf.save_checkpoint_interval=2000 \ 65 | ++train_conf.keep_nbest_models=20 \ 66 | ++train_conf.avg_nbest_model=10 \ 67 | ++train_conf.use_deepspeed=false \ 68 | ++train_conf.deepspeed_config=${deepspeed_config} \ 69 | ++optim_conf.lr=0.0002 \ 70 | ++output_dir="${output_dir}" &> ${log_file} -------------------------------------------------------------------------------- /image/aed_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/SenseVoice/3ecc6f6a8f8524a77e04be51846d840cb35da2b4/image/aed_figure.png -------------------------------------------------------------------------------- /image/asr_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/SenseVoice/3ecc6f6a8f8524a77e04be51846d840cb35da2b4/image/asr_results.png -------------------------------------------------------------------------------- /image/asr_results1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/SenseVoice/3ecc6f6a8f8524a77e04be51846d840cb35da2b4/image/asr_results1.png -------------------------------------------------------------------------------- /image/asr_results2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/SenseVoice/3ecc6f6a8f8524a77e04be51846d840cb35da2b4/image/asr_results2.png -------------------------------------------------------------------------------- /image/dingding_funasr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/SenseVoice/3ecc6f6a8f8524a77e04be51846d840cb35da2b4/image/dingding_funasr.png -------------------------------------------------------------------------------- /image/dingding_sv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/SenseVoice/3ecc6f6a8f8524a77e04be51846d840cb35da2b4/image/dingding_sv.png -------------------------------------------------------------------------------- /image/inference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/SenseVoice/3ecc6f6a8f8524a77e04be51846d840cb35da2b4/image/inference.png -------------------------------------------------------------------------------- /image/sensevoice.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/SenseVoice/3ecc6f6a8f8524a77e04be51846d840cb35da2b4/image/sensevoice.png -------------------------------------------------------------------------------- /image/sensevoice2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/SenseVoice/3ecc6f6a8f8524a77e04be51846d840cb35da2b4/image/sensevoice2.png -------------------------------------------------------------------------------- /image/ser_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/SenseVoice/3ecc6f6a8f8524a77e04be51846d840cb35da2b4/image/ser_figure.png -------------------------------------------------------------------------------- /image/ser_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/SenseVoice/3ecc6f6a8f8524a77e04be51846d840cb35da2b4/image/ser_table.png -------------------------------------------------------------------------------- /image/webui.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/SenseVoice/3ecc6f6a8f8524a77e04be51846d840cb35da2b4/image/webui.png -------------------------------------------------------------------------------- /image/wechat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/SenseVoice/3ecc6f6a8f8524a77e04be51846d840cb35da2b4/image/wechat.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | 2 | import time 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from typing import Iterable, Optional 7 | 8 | from funasr.register import tables 9 | from funasr.models.ctc.ctc import CTC 10 | from funasr.utils.datadir_writer import DatadirWriter 11 | from funasr.models.paraformer.search import Hypothesis 12 | from funasr.train_utils.device_funcs import force_gatherable 13 | from funasr.losses.label_smoothing_loss import LabelSmoothingLoss 14 | from funasr.metrics.compute_acc import compute_accuracy, th_accuracy 15 | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank 16 | from utils.ctc_alignment import ctc_forced_align 17 | 18 | class SinusoidalPositionEncoder(torch.nn.Module): 19 | """ """ 20 | 21 | def __int__(self, d_model=80, dropout_rate=0.1): 22 | pass 23 | 24 | def encode( 25 | self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32 26 | ): 27 | batch_size = positions.size(0) 28 | positions = positions.type(dtype) 29 | device = positions.device 30 | log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype, device=device)) / ( 31 | depth / 2 - 1 32 | ) 33 | inv_timescales = torch.exp( 34 | torch.arange(depth / 2, device=device).type(dtype) * (-log_timescale_increment) 35 | ) 36 | inv_timescales = torch.reshape(inv_timescales, [batch_size, -1]) 37 | scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape( 38 | inv_timescales, [1, 1, -1] 39 | ) 40 | encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) 41 | return encoding.type(dtype) 42 | 43 | def forward(self, x): 44 | batch_size, timesteps, input_dim = x.size() 45 | positions = torch.arange(1, timesteps + 1, device=x.device)[None, :] 46 | position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device) 47 | 48 | return x + position_encoding 49 | 50 | 51 | class PositionwiseFeedForward(torch.nn.Module): 52 | """Positionwise feed forward layer. 53 | 54 | Args: 55 | idim (int): Input dimenstion. 56 | hidden_units (int): The number of hidden units. 57 | dropout_rate (float): Dropout rate. 58 | 59 | """ 60 | 61 | def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()): 62 | """Construct an PositionwiseFeedForward object.""" 63 | super(PositionwiseFeedForward, self).__init__() 64 | self.w_1 = torch.nn.Linear(idim, hidden_units) 65 | self.w_2 = torch.nn.Linear(hidden_units, idim) 66 | self.dropout = torch.nn.Dropout(dropout_rate) 67 | self.activation = activation 68 | 69 | def forward(self, x): 70 | """Forward function.""" 71 | return self.w_2(self.dropout(self.activation(self.w_1(x)))) 72 | 73 | 74 | class MultiHeadedAttentionSANM(nn.Module): 75 | """Multi-Head Attention layer. 76 | 77 | Args: 78 | n_head (int): The number of heads. 79 | n_feat (int): The number of features. 80 | dropout_rate (float): Dropout rate. 81 | 82 | """ 83 | 84 | def __init__( 85 | self, 86 | n_head, 87 | in_feat, 88 | n_feat, 89 | dropout_rate, 90 | kernel_size, 91 | sanm_shfit=0, 92 | lora_list=None, 93 | lora_rank=8, 94 | lora_alpha=16, 95 | lora_dropout=0.1, 96 | ): 97 | """Construct an MultiHeadedAttention object.""" 98 | super().__init__() 99 | assert n_feat % n_head == 0 100 | # We assume d_v always equals d_k 101 | self.d_k = n_feat // n_head 102 | self.h = n_head 103 | # self.linear_q = nn.Linear(n_feat, n_feat) 104 | # self.linear_k = nn.Linear(n_feat, n_feat) 105 | # self.linear_v = nn.Linear(n_feat, n_feat) 106 | 107 | self.linear_out = nn.Linear(n_feat, n_feat) 108 | self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3) 109 | self.attn = None 110 | self.dropout = nn.Dropout(p=dropout_rate) 111 | 112 | self.fsmn_block = nn.Conv1d( 113 | n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False 114 | ) 115 | # padding 116 | left_padding = (kernel_size - 1) // 2 117 | if sanm_shfit > 0: 118 | left_padding = left_padding + sanm_shfit 119 | right_padding = kernel_size - 1 - left_padding 120 | self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0) 121 | 122 | def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None): 123 | b, t, d = inputs.size() 124 | if mask is not None: 125 | mask = torch.reshape(mask, (b, -1, 1)) 126 | if mask_shfit_chunk is not None: 127 | mask = mask * mask_shfit_chunk 128 | inputs = inputs * mask 129 | 130 | x = inputs.transpose(1, 2) 131 | x = self.pad_fn(x) 132 | x = self.fsmn_block(x) 133 | x = x.transpose(1, 2) 134 | x += inputs 135 | x = self.dropout(x) 136 | if mask is not None: 137 | x = x * mask 138 | return x 139 | 140 | def forward_qkv(self, x): 141 | """Transform query, key and value. 142 | 143 | Args: 144 | query (torch.Tensor): Query tensor (#batch, time1, size). 145 | key (torch.Tensor): Key tensor (#batch, time2, size). 146 | value (torch.Tensor): Value tensor (#batch, time2, size). 147 | 148 | Returns: 149 | torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). 150 | torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). 151 | torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). 152 | 153 | """ 154 | b, t, d = x.size() 155 | q_k_v = self.linear_q_k_v(x) 156 | q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1) 157 | q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose( 158 | 1, 2 159 | ) # (batch, head, time1, d_k) 160 | k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose( 161 | 1, 2 162 | ) # (batch, head, time2, d_k) 163 | v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose( 164 | 1, 2 165 | ) # (batch, head, time2, d_k) 166 | 167 | return q_h, k_h, v_h, v 168 | 169 | def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None): 170 | """Compute attention context vector. 171 | 172 | Args: 173 | value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). 174 | scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). 175 | mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). 176 | 177 | Returns: 178 | torch.Tensor: Transformed value (#batch, time1, d_model) 179 | weighted by the attention score (#batch, time1, time2). 180 | 181 | """ 182 | n_batch = value.size(0) 183 | if mask is not None: 184 | if mask_att_chunk_encoder is not None: 185 | mask = mask * mask_att_chunk_encoder 186 | 187 | mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) 188 | 189 | min_value = -float( 190 | "inf" 191 | ) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min) 192 | scores = scores.masked_fill(mask, min_value) 193 | attn = torch.softmax(scores, dim=-1).masked_fill( 194 | mask, 0.0 195 | ) # (batch, head, time1, time2) 196 | else: 197 | attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) 198 | 199 | p_attn = self.dropout(attn) 200 | x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) 201 | x = ( 202 | x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) 203 | ) # (batch, time1, d_model) 204 | 205 | return self.linear_out(x) # (batch, time1, d_model) 206 | 207 | def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None): 208 | """Compute scaled dot product attention. 209 | 210 | Args: 211 | query (torch.Tensor): Query tensor (#batch, time1, size). 212 | key (torch.Tensor): Key tensor (#batch, time2, size). 213 | value (torch.Tensor): Value tensor (#batch, time2, size). 214 | mask (torch.Tensor): Mask tensor (#batch, 1, time2) or 215 | (#batch, time1, time2). 216 | 217 | Returns: 218 | torch.Tensor: Output tensor (#batch, time1, d_model). 219 | 220 | """ 221 | q_h, k_h, v_h, v = self.forward_qkv(x) 222 | fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk) 223 | q_h = q_h * self.d_k ** (-0.5) 224 | scores = torch.matmul(q_h, k_h.transpose(-2, -1)) 225 | att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder) 226 | return att_outs + fsmn_memory 227 | 228 | def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0): 229 | """Compute scaled dot product attention. 230 | 231 | Args: 232 | query (torch.Tensor): Query tensor (#batch, time1, size). 233 | key (torch.Tensor): Key tensor (#batch, time2, size). 234 | value (torch.Tensor): Value tensor (#batch, time2, size). 235 | mask (torch.Tensor): Mask tensor (#batch, 1, time2) or 236 | (#batch, time1, time2). 237 | 238 | Returns: 239 | torch.Tensor: Output tensor (#batch, time1, d_model). 240 | 241 | """ 242 | q_h, k_h, v_h, v = self.forward_qkv(x) 243 | if chunk_size is not None and look_back > 0 or look_back == -1: 244 | if cache is not None: 245 | k_h_stride = k_h[:, :, : -(chunk_size[2]), :] 246 | v_h_stride = v_h[:, :, : -(chunk_size[2]), :] 247 | k_h = torch.cat((cache["k"], k_h), dim=2) 248 | v_h = torch.cat((cache["v"], v_h), dim=2) 249 | 250 | cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2) 251 | cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2) 252 | if look_back != -1: 253 | cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]) :, :] 254 | cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]) :, :] 255 | else: 256 | cache_tmp = { 257 | "k": k_h[:, :, : -(chunk_size[2]), :], 258 | "v": v_h[:, :, : -(chunk_size[2]), :], 259 | } 260 | cache = cache_tmp 261 | fsmn_memory = self.forward_fsmn(v, None) 262 | q_h = q_h * self.d_k ** (-0.5) 263 | scores = torch.matmul(q_h, k_h.transpose(-2, -1)) 264 | att_outs = self.forward_attention(v_h, scores, None) 265 | return att_outs + fsmn_memory, cache 266 | 267 | 268 | class LayerNorm(nn.LayerNorm): 269 | def __init__(self, *args, **kwargs): 270 | super().__init__(*args, **kwargs) 271 | 272 | def forward(self, input): 273 | output = F.layer_norm( 274 | input.float(), 275 | self.normalized_shape, 276 | self.weight.float() if self.weight is not None else None, 277 | self.bias.float() if self.bias is not None else None, 278 | self.eps, 279 | ) 280 | return output.type_as(input) 281 | 282 | 283 | def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None): 284 | if maxlen is None: 285 | maxlen = lengths.max() 286 | row_vector = torch.arange(0, maxlen, 1).to(lengths.device) 287 | matrix = torch.unsqueeze(lengths, dim=-1) 288 | mask = row_vector < matrix 289 | mask = mask.detach() 290 | 291 | return mask.type(dtype).to(device) if device is not None else mask.type(dtype) 292 | 293 | 294 | class EncoderLayerSANM(nn.Module): 295 | def __init__( 296 | self, 297 | in_size, 298 | size, 299 | self_attn, 300 | feed_forward, 301 | dropout_rate, 302 | normalize_before=True, 303 | concat_after=False, 304 | stochastic_depth_rate=0.0, 305 | ): 306 | """Construct an EncoderLayer object.""" 307 | super(EncoderLayerSANM, self).__init__() 308 | self.self_attn = self_attn 309 | self.feed_forward = feed_forward 310 | self.norm1 = LayerNorm(in_size) 311 | self.norm2 = LayerNorm(size) 312 | self.dropout = nn.Dropout(dropout_rate) 313 | self.in_size = in_size 314 | self.size = size 315 | self.normalize_before = normalize_before 316 | self.concat_after = concat_after 317 | if self.concat_after: 318 | self.concat_linear = nn.Linear(size + size, size) 319 | self.stochastic_depth_rate = stochastic_depth_rate 320 | self.dropout_rate = dropout_rate 321 | 322 | def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None): 323 | """Compute encoded features. 324 | 325 | Args: 326 | x_input (torch.Tensor): Input tensor (#batch, time, size). 327 | mask (torch.Tensor): Mask tensor for the input (#batch, time). 328 | cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). 329 | 330 | Returns: 331 | torch.Tensor: Output tensor (#batch, time, size). 332 | torch.Tensor: Mask tensor (#batch, time). 333 | 334 | """ 335 | skip_layer = False 336 | # with stochastic depth, residual connection `x + f(x)` becomes 337 | # `x <- x + 1 / (1 - p) * f(x)` at training time. 338 | stoch_layer_coeff = 1.0 339 | if self.training and self.stochastic_depth_rate > 0: 340 | skip_layer = torch.rand(1).item() < self.stochastic_depth_rate 341 | stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) 342 | 343 | if skip_layer: 344 | if cache is not None: 345 | x = torch.cat([cache, x], dim=1) 346 | return x, mask 347 | 348 | residual = x 349 | if self.normalize_before: 350 | x = self.norm1(x) 351 | 352 | if self.concat_after: 353 | x_concat = torch.cat( 354 | ( 355 | x, 356 | self.self_attn( 357 | x, 358 | mask, 359 | mask_shfit_chunk=mask_shfit_chunk, 360 | mask_att_chunk_encoder=mask_att_chunk_encoder, 361 | ), 362 | ), 363 | dim=-1, 364 | ) 365 | if self.in_size == self.size: 366 | x = residual + stoch_layer_coeff * self.concat_linear(x_concat) 367 | else: 368 | x = stoch_layer_coeff * self.concat_linear(x_concat) 369 | else: 370 | if self.in_size == self.size: 371 | x = residual + stoch_layer_coeff * self.dropout( 372 | self.self_attn( 373 | x, 374 | mask, 375 | mask_shfit_chunk=mask_shfit_chunk, 376 | mask_att_chunk_encoder=mask_att_chunk_encoder, 377 | ) 378 | ) 379 | else: 380 | x = stoch_layer_coeff * self.dropout( 381 | self.self_attn( 382 | x, 383 | mask, 384 | mask_shfit_chunk=mask_shfit_chunk, 385 | mask_att_chunk_encoder=mask_att_chunk_encoder, 386 | ) 387 | ) 388 | if not self.normalize_before: 389 | x = self.norm1(x) 390 | 391 | residual = x 392 | if self.normalize_before: 393 | x = self.norm2(x) 394 | x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) 395 | if not self.normalize_before: 396 | x = self.norm2(x) 397 | 398 | return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder 399 | 400 | def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0): 401 | """Compute encoded features. 402 | 403 | Args: 404 | x_input (torch.Tensor): Input tensor (#batch, time, size). 405 | mask (torch.Tensor): Mask tensor for the input (#batch, time). 406 | cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). 407 | 408 | Returns: 409 | torch.Tensor: Output tensor (#batch, time, size). 410 | torch.Tensor: Mask tensor (#batch, time). 411 | 412 | """ 413 | 414 | residual = x 415 | if self.normalize_before: 416 | x = self.norm1(x) 417 | 418 | if self.in_size == self.size: 419 | attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back) 420 | x = residual + attn 421 | else: 422 | x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back) 423 | 424 | if not self.normalize_before: 425 | x = self.norm1(x) 426 | 427 | residual = x 428 | if self.normalize_before: 429 | x = self.norm2(x) 430 | x = residual + self.feed_forward(x) 431 | if not self.normalize_before: 432 | x = self.norm2(x) 433 | 434 | return x, cache 435 | 436 | 437 | @tables.register("encoder_classes", "SenseVoiceEncoderSmall") 438 | class SenseVoiceEncoderSmall(nn.Module): 439 | """ 440 | Author: Speech Lab of DAMO Academy, Alibaba Group 441 | SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition 442 | https://arxiv.org/abs/2006.01713 443 | """ 444 | 445 | def __init__( 446 | self, 447 | input_size: int, 448 | output_size: int = 256, 449 | attention_heads: int = 4, 450 | linear_units: int = 2048, 451 | num_blocks: int = 6, 452 | tp_blocks: int = 0, 453 | dropout_rate: float = 0.1, 454 | positional_dropout_rate: float = 0.1, 455 | attention_dropout_rate: float = 0.0, 456 | stochastic_depth_rate: float = 0.0, 457 | input_layer: Optional[str] = "conv2d", 458 | pos_enc_class=SinusoidalPositionEncoder, 459 | normalize_before: bool = True, 460 | concat_after: bool = False, 461 | positionwise_layer_type: str = "linear", 462 | positionwise_conv_kernel_size: int = 1, 463 | padding_idx: int = -1, 464 | kernel_size: int = 11, 465 | sanm_shfit: int = 0, 466 | selfattention_layer_type: str = "sanm", 467 | **kwargs, 468 | ): 469 | super().__init__() 470 | self._output_size = output_size 471 | 472 | self.embed = SinusoidalPositionEncoder() 473 | 474 | self.normalize_before = normalize_before 475 | 476 | positionwise_layer = PositionwiseFeedForward 477 | positionwise_layer_args = ( 478 | output_size, 479 | linear_units, 480 | dropout_rate, 481 | ) 482 | 483 | encoder_selfattn_layer = MultiHeadedAttentionSANM 484 | encoder_selfattn_layer_args0 = ( 485 | attention_heads, 486 | input_size, 487 | output_size, 488 | attention_dropout_rate, 489 | kernel_size, 490 | sanm_shfit, 491 | ) 492 | encoder_selfattn_layer_args = ( 493 | attention_heads, 494 | output_size, 495 | output_size, 496 | attention_dropout_rate, 497 | kernel_size, 498 | sanm_shfit, 499 | ) 500 | 501 | self.encoders0 = nn.ModuleList( 502 | [ 503 | EncoderLayerSANM( 504 | input_size, 505 | output_size, 506 | encoder_selfattn_layer(*encoder_selfattn_layer_args0), 507 | positionwise_layer(*positionwise_layer_args), 508 | dropout_rate, 509 | ) 510 | for i in range(1) 511 | ] 512 | ) 513 | self.encoders = nn.ModuleList( 514 | [ 515 | EncoderLayerSANM( 516 | output_size, 517 | output_size, 518 | encoder_selfattn_layer(*encoder_selfattn_layer_args), 519 | positionwise_layer(*positionwise_layer_args), 520 | dropout_rate, 521 | ) 522 | for i in range(num_blocks - 1) 523 | ] 524 | ) 525 | 526 | self.tp_encoders = nn.ModuleList( 527 | [ 528 | EncoderLayerSANM( 529 | output_size, 530 | output_size, 531 | encoder_selfattn_layer(*encoder_selfattn_layer_args), 532 | positionwise_layer(*positionwise_layer_args), 533 | dropout_rate, 534 | ) 535 | for i in range(tp_blocks) 536 | ] 537 | ) 538 | 539 | self.after_norm = LayerNorm(output_size) 540 | 541 | self.tp_norm = LayerNorm(output_size) 542 | 543 | def output_size(self) -> int: 544 | return self._output_size 545 | 546 | def forward( 547 | self, 548 | xs_pad: torch.Tensor, 549 | ilens: torch.Tensor, 550 | ): 551 | """Embed positions in tensor.""" 552 | masks = sequence_mask(ilens, device=ilens.device)[:, None, :] 553 | 554 | xs_pad *= self.output_size() ** 0.5 555 | 556 | xs_pad = self.embed(xs_pad) 557 | 558 | # forward encoder1 559 | for layer_idx, encoder_layer in enumerate(self.encoders0): 560 | encoder_outs = encoder_layer(xs_pad, masks) 561 | xs_pad, masks = encoder_outs[0], encoder_outs[1] 562 | 563 | for layer_idx, encoder_layer in enumerate(self.encoders): 564 | encoder_outs = encoder_layer(xs_pad, masks) 565 | xs_pad, masks = encoder_outs[0], encoder_outs[1] 566 | 567 | xs_pad = self.after_norm(xs_pad) 568 | 569 | # forward encoder2 570 | olens = masks.squeeze(1).sum(1).int() 571 | 572 | for layer_idx, encoder_layer in enumerate(self.tp_encoders): 573 | encoder_outs = encoder_layer(xs_pad, masks) 574 | xs_pad, masks = encoder_outs[0], encoder_outs[1] 575 | 576 | xs_pad = self.tp_norm(xs_pad) 577 | return xs_pad, olens 578 | 579 | 580 | @tables.register("model_classes", "SenseVoiceSmall") 581 | class SenseVoiceSmall(nn.Module): 582 | """CTC-attention hybrid Encoder-Decoder model""" 583 | 584 | def __init__( 585 | self, 586 | specaug: str = None, 587 | specaug_conf: dict = None, 588 | normalize: str = None, 589 | normalize_conf: dict = None, 590 | encoder: str = None, 591 | encoder_conf: dict = None, 592 | ctc_conf: dict = None, 593 | input_size: int = 80, 594 | vocab_size: int = -1, 595 | ignore_id: int = -1, 596 | blank_id: int = 0, 597 | sos: int = 1, 598 | eos: int = 2, 599 | length_normalized_loss: bool = False, 600 | **kwargs, 601 | ): 602 | 603 | super().__init__() 604 | 605 | if specaug is not None: 606 | specaug_class = tables.specaug_classes.get(specaug) 607 | specaug = specaug_class(**specaug_conf) 608 | if normalize is not None: 609 | normalize_class = tables.normalize_classes.get(normalize) 610 | normalize = normalize_class(**normalize_conf) 611 | encoder_class = tables.encoder_classes.get(encoder) 612 | encoder = encoder_class(input_size=input_size, **encoder_conf) 613 | encoder_output_size = encoder.output_size() 614 | 615 | if ctc_conf is None: 616 | ctc_conf = {} 617 | ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf) 618 | 619 | self.blank_id = blank_id 620 | self.sos = sos if sos is not None else vocab_size - 1 621 | self.eos = eos if eos is not None else vocab_size - 1 622 | self.vocab_size = vocab_size 623 | self.ignore_id = ignore_id 624 | self.specaug = specaug 625 | self.normalize = normalize 626 | self.encoder = encoder 627 | self.error_calculator = None 628 | 629 | self.ctc = ctc 630 | 631 | self.length_normalized_loss = length_normalized_loss 632 | self.encoder_output_size = encoder_output_size 633 | 634 | self.lid_dict = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13} 635 | self.lid_int_dict = {24884: 3, 24885: 4, 24888: 7, 24892: 11, 24896: 12, 24992: 13} 636 | self.textnorm_dict = {"withitn": 14, "woitn": 15} 637 | self.textnorm_int_dict = {25016: 14, 25017: 15} 638 | self.embed = torch.nn.Embedding(7 + len(self.lid_dict) + len(self.textnorm_dict), input_size) 639 | self.emo_dict = {"unk": 25009, "happy": 25001, "sad": 25002, "angry": 25003, "neutral": 25004} 640 | 641 | self.criterion_att = LabelSmoothingLoss( 642 | size=self.vocab_size, 643 | padding_idx=self.ignore_id, 644 | smoothing=kwargs.get("lsm_weight", 0.0), 645 | normalize_length=self.length_normalized_loss, 646 | ) 647 | 648 | @staticmethod 649 | def from_pretrained(model:str=None, **kwargs): 650 | from funasr import AutoModel 651 | model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs) 652 | 653 | return model, kwargs 654 | 655 | def forward( 656 | self, 657 | speech: torch.Tensor, 658 | speech_lengths: torch.Tensor, 659 | text: torch.Tensor, 660 | text_lengths: torch.Tensor, 661 | **kwargs, 662 | ): 663 | """Encoder + Decoder + Calc loss 664 | Args: 665 | speech: (Batch, Length, ...) 666 | speech_lengths: (Batch, ) 667 | text: (Batch, Length) 668 | text_lengths: (Batch,) 669 | """ 670 | # import pdb; 671 | # pdb.set_trace() 672 | if len(text_lengths.size()) > 1: 673 | text_lengths = text_lengths[:, 0] 674 | if len(speech_lengths.size()) > 1: 675 | speech_lengths = speech_lengths[:, 0] 676 | 677 | batch_size = speech.shape[0] 678 | 679 | # 1. Encoder 680 | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text) 681 | 682 | loss_ctc, cer_ctc = None, None 683 | loss_rich, acc_rich = None, None 684 | stats = dict() 685 | 686 | loss_ctc, cer_ctc = self._calc_ctc_loss( 687 | encoder_out[:, 4:, :], encoder_out_lens - 4, text[:, 4:], text_lengths - 4 688 | ) 689 | 690 | loss_rich, acc_rich = self._calc_rich_ce_loss( 691 | encoder_out[:, :4, :], text[:, :4] 692 | ) 693 | 694 | loss = loss_ctc + loss_rich 695 | # Collect total loss stats 696 | stats["loss_ctc"] = torch.clone(loss_ctc.detach()) if loss_ctc is not None else None 697 | stats["loss_rich"] = torch.clone(loss_rich.detach()) if loss_rich is not None else None 698 | stats["loss"] = torch.clone(loss.detach()) if loss is not None else None 699 | stats["acc_rich"] = acc_rich 700 | 701 | # force_gatherable: to-device and to-tensor if scalar for DataParallel 702 | if self.length_normalized_loss: 703 | batch_size = int((text_lengths + 1).sum()) 704 | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) 705 | return loss, stats, weight 706 | 707 | def encode( 708 | self, 709 | speech: torch.Tensor, 710 | speech_lengths: torch.Tensor, 711 | text: torch.Tensor, 712 | **kwargs, 713 | ): 714 | """Frontend + Encoder. Note that this method is used by asr_inference.py 715 | Args: 716 | speech: (Batch, Length, ...) 717 | speech_lengths: (Batch, ) 718 | ind: int 719 | """ 720 | 721 | # Data augmentation 722 | if self.specaug is not None and self.training: 723 | speech, speech_lengths = self.specaug(speech, speech_lengths) 724 | 725 | # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN 726 | if self.normalize is not None: 727 | speech, speech_lengths = self.normalize(speech, speech_lengths) 728 | 729 | 730 | lids = torch.LongTensor([[self.lid_int_dict[int(lid)] if torch.rand(1) > 0.2 and int(lid) in self.lid_int_dict else 0 ] for lid in text[:, 0]]).to(speech.device) 731 | language_query = self.embed(lids) 732 | 733 | styles = torch.LongTensor([[self.textnorm_int_dict[int(style)]] for style in text[:, 3]]).to(speech.device) 734 | style_query = self.embed(styles) 735 | speech = torch.cat((style_query, speech), dim=1) 736 | speech_lengths += 1 737 | 738 | event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(speech.size(0), 1, 1) 739 | input_query = torch.cat((language_query, event_emo_query), dim=1) 740 | speech = torch.cat((input_query, speech), dim=1) 741 | speech_lengths += 3 742 | 743 | encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths) 744 | 745 | return encoder_out, encoder_out_lens 746 | 747 | def _calc_ctc_loss( 748 | self, 749 | encoder_out: torch.Tensor, 750 | encoder_out_lens: torch.Tensor, 751 | ys_pad: torch.Tensor, 752 | ys_pad_lens: torch.Tensor, 753 | ): 754 | # Calc CTC loss 755 | loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) 756 | 757 | # Calc CER using CTC 758 | cer_ctc = None 759 | if not self.training and self.error_calculator is not None: 760 | ys_hat = self.ctc.argmax(encoder_out).data 761 | cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) 762 | return loss_ctc, cer_ctc 763 | 764 | def _calc_rich_ce_loss( 765 | self, 766 | encoder_out: torch.Tensor, 767 | ys_pad: torch.Tensor, 768 | ): 769 | decoder_out = self.ctc.ctc_lo(encoder_out) 770 | # 2. Compute attention loss 771 | loss_rich = self.criterion_att(decoder_out, ys_pad.contiguous()) 772 | acc_rich = th_accuracy( 773 | decoder_out.view(-1, self.vocab_size), 774 | ys_pad.contiguous(), 775 | ignore_label=self.ignore_id, 776 | ) 777 | 778 | return loss_rich, acc_rich 779 | 780 | 781 | def inference( 782 | self, 783 | data_in, 784 | data_lengths=None, 785 | key: list = ["wav_file_tmp_name"], 786 | tokenizer=None, 787 | frontend=None, 788 | **kwargs, 789 | ): 790 | 791 | 792 | meta_data = {} 793 | if ( 794 | isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank" 795 | ): # fbank 796 | speech, speech_lengths = data_in, data_lengths 797 | if len(speech.shape) < 3: 798 | speech = speech[None, :, :] 799 | if speech_lengths is None: 800 | speech_lengths = speech.shape[1] 801 | else: 802 | # extract fbank feats 803 | time1 = time.perf_counter() 804 | audio_sample_list = load_audio_text_image_video( 805 | data_in, 806 | fs=frontend.fs, 807 | audio_fs=kwargs.get("fs", 16000), 808 | data_type=kwargs.get("data_type", "sound"), 809 | tokenizer=tokenizer, 810 | ) 811 | time2 = time.perf_counter() 812 | meta_data["load_data"] = f"{time2 - time1:0.3f}" 813 | speech, speech_lengths = extract_fbank( 814 | audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend 815 | ) 816 | time3 = time.perf_counter() 817 | meta_data["extract_feat"] = f"{time3 - time2:0.3f}" 818 | meta_data["batch_data_time"] = ( 819 | speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 820 | ) 821 | 822 | speech = speech.to(device=kwargs["device"]) 823 | speech_lengths = speech_lengths.to(device=kwargs["device"]) 824 | 825 | language = kwargs.get("language", "auto") 826 | language_query = self.embed( 827 | torch.LongTensor( 828 | [[self.lid_dict[language] if language in self.lid_dict else 0]] 829 | ).to(speech.device) 830 | ).repeat(speech.size(0), 1, 1) 831 | 832 | use_itn = kwargs.get("use_itn", False) 833 | output_timestamp = kwargs.get("output_timestamp", False) 834 | 835 | textnorm = kwargs.get("text_norm", None) 836 | if textnorm is None: 837 | textnorm = "withitn" if use_itn else "woitn" 838 | textnorm_query = self.embed( 839 | torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device) 840 | ).repeat(speech.size(0), 1, 1) 841 | speech = torch.cat((textnorm_query, speech), dim=1) 842 | speech_lengths += 1 843 | 844 | event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat( 845 | speech.size(0), 1, 1 846 | ) 847 | input_query = torch.cat((language_query, event_emo_query), dim=1) 848 | speech = torch.cat((input_query, speech), dim=1) 849 | speech_lengths += 3 850 | 851 | # Encoder 852 | encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths) 853 | if isinstance(encoder_out, tuple): 854 | encoder_out = encoder_out[0] 855 | 856 | # c. Passed the encoder result and the beam search 857 | ctc_logits = self.ctc.log_softmax(encoder_out) 858 | if kwargs.get("ban_emo_unk", False): 859 | ctc_logits[:, :, self.emo_dict["unk"]] = -float("inf") 860 | 861 | results = [] 862 | b, n, d = encoder_out.size() 863 | if isinstance(key[0], (list, tuple)): 864 | key = key[0] 865 | if len(key) < b: 866 | key = key * b 867 | for i in range(b): 868 | x = ctc_logits[i, : encoder_out_lens[i].item(), :] 869 | yseq = x.argmax(dim=-1) 870 | yseq = torch.unique_consecutive(yseq, dim=-1) 871 | 872 | ibest_writer = None 873 | if kwargs.get("output_dir") is not None: 874 | if not hasattr(self, "writer"): 875 | self.writer = DatadirWriter(kwargs.get("output_dir")) 876 | ibest_writer = self.writer[f"1best_recog"] 877 | 878 | mask = yseq != self.blank_id 879 | token_int = yseq[mask].tolist() 880 | 881 | # Change integer-ids to tokens 882 | text = tokenizer.decode(token_int) 883 | if ibest_writer is not None: 884 | ibest_writer["text"][key[i]] = text 885 | 886 | if output_timestamp: 887 | from itertools import groupby 888 | timestamp = [] 889 | tokens = tokenizer.text2tokens(text)[4:] 890 | 891 | logits_speech = self.ctc.softmax(encoder_out)[i, 4:encoder_out_lens[i].item(), :] 892 | 893 | pred = logits_speech.argmax(-1).cpu() 894 | logits_speech[pred==self.blank_id, self.blank_id] = 0 895 | 896 | align = ctc_forced_align( 897 | logits_speech.unsqueeze(0).float(), 898 | torch.Tensor(token_int[4:]).unsqueeze(0).long().to(logits_speech.device), 899 | (encoder_out_lens-4).long(), 900 | torch.tensor(len(token_int)-4).unsqueeze(0).long().to(logits_speech.device), 901 | ignore_id=self.ignore_id, 902 | ) 903 | 904 | pred = groupby(align[0, :encoder_out_lens[0]]) 905 | _start = 0 906 | token_id = 0 907 | ts_max = encoder_out_lens[i] - 4 908 | for pred_token, pred_frame in pred: 909 | _end = _start + len(list(pred_frame)) 910 | if pred_token != 0: 911 | ts_left = max((_start*60-30)/1000, 0) 912 | ts_right = min((_end*60-30)/1000, (ts_max*60-30)/1000) 913 | timestamp.append([tokens[token_id], ts_left, ts_right]) 914 | token_id += 1 915 | _start = _end 916 | 917 | result_i = {"key": key[i], "text": text, "timestamp": timestamp} 918 | results.append(result_i) 919 | else: 920 | result_i = {"key": key[i], "text": text} 921 | results.append(result_i) 922 | return results, meta_data 923 | 924 | def export(self, **kwargs): 925 | from export_meta import export_rebuild_model 926 | 927 | if "max_seq_len" not in kwargs: 928 | kwargs["max_seq_len"] = 512 929 | models = export_rebuild_model(model=self, **kwargs) 930 | return models 931 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch<=2.3 2 | torchaudio 3 | modelscope 4 | huggingface 5 | huggingface_hub 6 | funasr>=1.1.3 7 | numpy<=1.26.4 8 | gradio 9 | fastapi>=0.111.1 10 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/SenseVoice/3ecc6f6a8f8524a77e04be51846d840cb35da2b4/utils/__init__.py -------------------------------------------------------------------------------- /utils/ctc_alignment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def ctc_forced_align( 4 | log_probs: torch.Tensor, 5 | targets: torch.Tensor, 6 | input_lengths: torch.Tensor, 7 | target_lengths: torch.Tensor, 8 | blank: int = 0, 9 | ignore_id: int = -1, 10 | ) -> torch.Tensor: 11 | """Align a CTC label sequence to an emission. 12 | 13 | Args: 14 | log_probs (Tensor): log probability of CTC emission output. 15 | Tensor of shape `(B, T, C)`. where `B` is the batch size, `T` is the input length, 16 | `C` is the number of characters in alphabet including blank. 17 | targets (Tensor): Target sequence. Tensor of shape `(B, L)`, 18 | where `L` is the target length. 19 | input_lengths (Tensor): 20 | Lengths of the inputs (max value must each be <= `T`). 1-D Tensor of shape `(B,)`. 21 | target_lengths (Tensor): 22 | Lengths of the targets. 1-D Tensor of shape `(B,)`. 23 | blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0) 24 | ignore_id (int, optional): The index of ignore symbol in CTC emission. (Default: -1) 25 | """ 26 | targets[targets == ignore_id] = blank 27 | 28 | batch_size, input_time_size, _ = log_probs.size() 29 | bsz_indices = torch.arange(batch_size, device=input_lengths.device) 30 | 31 | _t_a_r_g_e_t_s_ = torch.cat( 32 | ( 33 | torch.stack((torch.full_like(targets, blank), targets), dim=-1).flatten(start_dim=1), 34 | torch.full_like(targets[:, :1], blank), 35 | ), 36 | dim=-1, 37 | ) 38 | diff_labels = torch.cat( 39 | ( 40 | torch.as_tensor([[False, False]], device=targets.device).expand(batch_size, -1), 41 | _t_a_r_g_e_t_s_[:, 2:] != _t_a_r_g_e_t_s_[:, :-2], 42 | ), 43 | dim=1, 44 | ) 45 | 46 | neg_inf = torch.tensor(float("-inf"), device=log_probs.device, dtype=log_probs.dtype) 47 | padding_num = 2 48 | padded_t = padding_num + _t_a_r_g_e_t_s_.size(-1) 49 | best_score = torch.full((batch_size, padded_t), neg_inf, device=log_probs.device, dtype=log_probs.dtype) 50 | best_score[:, padding_num + 0] = log_probs[:, 0, blank] 51 | best_score[:, padding_num + 1] = log_probs[bsz_indices, 0, _t_a_r_g_e_t_s_[:, 1]] 52 | 53 | backpointers = torch.zeros((batch_size, input_time_size, padded_t), device=log_probs.device, dtype=targets.dtype) 54 | 55 | for t in range(1, input_time_size): 56 | prev = torch.stack( 57 | (best_score[:, 2:], best_score[:, 1:-1], torch.where(diff_labels, best_score[:, :-2], neg_inf)) 58 | ) 59 | prev_max_value, prev_max_idx = prev.max(dim=0) 60 | best_score[:, padding_num:] = log_probs[:, t].gather(-1, _t_a_r_g_e_t_s_) + prev_max_value 61 | backpointers[:, t, padding_num:] = prev_max_idx 62 | 63 | l1l2 = best_score.gather( 64 | -1, torch.stack((padding_num + target_lengths * 2 - 1, padding_num + target_lengths * 2), dim=-1) 65 | ) 66 | 67 | path = torch.zeros((batch_size, input_time_size), device=best_score.device, dtype=torch.long) 68 | path[bsz_indices, input_lengths - 1] = padding_num + target_lengths * 2 - 1 + l1l2.argmax(dim=-1) 69 | 70 | for t in range(input_time_size - 1, 0, -1): 71 | target_indices = path[:, t] 72 | prev_max_idx = backpointers[bsz_indices, t, target_indices] 73 | path[:, t - 1] += target_indices - prev_max_idx 74 | 75 | alignments = _t_a_r_g_e_t_s_.gather(dim=-1, index=(path - padding_num).clamp(min=0)) 76 | return alignments 77 | -------------------------------------------------------------------------------- /utils/export_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | def export( 6 | model, quantize: bool = False, opset_version: int = 14, type="onnx", **kwargs 7 | ): 8 | model_scripts = model.export(**kwargs) 9 | export_dir = kwargs.get("output_dir", os.path.dirname(kwargs.get("init_param"))) 10 | os.makedirs(export_dir, exist_ok=True) 11 | 12 | if not isinstance(model_scripts, (list, tuple)): 13 | model_scripts = (model_scripts,) 14 | for m in model_scripts: 15 | m.eval() 16 | if type == "onnx": 17 | _onnx( 18 | m, 19 | quantize=quantize, 20 | opset_version=opset_version, 21 | export_dir=export_dir, 22 | **kwargs, 23 | ) 24 | print("output dir: {}".format(export_dir)) 25 | 26 | return export_dir 27 | 28 | 29 | def _onnx( 30 | model, 31 | quantize: bool = False, 32 | opset_version: int = 14, 33 | export_dir: str = None, 34 | **kwargs, 35 | ): 36 | 37 | dummy_input = model.export_dummy_inputs() 38 | 39 | verbose = kwargs.get("verbose", False) 40 | 41 | export_name = model.export_name() 42 | model_path = os.path.join(export_dir, export_name) 43 | torch.onnx.export( 44 | model, 45 | dummy_input, 46 | model_path, 47 | verbose=verbose, 48 | opset_version=opset_version, 49 | input_names=model.export_input_names(), 50 | output_names=model.export_output_names(), 51 | dynamic_axes=model.export_dynamic_axes(), 52 | ) 53 | 54 | if quantize: 55 | from onnxruntime.quantization import QuantType, quantize_dynamic 56 | import onnx 57 | 58 | quant_model_path = model_path.replace(".onnx", "_quant.onnx") 59 | if not os.path.exists(quant_model_path): 60 | onnx_model = onnx.load(model_path) 61 | nodes = [n.name for n in onnx_model.graph.node] 62 | nodes_to_exclude = [ 63 | m for m in nodes if "output" in m or "bias_encoder" in m or "bias_decoder" in m 64 | ] 65 | quantize_dynamic( 66 | model_input=model_path, 67 | model_output=quant_model_path, 68 | op_types_to_quantize=["MatMul"], 69 | per_channel=True, 70 | reduce_range=False, 71 | weight_type=QuantType.QUInt8, 72 | nodes_to_exclude=nodes_to_exclude, 73 | ) 74 | -------------------------------------------------------------------------------- /utils/frontend.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | from pathlib import Path 3 | from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union 4 | import copy 5 | 6 | import numpy as np 7 | import kaldi_native_fbank as knf 8 | 9 | root_dir = Path(__file__).resolve().parent 10 | 11 | logger_initialized = {} 12 | 13 | 14 | class WavFrontend: 15 | """Conventional frontend structure for ASR.""" 16 | 17 | def __init__( 18 | self, 19 | cmvn_file: str = None, 20 | fs: int = 16000, 21 | window: str = "hamming", 22 | n_mels: int = 80, 23 | frame_length: int = 25, 24 | frame_shift: int = 10, 25 | lfr_m: int = 1, 26 | lfr_n: int = 1, 27 | dither: float = 1.0, 28 | **kwargs, 29 | ) -> None: 30 | 31 | opts = knf.FbankOptions() 32 | opts.frame_opts.samp_freq = fs 33 | opts.frame_opts.dither = dither 34 | opts.frame_opts.window_type = window 35 | opts.frame_opts.frame_shift_ms = float(frame_shift) 36 | opts.frame_opts.frame_length_ms = float(frame_length) 37 | opts.mel_opts.num_bins = n_mels 38 | opts.energy_floor = 0 39 | opts.frame_opts.snip_edges = True 40 | opts.mel_opts.debug_mel = False 41 | self.opts = opts 42 | 43 | self.lfr_m = lfr_m 44 | self.lfr_n = lfr_n 45 | self.cmvn_file = cmvn_file 46 | 47 | if self.cmvn_file: 48 | self.cmvn = self.load_cmvn() 49 | self.fbank_fn = None 50 | self.fbank_beg_idx = 0 51 | self.reset_status() 52 | 53 | def fbank(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 54 | waveform = waveform * (1 << 15) 55 | self.fbank_fn = knf.OnlineFbank(self.opts) 56 | self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist()) 57 | frames = self.fbank_fn.num_frames_ready 58 | mat = np.empty([frames, self.opts.mel_opts.num_bins]) 59 | for i in range(frames): 60 | mat[i, :] = self.fbank_fn.get_frame(i) 61 | feat = mat.astype(np.float32) 62 | feat_len = np.array(mat.shape[0]).astype(np.int32) 63 | return feat, feat_len 64 | 65 | def fbank_online(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 66 | waveform = waveform * (1 << 15) 67 | # self.fbank_fn = knf.OnlineFbank(self.opts) 68 | self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist()) 69 | frames = self.fbank_fn.num_frames_ready 70 | mat = np.empty([frames, self.opts.mel_opts.num_bins]) 71 | for i in range(self.fbank_beg_idx, frames): 72 | mat[i, :] = self.fbank_fn.get_frame(i) 73 | # self.fbank_beg_idx += (frames-self.fbank_beg_idx) 74 | feat = mat.astype(np.float32) 75 | feat_len = np.array(mat.shape[0]).astype(np.int32) 76 | return feat, feat_len 77 | 78 | def reset_status(self): 79 | self.fbank_fn = knf.OnlineFbank(self.opts) 80 | self.fbank_beg_idx = 0 81 | 82 | def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 83 | if self.lfr_m != 1 or self.lfr_n != 1: 84 | feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n) 85 | 86 | if self.cmvn_file: 87 | feat = self.apply_cmvn(feat) 88 | 89 | feat_len = np.array(feat.shape[0]).astype(np.int32) 90 | return feat, feat_len 91 | 92 | @staticmethod 93 | def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray: 94 | LFR_inputs = [] 95 | 96 | T = inputs.shape[0] 97 | T_lfr = int(np.ceil(T / lfr_n)) 98 | left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1)) 99 | inputs = np.vstack((left_padding, inputs)) 100 | T = T + (lfr_m - 1) // 2 101 | for i in range(T_lfr): 102 | if lfr_m <= T - i * lfr_n: 103 | LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1)) 104 | else: 105 | # process last LFR frame 106 | num_padding = lfr_m - (T - i * lfr_n) 107 | frame = inputs[i * lfr_n :].reshape(-1) 108 | for _ in range(num_padding): 109 | frame = np.hstack((frame, inputs[-1])) 110 | 111 | LFR_inputs.append(frame) 112 | LFR_outputs = np.vstack(LFR_inputs).astype(np.float32) 113 | return LFR_outputs 114 | 115 | def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray: 116 | """ 117 | Apply CMVN with mvn data 118 | """ 119 | frame, dim = inputs.shape 120 | means = np.tile(self.cmvn[0:1, :dim], (frame, 1)) 121 | vars = np.tile(self.cmvn[1:2, :dim], (frame, 1)) 122 | inputs = (inputs + means) * vars 123 | return inputs 124 | 125 | def load_cmvn( 126 | self, 127 | ) -> np.ndarray: 128 | with open(self.cmvn_file, "r", encoding="utf-8") as f: 129 | lines = f.readlines() 130 | 131 | means_list = [] 132 | vars_list = [] 133 | for i in range(len(lines)): 134 | line_item = lines[i].split() 135 | if line_item[0] == "": 136 | line_item = lines[i + 1].split() 137 | if line_item[0] == "": 138 | add_shift_line = line_item[3 : (len(line_item) - 1)] 139 | means_list = list(add_shift_line) 140 | continue 141 | elif line_item[0] == "": 142 | line_item = lines[i + 1].split() 143 | if line_item[0] == "": 144 | rescale_line = line_item[3 : (len(line_item) - 1)] 145 | vars_list = list(rescale_line) 146 | continue 147 | 148 | means = np.array(means_list).astype(np.float64) 149 | vars = np.array(vars_list).astype(np.float64) 150 | cmvn = np.array([means, vars]) 151 | return cmvn 152 | 153 | 154 | class WavFrontendOnline(WavFrontend): 155 | def __init__(self, **kwargs): 156 | super().__init__(**kwargs) 157 | # self.fbank_fn = knf.OnlineFbank(self.opts) 158 | # add variables 159 | self.frame_sample_length = int( 160 | self.opts.frame_opts.frame_length_ms * self.opts.frame_opts.samp_freq / 1000 161 | ) 162 | self.frame_shift_sample_length = int( 163 | self.opts.frame_opts.frame_shift_ms * self.opts.frame_opts.samp_freq / 1000 164 | ) 165 | self.waveform = None 166 | self.reserve_waveforms = None 167 | self.input_cache = None 168 | self.lfr_splice_cache = [] 169 | 170 | @staticmethod 171 | # inputs has catted the cache 172 | def apply_lfr( 173 | inputs: np.ndarray, lfr_m: int, lfr_n: int, is_final: bool = False 174 | ) -> Tuple[np.ndarray, np.ndarray, int]: 175 | """ 176 | Apply lfr with data 177 | """ 178 | 179 | LFR_inputs = [] 180 | T = inputs.shape[0] # include the right context 181 | T_lfr = int( 182 | np.ceil((T - (lfr_m - 1) // 2) / lfr_n) 183 | ) # minus the right context: (lfr_m - 1) // 2 184 | splice_idx = T_lfr 185 | for i in range(T_lfr): 186 | if lfr_m <= T - i * lfr_n: 187 | LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1)) 188 | else: # process last LFR frame 189 | if is_final: 190 | num_padding = lfr_m - (T - i * lfr_n) 191 | frame = (inputs[i * lfr_n :]).reshape(-1) 192 | for _ in range(num_padding): 193 | frame = np.hstack((frame, inputs[-1])) 194 | LFR_inputs.append(frame) 195 | else: 196 | # update splice_idx and break the circle 197 | splice_idx = i 198 | break 199 | splice_idx = min(T - 1, splice_idx * lfr_n) 200 | lfr_splice_cache = inputs[splice_idx:, :] 201 | LFR_outputs = np.vstack(LFR_inputs) 202 | return LFR_outputs.astype(np.float32), lfr_splice_cache, splice_idx 203 | 204 | @staticmethod 205 | def compute_frame_num( 206 | sample_length: int, frame_sample_length: int, frame_shift_sample_length: int 207 | ) -> int: 208 | frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1) 209 | return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0 210 | 211 | def fbank( 212 | self, input: np.ndarray, input_lengths: np.ndarray 213 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 214 | self.fbank_fn = knf.OnlineFbank(self.opts) 215 | batch_size = input.shape[0] 216 | if self.input_cache is None: 217 | self.input_cache = np.empty((batch_size, 0), dtype=np.float32) 218 | input = np.concatenate((self.input_cache, input), axis=1) 219 | frame_num = self.compute_frame_num( 220 | input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length 221 | ) 222 | # update self.in_cache 223 | self.input_cache = input[ 224 | :, -(input.shape[-1] - frame_num * self.frame_shift_sample_length) : 225 | ] 226 | waveforms = np.empty(0, dtype=np.float32) 227 | feats_pad = np.empty(0, dtype=np.float32) 228 | feats_lens = np.empty(0, dtype=np.int32) 229 | if frame_num: 230 | waveforms = [] 231 | feats = [] 232 | feats_lens = [] 233 | for i in range(batch_size): 234 | waveform = input[i] 235 | waveforms.append( 236 | waveform[ 237 | : ( 238 | (frame_num - 1) * self.frame_shift_sample_length 239 | + self.frame_sample_length 240 | ) 241 | ] 242 | ) 243 | waveform = waveform * (1 << 15) 244 | 245 | self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist()) 246 | frames = self.fbank_fn.num_frames_ready 247 | mat = np.empty([frames, self.opts.mel_opts.num_bins]) 248 | for i in range(frames): 249 | mat[i, :] = self.fbank_fn.get_frame(i) 250 | feat = mat.astype(np.float32) 251 | feat_len = np.array(mat.shape[0]).astype(np.int32) 252 | feats.append(feat) 253 | feats_lens.append(feat_len) 254 | 255 | waveforms = np.stack(waveforms) 256 | feats_lens = np.array(feats_lens) 257 | feats_pad = np.array(feats) 258 | self.fbanks = feats_pad 259 | self.fbanks_lens = copy.deepcopy(feats_lens) 260 | return waveforms, feats_pad, feats_lens 261 | 262 | def get_fbank(self) -> Tuple[np.ndarray, np.ndarray]: 263 | return self.fbanks, self.fbanks_lens 264 | 265 | def lfr_cmvn( 266 | self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False 267 | ) -> Tuple[np.ndarray, np.ndarray, List[int]]: 268 | batch_size = input.shape[0] 269 | feats = [] 270 | feats_lens = [] 271 | lfr_splice_frame_idxs = [] 272 | for i in range(batch_size): 273 | mat = input[i, : input_lengths[i], :] 274 | lfr_splice_frame_idx = -1 275 | if self.lfr_m != 1 or self.lfr_n != 1: 276 | # update self.lfr_splice_cache in self.apply_lfr 277 | mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr( 278 | mat, self.lfr_m, self.lfr_n, is_final 279 | ) 280 | if self.cmvn_file is not None: 281 | mat = self.apply_cmvn(mat) 282 | feat_length = mat.shape[0] 283 | feats.append(mat) 284 | feats_lens.append(feat_length) 285 | lfr_splice_frame_idxs.append(lfr_splice_frame_idx) 286 | 287 | feats_lens = np.array(feats_lens) 288 | feats_pad = np.array(feats) 289 | return feats_pad, feats_lens, lfr_splice_frame_idxs 290 | 291 | def extract_fbank( 292 | self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False 293 | ) -> Tuple[np.ndarray, np.ndarray]: 294 | batch_size = input.shape[0] 295 | assert ( 296 | batch_size == 1 297 | ), "we support to extract feature online only when the batch size is equal to 1 now" 298 | waveforms, feats, feats_lengths = self.fbank(input, input_lengths) # input shape: B T D 299 | if feats.shape[0]: 300 | self.waveforms = ( 301 | waveforms 302 | if self.reserve_waveforms is None 303 | else np.concatenate((self.reserve_waveforms, waveforms), axis=1) 304 | ) 305 | if not self.lfr_splice_cache: 306 | for i in range(batch_size): 307 | self.lfr_splice_cache.append( 308 | np.expand_dims(feats[i][0, :], axis=0).repeat((self.lfr_m - 1) // 2, axis=0) 309 | ) 310 | 311 | if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m: 312 | lfr_splice_cache_np = np.stack(self.lfr_splice_cache) # B T D 313 | feats = np.concatenate((lfr_splice_cache_np, feats), axis=1) 314 | feats_lengths += lfr_splice_cache_np[0].shape[0] 315 | frame_from_waveforms = int( 316 | (self.waveforms.shape[1] - self.frame_sample_length) 317 | / self.frame_shift_sample_length 318 | + 1 319 | ) 320 | minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0 321 | feats, feats_lengths, lfr_splice_frame_idxs = self.lfr_cmvn( 322 | feats, feats_lengths, is_final 323 | ) 324 | if self.lfr_m == 1: 325 | self.reserve_waveforms = None 326 | else: 327 | reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame 328 | # print('reserve_frame_idx: ' + str(reserve_frame_idx)) 329 | # print('frame_frame: ' + str(frame_from_waveforms)) 330 | self.reserve_waveforms = self.waveforms[ 331 | :, 332 | reserve_frame_idx 333 | * self.frame_shift_sample_length : frame_from_waveforms 334 | * self.frame_shift_sample_length, 335 | ] 336 | sample_length = ( 337 | frame_from_waveforms - 1 338 | ) * self.frame_shift_sample_length + self.frame_sample_length 339 | self.waveforms = self.waveforms[:, :sample_length] 340 | else: 341 | # update self.reserve_waveforms and self.lfr_splice_cache 342 | self.reserve_waveforms = self.waveforms[ 343 | :, : -(self.frame_sample_length - self.frame_shift_sample_length) 344 | ] 345 | for i in range(batch_size): 346 | self.lfr_splice_cache[i] = np.concatenate( 347 | (self.lfr_splice_cache[i], feats[i]), axis=0 348 | ) 349 | return np.empty(0, dtype=np.float32), feats_lengths 350 | else: 351 | if is_final: 352 | self.waveforms = ( 353 | waveforms if self.reserve_waveforms is None else self.reserve_waveforms 354 | ) 355 | feats = np.stack(self.lfr_splice_cache) 356 | feats_lengths = np.zeros(batch_size, dtype=np.int32) + feats.shape[1] 357 | feats, feats_lengths, _ = self.lfr_cmvn(feats, feats_lengths, is_final) 358 | if is_final: 359 | self.cache_reset() 360 | return feats, feats_lengths 361 | 362 | def get_waveforms(self): 363 | return self.waveforms 364 | 365 | def cache_reset(self): 366 | self.fbank_fn = knf.OnlineFbank(self.opts) 367 | self.reserve_waveforms = None 368 | self.input_cache = None 369 | self.lfr_splice_cache = [] 370 | 371 | 372 | def load_bytes(input): 373 | middle_data = np.frombuffer(input, dtype=np.int16) 374 | middle_data = np.asarray(middle_data) 375 | if middle_data.dtype.kind not in "iu": 376 | raise TypeError("'middle_data' must be an array of integers") 377 | dtype = np.dtype("float32") 378 | if dtype.kind != "f": 379 | raise TypeError("'dtype' must be a floating point type") 380 | 381 | i = np.iinfo(middle_data.dtype) 382 | abs_max = 2 ** (i.bits - 1) 383 | offset = i.min + abs_max 384 | array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32) 385 | return array 386 | 387 | 388 | class SinusoidalPositionEncoderOnline: 389 | """Streaming Positional encoding.""" 390 | 391 | def encode(self, positions: np.ndarray = None, depth: int = None, dtype: np.dtype = np.float32): 392 | batch_size = positions.shape[0] 393 | positions = positions.astype(dtype) 394 | log_timescale_increment = np.log(np.array([10000], dtype=dtype)) / (depth / 2 - 1) 395 | inv_timescales = np.exp(np.arange(depth / 2).astype(dtype) * (-log_timescale_increment)) 396 | inv_timescales = np.reshape(inv_timescales, [batch_size, -1]) 397 | scaled_time = np.reshape(positions, [1, -1, 1]) * np.reshape(inv_timescales, [1, 1, -1]) 398 | encoding = np.concatenate((np.sin(scaled_time), np.cos(scaled_time)), axis=2) 399 | return encoding.astype(dtype) 400 | 401 | def forward(self, x, start_idx=0): 402 | batch_size, timesteps, input_dim = x.shape 403 | positions = np.arange(1, timesteps + 1 + start_idx)[None, :] 404 | position_encoding = self.encode(positions, input_dim, x.dtype) 405 | 406 | return x + position_encoding[:, start_idx : start_idx + timesteps] 407 | 408 | 409 | def test(): 410 | path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav" 411 | import librosa 412 | 413 | cmvn_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn" 414 | config_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/config.yaml" 415 | from funasr.runtime.python.onnxruntime.rapid_paraformer.utils.utils import read_yaml 416 | 417 | config = read_yaml(config_file) 418 | waveform, _ = librosa.load(path, sr=None) 419 | frontend = WavFrontend( 420 | cmvn_file=cmvn_file, 421 | **config["frontend_conf"], 422 | ) 423 | speech, _ = frontend.fbank_online(waveform) # 1d, (sample,), numpy 424 | feat, feat_len = frontend.lfr_cmvn( 425 | speech 426 | ) # 2d, (frame, 450), np.float32 -> torch, torch.from_numpy(), dtype, (1, frame, 450) 427 | 428 | frontend.reset_status() # clear cache 429 | return feat, feat_len 430 | 431 | 432 | if __name__ == "__main__": 433 | test() 434 | -------------------------------------------------------------------------------- /utils/infer_utils.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | 3 | import functools 4 | import logging 5 | from pathlib import Path 6 | from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union 7 | 8 | import re 9 | import numpy as np 10 | import yaml 11 | 12 | try: 13 | from onnxruntime import ( 14 | GraphOptimizationLevel, 15 | InferenceSession, 16 | SessionOptions, 17 | get_available_providers, 18 | get_device, 19 | ) 20 | except: 21 | print("please pip3 install onnxruntime") 22 | import jieba 23 | import warnings 24 | 25 | root_dir = Path(__file__).resolve().parent 26 | 27 | logger_initialized = {} 28 | 29 | 30 | def pad_list(xs, pad_value, max_len=None): 31 | n_batch = len(xs) 32 | if max_len is None: 33 | max_len = max(x.size(0) for x in xs) 34 | # pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) 35 | # numpy format 36 | pad = (np.zeros((n_batch, max_len)) + pad_value).astype(np.int32) 37 | for i in range(n_batch): 38 | pad[i, : xs[i].shape[0]] = xs[i] 39 | 40 | return pad 41 | 42 | 43 | """ 44 | def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None): 45 | if length_dim == 0: 46 | raise ValueError("length_dim cannot be 0: {}".format(length_dim)) 47 | 48 | if not isinstance(lengths, list): 49 | lengths = lengths.tolist() 50 | bs = int(len(lengths)) 51 | if maxlen is None: 52 | if xs is None: 53 | maxlen = int(max(lengths)) 54 | else: 55 | maxlen = xs.size(length_dim) 56 | else: 57 | assert xs is None 58 | assert maxlen >= int(max(lengths)) 59 | 60 | seq_range = torch.arange(0, maxlen, dtype=torch.int64) 61 | seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) 62 | seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) 63 | mask = seq_range_expand >= seq_length_expand 64 | 65 | if xs is not None: 66 | assert xs.size(0) == bs, (xs.size(0), bs) 67 | 68 | if length_dim < 0: 69 | length_dim = xs.dim() + length_dim 70 | # ind = (:, None, ..., None, :, , None, ..., None) 71 | ind = tuple( 72 | slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) 73 | ) 74 | mask = mask[ind].expand_as(xs).to(xs.device) 75 | return mask 76 | """ 77 | 78 | 79 | class TokenIDConverter: 80 | def __init__( 81 | self, 82 | token_list: Union[List, str], 83 | ): 84 | 85 | self.token_list = token_list 86 | self.unk_symbol = token_list[-1] 87 | self.token2id = {v: i for i, v in enumerate(self.token_list)} 88 | self.unk_id = self.token2id[self.unk_symbol] 89 | 90 | def get_num_vocabulary_size(self) -> int: 91 | return len(self.token_list) 92 | 93 | def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]: 94 | if isinstance(integers, np.ndarray) and integers.ndim != 1: 95 | raise TokenIDConverterError(f"Must be 1 dim ndarray, but got {integers.ndim}") 96 | return [self.token_list[i] for i in integers] 97 | 98 | def tokens2ids(self, tokens: Iterable[str]) -> List[int]: 99 | 100 | return [self.token2id.get(i, self.unk_id) for i in tokens] 101 | 102 | 103 | class CharTokenizer: 104 | def __init__( 105 | self, 106 | symbol_value: Union[Path, str, Iterable[str]] = None, 107 | space_symbol: str = "", 108 | remove_non_linguistic_symbols: bool = False, 109 | ): 110 | 111 | self.space_symbol = space_symbol 112 | self.non_linguistic_symbols = self.load_symbols(symbol_value) 113 | self.remove_non_linguistic_symbols = remove_non_linguistic_symbols 114 | 115 | @staticmethod 116 | def load_symbols(value: Union[Path, str, Iterable[str]] = None) -> Set: 117 | if value is None: 118 | return set() 119 | 120 | if isinstance(value, Iterable[str]): 121 | return set(value) 122 | 123 | file_path = Path(value) 124 | if not file_path.exists(): 125 | logging.warning("%s doesn't exist.", file_path) 126 | return set() 127 | 128 | with file_path.open("r", encoding="utf-8") as f: 129 | return set(line.rstrip() for line in f) 130 | 131 | def text2tokens(self, line: Union[str, list]) -> List[str]: 132 | tokens = [] 133 | while len(line) != 0: 134 | for w in self.non_linguistic_symbols: 135 | if line.startswith(w): 136 | if not self.remove_non_linguistic_symbols: 137 | tokens.append(line[: len(w)]) 138 | line = line[len(w) :] 139 | break 140 | else: 141 | t = line[0] 142 | if t == " ": 143 | t = "" 144 | tokens.append(t) 145 | line = line[1:] 146 | return tokens 147 | 148 | def tokens2text(self, tokens: Iterable[str]) -> str: 149 | tokens = [t if t != self.space_symbol else " " for t in tokens] 150 | return "".join(tokens) 151 | 152 | def __repr__(self): 153 | return ( 154 | f"{self.__class__.__name__}(" 155 | f'space_symbol="{self.space_symbol}"' 156 | f'non_linguistic_symbols="{self.non_linguistic_symbols}"' 157 | f")" 158 | ) 159 | 160 | 161 | class Hypothesis(NamedTuple): 162 | """Hypothesis data type.""" 163 | 164 | yseq: np.ndarray 165 | score: Union[float, np.ndarray] = 0 166 | scores: Dict[str, Union[float, np.ndarray]] = dict() 167 | states: Dict[str, Any] = dict() 168 | 169 | def asdict(self) -> dict: 170 | """Convert data to JSON-friendly dict.""" 171 | return self._replace( 172 | yseq=self.yseq.tolist(), 173 | score=float(self.score), 174 | scores={k: float(v) for k, v in self.scores.items()}, 175 | )._asdict() 176 | 177 | 178 | class TokenIDConverterError(Exception): 179 | pass 180 | 181 | 182 | class ONNXRuntimeError(Exception): 183 | pass 184 | 185 | 186 | class OrtInferSession: 187 | def __init__(self, model_file, device_id=-1, intra_op_num_threads=4): 188 | device_id = str(device_id) 189 | sess_opt = SessionOptions() 190 | sess_opt.intra_op_num_threads = intra_op_num_threads 191 | sess_opt.log_severity_level = 4 192 | sess_opt.enable_cpu_mem_arena = False 193 | sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL 194 | 195 | cuda_ep = "CUDAExecutionProvider" 196 | cuda_provider_options = { 197 | "device_id": device_id, 198 | "arena_extend_strategy": "kNextPowerOfTwo", 199 | "cudnn_conv_algo_search": "EXHAUSTIVE", 200 | "do_copy_in_default_stream": "true", 201 | } 202 | cpu_ep = "CPUExecutionProvider" 203 | cpu_provider_options = { 204 | "arena_extend_strategy": "kSameAsRequested", 205 | } 206 | 207 | EP_list = [] 208 | if device_id != "-1" and get_device() == "GPU" and cuda_ep in get_available_providers(): 209 | EP_list = [(cuda_ep, cuda_provider_options)] 210 | EP_list.append((cpu_ep, cpu_provider_options)) 211 | 212 | self._verify_model(model_file) 213 | self.session = InferenceSession(model_file, sess_options=sess_opt, providers=EP_list) 214 | 215 | if device_id != "-1" and cuda_ep not in self.session.get_providers(): 216 | warnings.warn( 217 | f"{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n" 218 | "Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, " 219 | "you can check their relations from the offical web site: " 220 | "https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html", 221 | RuntimeWarning, 222 | ) 223 | 224 | def __call__(self, input_content: List[Union[np.ndarray, np.ndarray]]) -> np.ndarray: 225 | input_dict = dict(zip(self.get_input_names(), input_content)) 226 | try: 227 | return self.session.run(self.get_output_names(), input_dict) 228 | except Exception as e: 229 | raise ONNXRuntimeError("ONNXRuntime inferece failed.") from e 230 | 231 | def get_input_names( 232 | self, 233 | ): 234 | return [v.name for v in self.session.get_inputs()] 235 | 236 | def get_output_names( 237 | self, 238 | ): 239 | return [v.name for v in self.session.get_outputs()] 240 | 241 | def get_character_list(self, key: str = "character"): 242 | return self.meta_dict[key].splitlines() 243 | 244 | def have_key(self, key: str = "character") -> bool: 245 | self.meta_dict = self.session.get_modelmeta().custom_metadata_map 246 | if key in self.meta_dict.keys(): 247 | return True 248 | return False 249 | 250 | @staticmethod 251 | def _verify_model(model_path): 252 | model_path = Path(model_path) 253 | if not model_path.exists(): 254 | raise FileNotFoundError(f"{model_path} does not exists.") 255 | if not model_path.is_file(): 256 | raise FileExistsError(f"{model_path} is not a file.") 257 | 258 | 259 | def split_to_mini_sentence(words: list, word_limit: int = 20): 260 | assert word_limit > 1 261 | if len(words) <= word_limit: 262 | return [words] 263 | sentences = [] 264 | length = len(words) 265 | sentence_len = length // word_limit 266 | for i in range(sentence_len): 267 | sentences.append(words[i * word_limit : (i + 1) * word_limit]) 268 | if length % word_limit > 0: 269 | sentences.append(words[sentence_len * word_limit :]) 270 | return sentences 271 | 272 | 273 | def code_mix_split_words(text: str): 274 | words = [] 275 | segs = text.split() 276 | for seg in segs: 277 | # There is no space in seg. 278 | current_word = "" 279 | for c in seg: 280 | if len(c.encode()) == 1: 281 | # This is an ASCII char. 282 | current_word += c 283 | else: 284 | # This is a Chinese char. 285 | if len(current_word) > 0: 286 | words.append(current_word) 287 | current_word = "" 288 | words.append(c) 289 | if len(current_word) > 0: 290 | words.append(current_word) 291 | return words 292 | 293 | 294 | def isEnglish(text: str): 295 | if re.search("^[a-zA-Z']+$", text): 296 | return True 297 | else: 298 | return False 299 | 300 | 301 | def join_chinese_and_english(input_list): 302 | line = "" 303 | for token in input_list: 304 | if isEnglish(token): 305 | line = line + " " + token 306 | else: 307 | line = line + token 308 | 309 | line = line.strip() 310 | return line 311 | 312 | 313 | def code_mix_split_words_jieba(seg_dict_file: str): 314 | jieba.load_userdict(seg_dict_file) 315 | 316 | def _fn(text: str): 317 | input_list = text.split() 318 | token_list_all = [] 319 | langauge_list = [] 320 | token_list_tmp = [] 321 | language_flag = None 322 | for token in input_list: 323 | if isEnglish(token) and language_flag == "Chinese": 324 | token_list_all.append(token_list_tmp) 325 | langauge_list.append("Chinese") 326 | token_list_tmp = [] 327 | elif not isEnglish(token) and language_flag == "English": 328 | token_list_all.append(token_list_tmp) 329 | langauge_list.append("English") 330 | token_list_tmp = [] 331 | 332 | token_list_tmp.append(token) 333 | 334 | if isEnglish(token): 335 | language_flag = "English" 336 | else: 337 | language_flag = "Chinese" 338 | 339 | if token_list_tmp: 340 | token_list_all.append(token_list_tmp) 341 | langauge_list.append(language_flag) 342 | 343 | result_list = [] 344 | for token_list_tmp, language_flag in zip(token_list_all, langauge_list): 345 | if language_flag == "English": 346 | result_list.extend(token_list_tmp) 347 | else: 348 | seg_list = jieba.cut(join_chinese_and_english(token_list_tmp), HMM=False) 349 | result_list.extend(seg_list) 350 | 351 | return result_list 352 | 353 | return _fn 354 | 355 | 356 | def read_yaml(yaml_path: Union[str, Path]) -> Dict: 357 | if not Path(yaml_path).exists(): 358 | raise FileExistsError(f"The {yaml_path} does not exist.") 359 | 360 | with open(str(yaml_path), "rb") as f: 361 | data = yaml.load(f, Loader=yaml.Loader) 362 | return data 363 | 364 | 365 | @functools.lru_cache() 366 | def get_logger(name="funasr_onnx"): 367 | """Initialize and get a logger by name. 368 | If the logger has not been initialized, this method will initialize the 369 | logger by adding one or two handlers, otherwise the initialized logger will 370 | be directly returned. During initialization, a StreamHandler will always be 371 | added. 372 | Args: 373 | name (str): Logger name. 374 | Returns: 375 | logging.Logger: The expected logger. 376 | """ 377 | logger = logging.getLogger(name) 378 | if name in logger_initialized: 379 | return logger 380 | 381 | for logger_name in logger_initialized: 382 | if name.startswith(logger_name): 383 | return logger 384 | 385 | formatter = logging.Formatter( 386 | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S" 387 | ) 388 | 389 | sh = logging.StreamHandler() 390 | sh.setFormatter(formatter) 391 | logger.addHandler(sh) 392 | logger_initialized[name] = True 393 | logger.propagate = False 394 | logging.basicConfig(level=logging.ERROR) 395 | return logger 396 | -------------------------------------------------------------------------------- /utils/model_bin.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | # Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved. 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | import os.path 7 | from pathlib import Path 8 | from typing import List, Union, Tuple 9 | import torch 10 | import librosa 11 | import numpy as np 12 | 13 | from utils.infer_utils import ( 14 | CharTokenizer, 15 | Hypothesis, 16 | ONNXRuntimeError, 17 | OrtInferSession, 18 | TokenIDConverter, 19 | get_logger, 20 | read_yaml, 21 | ) 22 | from utils.frontend import WavFrontend 23 | from utils.infer_utils import pad_list 24 | 25 | logging = get_logger() 26 | 27 | 28 | class SenseVoiceSmallONNX: 29 | """ 30 | Author: Speech Lab of DAMO Academy, Alibaba Group 31 | Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition 32 | https://arxiv.org/abs/2206.08317 33 | """ 34 | 35 | def __init__( 36 | self, 37 | model_dir: Union[str, Path] = None, 38 | batch_size: int = 1, 39 | device_id: Union[str, int] = "-1", 40 | plot_timestamp_to: str = "", 41 | quantize: bool = False, 42 | intra_op_num_threads: int = 4, 43 | cache_dir: str = None, 44 | **kwargs, 45 | ): 46 | if quantize: 47 | model_file = os.path.join(model_dir, "model_quant.onnx") 48 | else: 49 | model_file = os.path.join(model_dir, "model.onnx") 50 | 51 | config_file = os.path.join(model_dir, "config.yaml") 52 | cmvn_file = os.path.join(model_dir, "am.mvn") 53 | config = read_yaml(config_file) 54 | # token_list = os.path.join(model_dir, "tokens.json") 55 | # with open(token_list, "r", encoding="utf-8") as f: 56 | # token_list = json.load(f) 57 | 58 | # self.converter = TokenIDConverter(token_list) 59 | self.tokenizer = CharTokenizer() 60 | config["frontend_conf"]['cmvn_file'] = cmvn_file 61 | self.frontend = WavFrontend(**config["frontend_conf"]) 62 | self.ort_infer = OrtInferSession( 63 | model_file, device_id, intra_op_num_threads=intra_op_num_threads 64 | ) 65 | self.batch_size = batch_size 66 | self.blank_id = 0 67 | 68 | def __call__(self, 69 | wav_content: Union[str, np.ndarray, List[str]], 70 | language: List, 71 | textnorm: List, 72 | tokenizer=None, 73 | **kwargs) -> List: 74 | waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq) 75 | waveform_nums = len(waveform_list) 76 | asr_res = [] 77 | for beg_idx in range(0, waveform_nums, self.batch_size): 78 | end_idx = min(waveform_nums, beg_idx + self.batch_size) 79 | feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx]) 80 | ctc_logits, encoder_out_lens = self.infer(feats, 81 | feats_len, 82 | np.array(language, dtype=np.int32), 83 | np.array(textnorm, dtype=np.int32) 84 | ) 85 | # back to torch.Tensor 86 | ctc_logits = torch.from_numpy(ctc_logits).float() 87 | # support batch_size=1 only currently 88 | x = ctc_logits[0, : encoder_out_lens[0].item(), :] 89 | yseq = x.argmax(dim=-1) 90 | yseq = torch.unique_consecutive(yseq, dim=-1) 91 | 92 | mask = yseq != self.blank_id 93 | token_int = yseq[mask].tolist() 94 | 95 | if tokenizer is not None: 96 | asr_res.append(tokenizer.tokens2text(token_int)) 97 | else: 98 | asr_res.append(token_int) 99 | return asr_res 100 | 101 | def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List: 102 | def load_wav(path: str) -> np.ndarray: 103 | waveform, _ = librosa.load(path, sr=fs) 104 | return waveform 105 | 106 | if isinstance(wav_content, np.ndarray): 107 | return [wav_content] 108 | 109 | if isinstance(wav_content, str): 110 | return [load_wav(wav_content)] 111 | 112 | if isinstance(wav_content, list): 113 | return [load_wav(path) for path in wav_content] 114 | 115 | raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]") 116 | 117 | def extract_feat(self, waveform_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]: 118 | feats, feats_len = [], [] 119 | for waveform in waveform_list: 120 | speech, _ = self.frontend.fbank(waveform) 121 | feat, feat_len = self.frontend.lfr_cmvn(speech) 122 | feats.append(feat) 123 | feats_len.append(feat_len) 124 | 125 | feats = self.pad_feats(feats, np.max(feats_len)) 126 | feats_len = np.array(feats_len).astype(np.int32) 127 | return feats, feats_len 128 | 129 | @staticmethod 130 | def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray: 131 | def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray: 132 | pad_width = ((0, max_feat_len - cur_len), (0, 0)) 133 | return np.pad(feat, pad_width, "constant", constant_values=0) 134 | 135 | feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats] 136 | feats = np.array(feat_res).astype(np.float32) 137 | return feats 138 | 139 | def infer(self, 140 | feats: np.ndarray, 141 | feats_len: np.ndarray, 142 | language: np.ndarray, 143 | textnorm: np.ndarray,) -> Tuple[np.ndarray, np.ndarray]: 144 | outputs = self.ort_infer([feats, feats_len, language, textnorm]) 145 | return outputs 146 | -------------------------------------------------------------------------------- /webui.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import os 4 | import librosa 5 | import base64 6 | import io 7 | import gradio as gr 8 | import re 9 | 10 | import numpy as np 11 | import torch 12 | import torchaudio 13 | 14 | 15 | from funasr import AutoModel 16 | 17 | model = "iic/SenseVoiceSmall" 18 | model = AutoModel(model=model, 19 | vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", 20 | vad_kwargs={"max_single_segment_time": 30000}, 21 | trust_remote_code=True, 22 | ) 23 | 24 | import re 25 | 26 | emo_dict = { 27 | "<|HAPPY|>": "😊", 28 | "<|SAD|>": "😔", 29 | "<|ANGRY|>": "😡", 30 | "<|NEUTRAL|>": "", 31 | "<|FEARFUL|>": "😰", 32 | "<|DISGUSTED|>": "🤢", 33 | "<|SURPRISED|>": "😮", 34 | } 35 | 36 | event_dict = { 37 | "<|BGM|>": "🎼", 38 | "<|Speech|>": "", 39 | "<|Applause|>": "👏", 40 | "<|Laughter|>": "😀", 41 | "<|Cry|>": "😭", 42 | "<|Sneeze|>": "🤧", 43 | "<|Breath|>": "", 44 | "<|Cough|>": "🤧", 45 | } 46 | 47 | emoji_dict = { 48 | "<|nospeech|><|Event_UNK|>": "❓", 49 | "<|zh|>": "", 50 | "<|en|>": "", 51 | "<|yue|>": "", 52 | "<|ja|>": "", 53 | "<|ko|>": "", 54 | "<|nospeech|>": "", 55 | "<|HAPPY|>": "😊", 56 | "<|SAD|>": "😔", 57 | "<|ANGRY|>": "😡", 58 | "<|NEUTRAL|>": "", 59 | "<|BGM|>": "🎼", 60 | "<|Speech|>": "", 61 | "<|Applause|>": "👏", 62 | "<|Laughter|>": "😀", 63 | "<|FEARFUL|>": "😰", 64 | "<|DISGUSTED|>": "🤢", 65 | "<|SURPRISED|>": "😮", 66 | "<|Cry|>": "😭", 67 | "<|EMO_UNKNOWN|>": "", 68 | "<|Sneeze|>": "🤧", 69 | "<|Breath|>": "", 70 | "<|Cough|>": "😷", 71 | "<|Sing|>": "", 72 | "<|Speech_Noise|>": "", 73 | "<|withitn|>": "", 74 | "<|woitn|>": "", 75 | "<|GBG|>": "", 76 | "<|Event_UNK|>": "", 77 | } 78 | 79 | lang_dict = { 80 | "<|zh|>": "<|lang|>", 81 | "<|en|>": "<|lang|>", 82 | "<|yue|>": "<|lang|>", 83 | "<|ja|>": "<|lang|>", 84 | "<|ko|>": "<|lang|>", 85 | "<|nospeech|>": "<|lang|>", 86 | } 87 | 88 | emo_set = {"😊", "😔", "😡", "😰", "🤢", "😮"} 89 | event_set = {"🎼", "👏", "😀", "😭", "🤧", "😷",} 90 | 91 | def format_str(s): 92 | for sptk in emoji_dict: 93 | s = s.replace(sptk, emoji_dict[sptk]) 94 | return s 95 | 96 | 97 | def format_str_v2(s): 98 | sptk_dict = {} 99 | for sptk in emoji_dict: 100 | sptk_dict[sptk] = s.count(sptk) 101 | s = s.replace(sptk, "") 102 | emo = "<|NEUTRAL|>" 103 | for e in emo_dict: 104 | if sptk_dict[e] > sptk_dict[emo]: 105 | emo = e 106 | for e in event_dict: 107 | if sptk_dict[e] > 0: 108 | s = event_dict[e] + s 109 | s = s + emo_dict[emo] 110 | 111 | for emoji in emo_set.union(event_set): 112 | s = s.replace(" " + emoji, emoji) 113 | s = s.replace(emoji + " ", emoji) 114 | return s.strip() 115 | 116 | def format_str_v3(s): 117 | def get_emo(s): 118 | return s[-1] if s[-1] in emo_set else None 119 | def get_event(s): 120 | return s[0] if s[0] in event_set else None 121 | 122 | s = s.replace("<|nospeech|><|Event_UNK|>", "❓") 123 | for lang in lang_dict: 124 | s = s.replace(lang, "<|lang|>") 125 | s_list = [format_str_v2(s_i).strip(" ") for s_i in s.split("<|lang|>")] 126 | new_s = " " + s_list[0] 127 | cur_ent_event = get_event(new_s) 128 | for i in range(1, len(s_list)): 129 | if len(s_list[i]) == 0: 130 | continue 131 | if get_event(s_list[i]) == cur_ent_event and get_event(s_list[i]) != None: 132 | s_list[i] = s_list[i][1:] 133 | #else: 134 | cur_ent_event = get_event(s_list[i]) 135 | if get_emo(s_list[i]) != None and get_emo(s_list[i]) == get_emo(new_s): 136 | new_s = new_s[:-1] 137 | new_s += s_list[i].strip().lstrip() 138 | new_s = new_s.replace("The.", " ") 139 | return new_s.strip() 140 | 141 | def model_inference(input_wav, language, fs=16000): 142 | # task_abbr = {"Speech Recognition": "ASR", "Rich Text Transcription": ("ASR", "AED", "SER")} 143 | language_abbr = {"auto": "auto", "zh": "zh", "en": "en", "yue": "yue", "ja": "ja", "ko": "ko", 144 | "nospeech": "nospeech"} 145 | 146 | # task = "Speech Recognition" if task is None else task 147 | language = "auto" if len(language) < 1 else language 148 | selected_language = language_abbr[language] 149 | # selected_task = task_abbr.get(task) 150 | 151 | # print(f"input_wav: {type(input_wav)}, {input_wav[1].shape}, {input_wav}") 152 | 153 | if isinstance(input_wav, tuple): 154 | fs, input_wav = input_wav 155 | input_wav = input_wav.astype(np.float32) / np.iinfo(np.int16).max 156 | if len(input_wav.shape) > 1: 157 | input_wav = input_wav.mean(-1) 158 | if fs != 16000: 159 | print(f"audio_fs: {fs}") 160 | resampler = torchaudio.transforms.Resample(fs, 16000) 161 | input_wav_t = torch.from_numpy(input_wav).to(torch.float32) 162 | input_wav = resampler(input_wav_t[None, :])[0, :].numpy() 163 | 164 | 165 | merge_vad = True #False if selected_task == "ASR" else True 166 | print(f"language: {language}, merge_vad: {merge_vad}") 167 | text = model.generate(input=input_wav, 168 | cache={}, 169 | language=language, 170 | use_itn=True, 171 | batch_size_s=60, merge_vad=merge_vad) 172 | 173 | print(text) 174 | text = text[0]["text"] 175 | text = format_str_v3(text) 176 | 177 | print(text) 178 | 179 | return text 180 | 181 | 182 | audio_examples = [ 183 | ["example/zh.mp3", "zh"], 184 | ["example/yue.mp3", "yue"], 185 | ["example/en.mp3", "en"], 186 | ["example/ja.mp3", "ja"], 187 | ["example/ko.mp3", "ko"], 188 | ["example/emo_1.wav", "auto"], 189 | ["example/emo_2.wav", "auto"], 190 | ["example/emo_3.wav", "auto"], 191 | #["example/emo_4.wav", "auto"], 192 | #["example/event_1.wav", "auto"], 193 | #["example/event_2.wav", "auto"], 194 | #["example/event_3.wav", "auto"], 195 | ["example/rich_1.wav", "auto"], 196 | ["example/rich_2.wav", "auto"], 197 | #["example/rich_3.wav", "auto"], 198 | ["example/longwav_1.wav", "auto"], 199 | ["example/longwav_2.wav", "auto"], 200 | ["example/longwav_3.wav", "auto"], 201 | #["example/longwav_4.wav", "auto"], 202 | ] 203 | 204 | 205 | 206 | html_content = """ 207 |
208 |

Voice Understanding Model: SenseVoice-Small

209 |

SenseVoice-Small is an encoder-only speech foundation model designed for rapid voice understanding. It encompasses a variety of features including automatic speech recognition (ASR), spoken language identification (LID), speech emotion recognition (SER), and acoustic event detection (AED). SenseVoice-Small supports multilingual recognition for Chinese, English, Cantonese, Japanese, and Korean. Additionally, it offers exceptionally low inference latency, performing 7 times faster than Whisper-small and 17 times faster than Whisper-large.

210 |

Usage

Upload an audio file or input through a microphone, then select the task and language. the audio is transcribed into corresponding text along with associated emotions (😊 happy, 😡 angry/exicting, 😔 sad) and types of sound events (😀 laughter, 🎼 music, 👏 applause, 🤧 cough&sneeze, 😭 cry). The event labels are placed in the front of the text and the emotion are in the back of the text.

211 |

Recommended audio input duration is below 30 seconds. For audio longer than 30 seconds, local deployment is recommended.

212 |

Repo

213 |

SenseVoice: multilingual speech understanding model

214 |

FunASR: fundamental speech recognition toolkit

215 |

CosyVoice: high-quality multilingual TTS model

216 |
217 | """ 218 | 219 | 220 | def launch(): 221 | with gr.Blocks(theme=gr.themes.Soft()) as demo: 222 | # gr.Markdown(description) 223 | gr.HTML(html_content) 224 | with gr.Row(): 225 | with gr.Column(): 226 | audio_inputs = gr.Audio(label="Upload audio or use the microphone") 227 | 228 | with gr.Accordion("Configuration"): 229 | language_inputs = gr.Dropdown(choices=["auto", "zh", "en", "yue", "ja", "ko", "nospeech"], 230 | value="auto", 231 | label="Language") 232 | fn_button = gr.Button("Start", variant="primary") 233 | text_outputs = gr.Textbox(label="Results") 234 | gr.Examples(examples=audio_examples, inputs=[audio_inputs, language_inputs], examples_per_page=20) 235 | 236 | fn_button.click(model_inference, inputs=[audio_inputs, language_inputs], outputs=text_outputs) 237 | 238 | demo.launch() 239 | 240 | 241 | if __name__ == "__main__": 242 | # iface.launch() 243 | launch() 244 | 245 | 246 | --------------------------------------------------------------------------------