├── LICENSE ├── README.md ├── assets └── FireRedASR_model.png ├── examples ├── fireredasr ├── inference_fireredasr_aed.sh ├── inference_fireredasr_llm.sh ├── pretrained_models └── wav │ ├── BAC009S0764W0121.wav │ ├── IT0011W0001.wav │ ├── TEST_MEETING_T0000000001_S00000.wav │ ├── TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000.wav │ ├── text │ └── wav.scp ├── fireredasr ├── data │ ├── asr_feat.py │ └── token_dict.py ├── models │ ├── fireredasr.py │ ├── fireredasr_aed.py │ ├── fireredasr_llm.py │ └── module │ │ ├── adapter.py │ │ ├── conformer_encoder.py │ │ └── transformer_decoder.py ├── speech2text.py ├── tokenizer │ ├── aed_tokenizer.py │ └── llm_tokenizer.py └── utils │ ├── param.py │ └── wer.py ├── pretrained_models └── README.md └── requirements.txt /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

FireRedASR: Open-Source Industrial-Grade 3 |
4 | Automatic Speech Recognition Models

5 | 6 |
7 | 8 | [[Paper]](https://arxiv.org/pdf/2501.14350) 9 | [[Model]](https://huggingface.co/fireredteam) 10 | [[Blog]](https://fireredteam.github.io/demos/firered_asr/) 11 | 12 | FireRedASR is a family of open-source industrial-grade automatic speech recognition (ASR) models supporting Mandarin, Chinese dialects and English, achieving a new state-of-the-art (SOTA) on public Mandarin ASR benchmarks, while also offering outstanding singing lyrics recognition capability. 13 | 14 | 15 | ## 🔥 News 16 | - [2025/02/17] We release [FireRedASR-LLM-L](https://huggingface.co/fireredteam/FireRedASR-LLM-L/tree/main) model weights. 17 | - [2025/01/24] We release [technical report](https://arxiv.org/pdf/2501.14350), [blog](https://fireredteam.github.io/demos/firered_asr/), and [FireRedASR-AED-L](https://huggingface.co/fireredteam/FireRedASR-AED-L/tree/main) model weights. 18 | 19 | 20 | ## Method 21 | 22 | FireRedASR is designed to meet diverse requirements in superior performance and optimal efficiency across various applications. It comprises two variants: 23 | - FireRedASR-LLM: Designed to achieve state-of-the-art (SOTA) performance and to enable seamless end-to-end speech interaction. It adopts an Encoder-Adapter-LLM framework leveraging large language model (LLM) capabilities. 24 | - FireRedASR-AED: Designed to balance high performance and computational efficiency and to serve as an effective speech representation module in LLM-based speech models. It utilizes an Attention-based Encoder-Decoder (AED) architecture. 25 | 26 | ![Model](/assets/FireRedASR_model.png) 27 | 28 | 29 | ## Evaluation 30 | Results are reported in Character Error Rate (CER%) for Chinese and Word Error Rate (WER%) for English. 31 | 32 | ### Evaluation on Public Mandarin ASR Benchmarks 33 | | Model | #Params | aishell1 | aishell2 | ws\_net | ws\_meeting | Average-4 | 34 | |:----------------:|:-------:|:--------:|:--------:|:--------:|:-----------:|:---------:| 35 | | FireRedASR-LLM | 8.3B | 0.76 | 2.15 | 4.60 | 4.67 | 3.05 | 36 | | FireRedASR-AED | 1.1B | 0.55 | 2.52 | 4.88 | 4.76 | 3.18 | 37 | | Seed-ASR | 12B+ | 0.68 | 2.27 | 4.66 | 5.69 | 3.33 | 38 | | Qwen-Audio | 8.4B | 1.30 | 3.10 | 9.50 | 10.87 | 6.19 | 39 | | SenseVoice-L | 1.6B | 2.09 | 3.04 | 6.01 | 6.73 | 4.47 | 40 | | Whisper-Large-v3 | 1.6B | 5.14 | 4.96 | 10.48 | 18.87 | 9.86 | 41 | | Paraformer-Large | 0.2B | 1.68 | 2.85 | 6.74 | 6.97 | 4.56 | 42 | 43 | `ws` means WenetSpeech. 44 | 45 | ### Evaluation on Public Chinese Dialect and English ASR Benchmarks 46 | |Test Set | KeSpeech | LibriSpeech test-clean | LibriSpeech test-other | 47 | | :------------:| :------: | :--------------------: | :----------------------:| 48 | |FireRedASR-LLM | 3.56 | 1.73 | 3.67 | 49 | |FireRedASR-AED | 4.48 | 1.93 | 4.44 | 50 | |Previous SOTA Results | 6.70 | 1.82 | 3.50 | 51 | 52 | 53 | ## Usage 54 | Download model files from [huggingface](https://huggingface.co/fireredteam) and place them in the folder `pretrained_models`. 55 | 56 | If you want to use `FireRedASR-LLM-L`, you also need to download [Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct) and place it in the folder `pretrained_models`. Then, go to folder `FireRedASR-LLM-L` and run `$ ln -s ../Qwen2-7B-Instruct` 57 | 58 | 59 | ### Setup 60 | Create a Python environment and install dependencies 61 | ```bash 62 | $ git clone https://github.com/FireRedTeam/FireRedASR.git 63 | $ conda create --name fireredasr python=3.10 64 | $ pip install -r requirements.txt 65 | ``` 66 | 67 | Set up Linux PATH and PYTHONPATH 68 | ``` 69 | $ export PATH=$PWD/fireredasr/:$PWD/fireredasr/utils/:$PATH 70 | $ export PYTHONPATH=$PWD/:$PYTHONPATH 71 | ``` 72 | 73 | Convert audio to 16kHz 16-bit PCM format 74 | ``` 75 | ffmpeg -i input_audio -ar 16000 -ac 1 -acodec pcm_s16le -f wav output.wav 76 | ``` 77 | 78 | ### Quick Start 79 | ```bash 80 | $ cd examples 81 | $ bash inference_fireredasr_aed.sh 82 | $ bash inference_fireredasr_llm.sh 83 | ``` 84 | 85 | ### Command-line Usage 86 | ```bash 87 | $ speech2text.py --help 88 | $ speech2text.py --wav_path examples/wav/BAC009S0764W0121.wav --asr_type "aed" --model_dir pretrained_models/FireRedASR-AED-L 89 | $ speech2text.py --wav_path examples/wav/BAC009S0764W0121.wav --asr_type "llm" --model_dir pretrained_models/FireRedASR-LLM-L 90 | ``` 91 | 92 | ### Python Usage 93 | ```python 94 | from fireredasr.models.fireredasr import FireRedAsr 95 | 96 | batch_uttid = ["BAC009S0764W0121"] 97 | batch_wav_path = ["examples/wav/BAC009S0764W0121.wav"] 98 | 99 | # FireRedASR-AED 100 | model = FireRedAsr.from_pretrained("aed", "pretrained_models/FireRedASR-AED-L") 101 | results = model.transcribe( 102 | batch_uttid, 103 | batch_wav_path, 104 | { 105 | "use_gpu": 1, 106 | "beam_size": 3, 107 | "nbest": 1, 108 | "decode_max_len": 0, 109 | "softmax_smoothing": 1.25, 110 | "aed_length_penalty": 0.6, 111 | "eos_penalty": 1.0 112 | } 113 | ) 114 | print(results) 115 | 116 | 117 | # FireRedASR-LLM 118 | model = FireRedAsr.from_pretrained("llm", "pretrained_models/FireRedASR-LLM-L") 119 | results = model.transcribe( 120 | batch_uttid, 121 | batch_wav_path, 122 | { 123 | "use_gpu": 1, 124 | "beam_size": 3, 125 | "decode_max_len": 0, 126 | "decode_min_len": 0, 127 | "repetition_penalty": 3.0, 128 | "llm_length_penalty": 1.0, 129 | "temperature": 1.0 130 | } 131 | ) 132 | print(results) 133 | ``` 134 | 135 | ## Usage Tips 136 | ### Batch Beam Search 137 | - When performing batch beam search with FireRedASR-LLM, please ensure that the input lengths of the utterances are similar. If there are significant differences in utterance lengths, shorter utterances may experience repetition issues. You can either sort your dataset by length or set `batch_size` to 1 to avoid the repetition issue. 138 | 139 | ### Input Length Limitations 140 | - FireRedASR-AED supports audio input up to 60s. Input longer than 60s may cause hallucination issues, and input exceeding 200s will trigger positional encoding errors. 141 | - FireRedASR-LLM supports audio input up to 30s. The behavior for longer input is currently unknown. 142 | 143 | 144 | ## Acknowledgements 145 | Thanks to the following open-source works: 146 | - [Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct) 147 | - [icefall/ASR_LLM](https://github.com/k2-fsa/icefall/tree/master/egs/speech_llm/ASR_LLM) 148 | - [WeNet](https://github.com/wenet-e2e/wenet) 149 | - [Speech-Transformer](https://github.com/kaituoxu/Speech-Transformer) 150 | 151 | 152 | ## Citation 153 | ```bibtex 154 | @article{xu2025fireredasr, 155 | title={FireRedASR: Open-Source Industrial-Grade Mandarin Speech Recognition Models from Encoder-Decoder to LLM Integration}, 156 | author={Xu, Kai-Tuo and Xie, Feng-Long and Tang, Xu and Hu, Yao}, 157 | journal={arXiv preprint arXiv:2501.14350}, 158 | year={2025} 159 | } 160 | ``` 161 | -------------------------------------------------------------------------------- /assets/FireRedASR_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FireRedTeam/FireRedASR/1eadb81b66eca948cd492bc0aeedd786333c049d/assets/FireRedASR_model.png -------------------------------------------------------------------------------- /examples/fireredasr: -------------------------------------------------------------------------------- 1 | ../fireredasr -------------------------------------------------------------------------------- /examples/inference_fireredasr_aed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PATH=$PWD/fireredasr/:$PWD/fireredasr/utils/:$PATH 4 | export PYTHONPATH=$PWD/:$PYTHONPATH 5 | 6 | # model_dir includes model.pth.tar, cmvn.ark, dict.txt 7 | model_dir=$PWD/pretrained_models/FireRedASR-AED-L 8 | 9 | # Support several input format 10 | wavs="--wav_path wav/BAC009S0764W0121.wav" 11 | wavs="--wav_paths wav/BAC009S0764W0121.wav wav/IT0011W0001.wav wav/TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000.wav wav/TEST_MEETING_T0000000001_S00000.wav" 12 | wavs="--wav_dir wav/" 13 | wavs="--wav_scp wav/wav.scp" 14 | 15 | out="out/aed-l-asr.txt" 16 | 17 | decode_args=" 18 | --batch_size 2 --beam_size 3 --nbest 1 19 | --decode_max_len 0 --softmax_smoothing 1.25 --aed_length_penalty 0.6 20 | --eos_penalty 1.0 21 | " 22 | 23 | mkdir -p $(dirname $out) 24 | set -x 25 | 26 | 27 | CUDA_VISIBLE_DEVICES=0 \ 28 | speech2text.py --asr_type "aed" --model_dir $model_dir $decode_args $wavs --output $out 29 | 30 | 31 | ref="wav/text" 32 | wer.py --print_sentence_wer 1 --do_tn 0 --rm_special 0 --ref $ref --hyp $out > $out.wer 2>&1 33 | tail -n8 $out.wer 34 | -------------------------------------------------------------------------------- /examples/inference_fireredasr_llm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PATH=$PWD/fireredasr/:$PWD/fireredasr/utils/:$PATH 4 | export PYTHONPATH=$PWD/:$PYTHONPATH 5 | 6 | # model_dir includes model.pth.tar, asr_encoder.pth.tar, cmvn.ark, Qwen2-7B-Instruct 7 | model_dir=$PWD/pretrained_models/FireRedASR-LLM-L 8 | 9 | # Support several input format 10 | wavs="--wav_path wav/BAC009S0764W0121.wav" 11 | wavs="--wav_paths wav/BAC009S0764W0121.wav wav/IT0011W0001.wav wav/TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000.wav wav/TEST_MEETING_T0000000001_S00000.wav" 12 | wavs="--wav_dir wav/" 13 | wavs="--wav_scp wav/wav.scp" 14 | 15 | out="out/llm-l-asr.txt" 16 | 17 | decode_args=" 18 | --batch_size 1 --beam_size 3 --decode_max_len 0 --decode_min_len 0 19 | --repetition_penalty 3.0 --llm_length_penalty 1.0 --temperature 1.0 20 | " 21 | 22 | mkdir -p $(dirname $out) 23 | set -x 24 | 25 | 26 | CUDA_VISIBLE_DEVICES=0 \ 27 | speech2text.py --asr_type "llm" --model_dir $model_dir $decode_args $wavs --output $out 28 | 29 | 30 | ref="wav/text" 31 | wer.py --print_sentence_wer 1 --do_tn 0 --rm_special 1 --ref $ref --hyp $out > $out.wer 2>&1 32 | tail -n8 $out.wer 33 | -------------------------------------------------------------------------------- /examples/pretrained_models: -------------------------------------------------------------------------------- 1 | ../pretrained_models -------------------------------------------------------------------------------- /examples/wav/BAC009S0764W0121.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FireRedTeam/FireRedASR/1eadb81b66eca948cd492bc0aeedd786333c049d/examples/wav/BAC009S0764W0121.wav -------------------------------------------------------------------------------- /examples/wav/IT0011W0001.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FireRedTeam/FireRedASR/1eadb81b66eca948cd492bc0aeedd786333c049d/examples/wav/IT0011W0001.wav -------------------------------------------------------------------------------- /examples/wav/TEST_MEETING_T0000000001_S00000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FireRedTeam/FireRedASR/1eadb81b66eca948cd492bc0aeedd786333c049d/examples/wav/TEST_MEETING_T0000000001_S00000.wav -------------------------------------------------------------------------------- /examples/wav/TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FireRedTeam/FireRedASR/1eadb81b66eca948cd492bc0aeedd786333c049d/examples/wav/TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000.wav -------------------------------------------------------------------------------- /examples/wav/text: -------------------------------------------------------------------------------- 1 | BAC009S0764W0121 甚至 出现 交易 几乎 停滞 的 情况 2 | IT0011W0001 换一首歌 3 | TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000 我有的时候说不清楚你们知道吗 4 | TEST_MEETING_T0000000001_S00000 好首先说一下刚才这个经理说完的这个销售问题咱再说一下咱们的商场问题首先咱们商场上半年业这个先各部门儿汇报一下就是业绩 5 | -------------------------------------------------------------------------------- /examples/wav/wav.scp: -------------------------------------------------------------------------------- 1 | BAC009S0764W0121 wav/BAC009S0764W0121.wav 2 | IT0011W0001 wav/IT0011W0001.wav 3 | TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000 wav/TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000.wav 4 | TEST_MEETING_T0000000001_S00000 wav/TEST_MEETING_T0000000001_S00000.wav 5 | -------------------------------------------------------------------------------- /fireredasr/data/asr_feat.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | import kaldiio 5 | import kaldi_native_fbank as knf 6 | import numpy as np 7 | import torch 8 | 9 | 10 | class ASRFeatExtractor: 11 | def __init__(self, kaldi_cmvn_file): 12 | self.cmvn = CMVN(kaldi_cmvn_file) if kaldi_cmvn_file != "" else None 13 | self.fbank = KaldifeatFbank(num_mel_bins=80, frame_length=25, 14 | frame_shift=10, dither=0.0) 15 | 16 | def __call__(self, wav_paths): 17 | feats = [] 18 | durs = [] 19 | for wav_path in wav_paths: 20 | sample_rate, wav_np = kaldiio.load_mat(wav_path) 21 | dur = wav_np.shape[0] / sample_rate 22 | fbank = self.fbank((sample_rate, wav_np)) 23 | if self.cmvn is not None: 24 | fbank = self.cmvn(fbank) 25 | fbank = torch.from_numpy(fbank).float() 26 | feats.append(fbank) 27 | durs.append(dur) 28 | lengths = torch.tensor([feat.size(0) for feat in feats]).long() 29 | feats_pad = self.pad_feat(feats, 0.0) 30 | return feats_pad, lengths, durs 31 | 32 | def pad_feat(self, xs, pad_value): 33 | # type: (List[Tensor], int) -> Tensor 34 | n_batch = len(xs) 35 | max_len = max([xs[i].size(0) for i in range(n_batch)]) 36 | pad = torch.ones(n_batch, max_len, *xs[0].size()[1:]).to(xs[0].device).to(xs[0].dtype).fill_(pad_value) 37 | for i in range(n_batch): 38 | pad[i, :xs[i].size(0)] = xs[i] 39 | return pad 40 | 41 | 42 | 43 | 44 | class CMVN: 45 | def __init__(self, kaldi_cmvn_file): 46 | self.dim, self.means, self.inverse_std_variences = \ 47 | self.read_kaldi_cmvn(kaldi_cmvn_file) 48 | 49 | def __call__(self, x, is_train=False): 50 | assert x.shape[-1] == self.dim, "CMVN dim mismatch" 51 | out = x - self.means 52 | out = out * self.inverse_std_variences 53 | return out 54 | 55 | def read_kaldi_cmvn(self, kaldi_cmvn_file): 56 | assert os.path.exists(kaldi_cmvn_file) 57 | stats = kaldiio.load_mat(kaldi_cmvn_file) 58 | assert stats.shape[0] == 2 59 | dim = stats.shape[-1] - 1 60 | count = stats[0, dim] 61 | assert count >= 1 62 | floor = 1e-20 63 | means = [] 64 | inverse_std_variences = [] 65 | for d in range(dim): 66 | mean = stats[0, d] / count 67 | means.append(mean.item()) 68 | varience = (stats[1, d] / count) - mean*mean 69 | if varience < floor: 70 | varience = floor 71 | istd = 1.0 / math.sqrt(varience) 72 | inverse_std_variences.append(istd) 73 | return dim, np.array(means), np.array(inverse_std_variences) 74 | 75 | 76 | 77 | class KaldifeatFbank: 78 | def __init__(self, num_mel_bins=80, frame_length=25, frame_shift=10, 79 | dither=1.0): 80 | self.dither = dither 81 | opts = knf.FbankOptions() 82 | opts.frame_opts.dither = dither 83 | opts.mel_opts.num_bins = num_mel_bins 84 | opts.frame_opts.snip_edges = True 85 | opts.mel_opts.debug_mel = False 86 | self.opts = opts 87 | 88 | def __call__(self, wav, is_train=False): 89 | if type(wav) is str: 90 | sample_rate, wav_np = kaldiio.load_mat(wav) 91 | elif type(wav) in [tuple, list] and len(wav) == 2: 92 | sample_rate, wav_np = wav 93 | assert len(wav_np.shape) == 1 94 | 95 | dither = self.dither if is_train else 0.0 96 | self.opts.frame_opts.dither = dither 97 | fbank = knf.OnlineFbank(self.opts) 98 | 99 | fbank.accept_waveform(sample_rate, wav_np.tolist()) 100 | feat = [] 101 | for i in range(fbank.num_frames_ready): 102 | feat.append(fbank.get_frame(i)) 103 | if len(feat) == 0: 104 | print("Check data, len(feat) == 0", wav, flush=True) 105 | return np.zeros((0, self.opts.mel_opts.num_bins)) 106 | feat = np.vstack(feat) 107 | return feat 108 | -------------------------------------------------------------------------------- /fireredasr/data/token_dict.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | class TokenDict: 5 | def __init__(self, dict_path, unk=""): 6 | assert dict_path != "" 7 | self.id2word, self.word2id = self.read_dict(dict_path) 8 | self.unk = unk 9 | assert unk == "" or unk in self.word2id 10 | self.unkid = self.word2id[unk] if unk else -1 11 | 12 | def get(self, key, default): 13 | if type(default) == str: 14 | default = self.word2id[default] 15 | return self.word2id.get(key, default) 16 | 17 | def __getitem__(self, key): 18 | if type(key) == str: 19 | if self.unk: 20 | return self.word2id.get(key, self.word2id[self.unk]) 21 | else: 22 | return self.word2id[key] 23 | elif type(key) == int: 24 | return self.id2word[key] 25 | else: 26 | raise TypeError("Key should be str or int") 27 | 28 | def __len__(self): 29 | return len(self.id2word) 30 | 31 | def __contains__(self, query): 32 | if type(query) == str: 33 | return query in self.word2id 34 | elif type(query) == int: 35 | return query in self.id2word 36 | else: 37 | raise TypeError("query should be str or int") 38 | 39 | def read_dict(self, dict_path): 40 | id2word, word2id = [], {} 41 | with open(dict_path, encoding='utf8') as f: 42 | for i, line in enumerate(f): 43 | tokens = line.strip().split() 44 | if len(tokens) >= 2: 45 | word, index = tokens[0], int(tokens[1]) 46 | elif len(tokens) == 1: 47 | word, index = tokens[0], i 48 | else: # empty line or space 49 | logging.info(f"Find empty line or space '{line.strip()}' in {dict_path}:L{i}, set to ' '") 50 | word, index = " ", i 51 | assert len(id2word) == index 52 | assert len(word2id) == index 53 | if word == "": 54 | logging.info(f"NOTE: Find in {dict_path}:L{i} and convert it to ' '") 55 | word = " " 56 | word2id[word] = index 57 | id2word.append(word) 58 | assert len(id2word) == len(word2id) 59 | return id2word, word2id 60 | -------------------------------------------------------------------------------- /fireredasr/models/fireredasr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import torch 5 | 6 | from fireredasr.data.asr_feat import ASRFeatExtractor 7 | from fireredasr.models.fireredasr_aed import FireRedAsrAed 8 | from fireredasr.models.fireredasr_llm import FireRedAsrLlm 9 | from fireredasr.tokenizer.aed_tokenizer import ChineseCharEnglishSpmTokenizer 10 | from fireredasr.tokenizer.llm_tokenizer import LlmTokenizerWrapper 11 | 12 | 13 | class FireRedAsr: 14 | @classmethod 15 | def from_pretrained(cls, asr_type, model_dir): 16 | assert asr_type in ["aed", "llm"] 17 | 18 | cmvn_path = os.path.join(model_dir, "cmvn.ark") 19 | feat_extractor = ASRFeatExtractor(cmvn_path) 20 | 21 | if asr_type == "aed": 22 | model_path = os.path.join(model_dir, "model.pth.tar") 23 | dict_path =os.path.join(model_dir, "dict.txt") 24 | spm_model = os.path.join(model_dir, "train_bpe1000.model") 25 | model = load_fireredasr_aed_model(model_path) 26 | tokenizer = ChineseCharEnglishSpmTokenizer(dict_path, spm_model) 27 | elif asr_type == "llm": 28 | model_path = os.path.join(model_dir, "model.pth.tar") 29 | encoder_path = os.path.join(model_dir, "asr_encoder.pth.tar") 30 | llm_dir = os.path.join(model_dir, "Qwen2-7B-Instruct") 31 | model, tokenizer = load_firered_llm_model_and_tokenizer( 32 | model_path, encoder_path, llm_dir) 33 | model.eval() 34 | return cls(asr_type, feat_extractor, model, tokenizer) 35 | 36 | def __init__(self, asr_type, feat_extractor, model, tokenizer): 37 | self.asr_type = asr_type 38 | self.feat_extractor = feat_extractor 39 | self.model = model 40 | self.tokenizer = tokenizer 41 | 42 | @torch.no_grad() 43 | def transcribe(self, batch_uttid, batch_wav_path, args={}): 44 | feats, lengths, durs = self.feat_extractor(batch_wav_path) 45 | total_dur = sum(durs) 46 | if args.get("use_gpu", False): 47 | feats, lengths = feats.cuda(), lengths.cuda() 48 | self.model.cuda() 49 | else: 50 | self.model.cpu() 51 | 52 | if self.asr_type == "aed": 53 | start_time = time.time() 54 | 55 | hyps = self.model.transcribe( 56 | feats, lengths, 57 | args.get("beam_size", 1), 58 | args.get("nbest", 1), 59 | args.get("decode_max_len", 0), 60 | args.get("softmax_smoothing", 1.0), 61 | args.get("aed_length_penalty", 0.0), 62 | args.get("eos_penalty", 1.0) 63 | ) 64 | 65 | elapsed = time.time() - start_time 66 | rtf= elapsed / total_dur if total_dur > 0 else 0 67 | 68 | results = [] 69 | for uttid, wav, hyp in zip(batch_uttid, batch_wav_path, hyps): 70 | hyp = hyp[0] # only return 1-best 71 | hyp_ids = [int(id) for id in hyp["yseq"].cpu()] 72 | text = self.tokenizer.detokenize(hyp_ids) 73 | results.append({"uttid": uttid, "text": text, "wav": wav, 74 | "rtf": f"{rtf:.4f}"}) 75 | return results 76 | 77 | elif self.asr_type == "llm": 78 | input_ids, attention_mask, _, _ = \ 79 | LlmTokenizerWrapper.preprocess_texts( 80 | origin_texts=[""]*feats.size(0), tokenizer=self.tokenizer, 81 | max_len=128, decode=True) 82 | if args.get("use_gpu", False): 83 | input_ids = input_ids.cuda() 84 | attention_mask = attention_mask.cuda() 85 | start_time = time.time() 86 | 87 | generated_ids = self.model.transcribe( 88 | feats, lengths, input_ids, attention_mask, 89 | args.get("beam_size", 1), 90 | args.get("decode_max_len", 0), 91 | args.get("decode_min_len", 0), 92 | args.get("repetition_penalty", 1.0), 93 | args.get("llm_length_penalty", 0.0), 94 | args.get("temperature", 1.0) 95 | ) 96 | 97 | elapsed = time.time() - start_time 98 | rtf= elapsed / total_dur if total_dur > 0 else 0 99 | texts = self.tokenizer.batch_decode(generated_ids, 100 | skip_special_tokens=True) 101 | results = [] 102 | for uttid, wav, text in zip(batch_uttid, batch_wav_path, texts): 103 | results.append({"uttid": uttid, "text": text, "wav": wav, 104 | "rtf": f"{rtf:.4f}"}) 105 | return results 106 | 107 | 108 | 109 | def load_fireredasr_aed_model(model_path): 110 | package = torch.load(model_path, map_location=lambda storage, loc: storage) 111 | print("model args:", package["args"]) 112 | model = FireRedAsrAed.from_args(package["args"]) 113 | model.load_state_dict(package["model_state_dict"], strict=True) 114 | return model 115 | 116 | 117 | def load_firered_llm_model_and_tokenizer(model_path, encoder_path, llm_dir): 118 | package = torch.load(model_path, map_location=lambda storage, loc: storage) 119 | package["args"].encoder_path = encoder_path 120 | package["args"].llm_dir = llm_dir 121 | print("model args:", package["args"]) 122 | model = FireRedAsrLlm.from_args(package["args"]) 123 | model.load_state_dict(package["model_state_dict"], strict=False) 124 | tokenizer = LlmTokenizerWrapper.build_llm_tokenizer(llm_dir) 125 | return model, tokenizer 126 | -------------------------------------------------------------------------------- /fireredasr/models/fireredasr_aed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from fireredasr.models.module.conformer_encoder import ConformerEncoder 4 | from fireredasr.models.module.transformer_decoder import TransformerDecoder 5 | 6 | 7 | class FireRedAsrAed(torch.nn.Module): 8 | @classmethod 9 | def from_args(cls, args): 10 | return cls(args) 11 | 12 | def __init__(self, args): 13 | super().__init__() 14 | self.sos_id = args.sos_id 15 | self.eos_id = args.eos_id 16 | 17 | self.encoder = ConformerEncoder( 18 | args.idim, args.n_layers_enc, args.n_head, args.d_model, 19 | args.residual_dropout, args.dropout_rate, 20 | args.kernel_size, args.pe_maxlen) 21 | 22 | self.decoder = TransformerDecoder( 23 | args.sos_id, args.eos_id, args.pad_id, args.odim, 24 | args.n_layers_dec, args.n_head, args.d_model, 25 | args.residual_dropout, args.pe_maxlen) 26 | 27 | def transcribe(self, padded_input, input_lengths, 28 | beam_size=1, nbest=1, decode_max_len=0, 29 | softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0): 30 | enc_outputs, _, enc_mask = self.encoder(padded_input, input_lengths) 31 | nbest_hyps = self.decoder.batch_beam_search( 32 | enc_outputs, enc_mask, 33 | beam_size, nbest, decode_max_len, 34 | softmax_smoothing, length_penalty, eos_penalty) 35 | return nbest_hyps 36 | -------------------------------------------------------------------------------- /fireredasr/models/fireredasr_llm.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | import re 5 | 6 | import torch 7 | import torch.nn as nn 8 | from transformers import AutoModelForCausalLM 9 | 10 | from fireredasr.models.fireredasr_aed import FireRedAsrAed 11 | from fireredasr.models.module.adapter import Adapter 12 | from fireredasr.tokenizer.llm_tokenizer import DEFAULT_SPEECH_TOKEN, IGNORE_TOKEN_ID 13 | from fireredasr.tokenizer.llm_tokenizer import LlmTokenizerWrapper 14 | from fireredasr.utils.param import count_model_parameters 15 | 16 | 17 | class FireRedAsrLlm(nn.Module): 18 | @classmethod 19 | def load_encoder(cls, model_path): 20 | assert os.path.exists(model_path) 21 | package = torch.load(model_path, map_location=lambda storage, loc: storage) 22 | model = FireRedAsrAed.from_args(package["args"]) 23 | if "model_state_dict" in package: 24 | model.load_state_dict(package["model_state_dict"], strict=False) 25 | encoder = model.encoder 26 | encoder_dim = encoder.odim 27 | return encoder, encoder_dim 28 | 29 | @classmethod 30 | def from_args(cls, args): 31 | logging.info(args) 32 | logging.info("Build FireRedAsrLlm") 33 | # Build Speech Encoder 34 | encoder, encoder_dim = cls.load_encoder(args.encoder_path) 35 | count_model_parameters(encoder) 36 | if args.freeze_encoder: 37 | logging.info(f"Frezee encoder") 38 | for name, param in encoder.named_parameters(): 39 | param.requires_grad = False 40 | encoder.eval() 41 | 42 | if args.use_flash_attn: 43 | attn_implementation = "flash_attention_2" 44 | if args.use_fp16: 45 | torch_dtype = torch.float16 46 | else: 47 | torch_dtype = torch.float32 48 | else: 49 | attn_implementation = "eager" 50 | if args.use_fp16: 51 | torch_dtype = torch.float16 52 | else: 53 | torch_dtype = torch.float32 54 | 55 | # Build LLM 56 | llm = AutoModelForCausalLM.from_pretrained( 57 | args.llm_dir, 58 | attn_implementation=attn_implementation, 59 | torch_dtype=torch_dtype, 60 | ) 61 | count_model_parameters(llm) 62 | 63 | # LLM Freeze or LoRA 64 | llm_dim = llm.config.hidden_size 65 | if args.freeze_llm: 66 | logging.info(f"Frezee LLM") 67 | for name, param in llm.named_parameters(): 68 | param.requires_grad = False 69 | llm.eval() 70 | else: 71 | if args.use_lora: 72 | from peft import LoraConfig, get_peft_model 73 | lora_config = LoraConfig( 74 | r=64, 75 | lora_alpha=16, 76 | target_modules=[ 77 | "q_proj", 78 | "k_proj", 79 | "v_proj", 80 | "o_proj", 81 | "up_proj", 82 | "gate_proj", 83 | "down_proj", 84 | ], 85 | lora_dropout=0.05, 86 | task_type="CAUSAL_LM", 87 | ) 88 | llm = get_peft_model(llm, lora_config) 89 | llm.print_trainable_parameters() 90 | 91 | tokenizer = LlmTokenizerWrapper.build_llm_tokenizer(args.llm_dir) 92 | assert tokenizer.pad_token_id == tokenizer.convert_tokens_to_ids("<|endoftext|>") 93 | llm.config.pad_token_id = tokenizer.pad_token_id 94 | llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>") 95 | llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>") 96 | llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids( 97 | DEFAULT_SPEECH_TOKEN 98 | ) 99 | 100 | # Build projector 101 | encoder_projector = Adapter( 102 | encoder_dim, llm_dim, args.encoder_downsample_rate) 103 | count_model_parameters(encoder_projector) 104 | 105 | return cls(encoder, llm, encoder_projector, 106 | args.freeze_encoder, args.freeze_llm) 107 | 108 | def __init__(self, encoder, llm, encoder_projector, 109 | freeze_encoder, freeze_llm): 110 | super().__init__() 111 | self.encoder = encoder 112 | self.llm = llm 113 | self.encoder_projector = encoder_projector 114 | # args 115 | self.freeze_encoder = freeze_encoder 116 | self.freeze_llm = freeze_llm 117 | self.llm_config = llm.config 118 | 119 | def transcribe(self, padded_feat, feat_lengths, padded_input_ids, attention_mask, 120 | beam_size=1, decode_max_len=0, decode_min_len=0, 121 | repetition_penalty=1.0, llm_length_penalty=1.0, temperature=1.0): 122 | encoder_outs, enc_lengths, enc_mask = self.encoder(padded_feat, feat_lengths) 123 | speech_features, speech_lens = self.encoder_projector(encoder_outs, enc_lengths) 124 | inputs_embeds = self.llm.get_input_embeddings()(padded_input_ids) 125 | 126 | inputs_embeds, attention_mask, _ = \ 127 | self._merge_input_ids_with_speech_features( 128 | speech_features.to(inputs_embeds.dtype), inputs_embeds, padded_input_ids, attention_mask, 129 | speech_lens=speech_lens 130 | ) 131 | 132 | max_new_tokens = speech_features.size(1) if decode_max_len < 1 else decode_max_len 133 | max_new_tokens = max(1, max_new_tokens) 134 | 135 | generated_ids = self.llm.generate( 136 | inputs_embeds=inputs_embeds, 137 | max_new_tokens=max_new_tokens, 138 | num_beams=beam_size, 139 | do_sample=False, 140 | min_length=decode_min_len, 141 | top_p=1.0, 142 | repetition_penalty=repetition_penalty, 143 | length_penalty=llm_length_penalty, 144 | temperature=temperature, 145 | bos_token_id=self.llm.config.bos_token_id, 146 | eos_token_id=self.llm.config.eos_token_id, 147 | pad_token_id=self.llm.config.pad_token_id, 148 | ) 149 | 150 | return generated_ids 151 | 152 | 153 | def _merge_input_ids_with_speech_features( 154 | self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None, 155 | speech_lens=None 156 | ): 157 | """ 158 | Modified from: https://github.com/k2-fsa/icefall/blob/master/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py 159 | """ 160 | speech_lens = None 161 | num_speechs, speech_len, embed_dim = speech_features.shape 162 | batch_size, sequence_length = input_ids.shape 163 | left_padding = not torch.sum( 164 | input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id) 165 | ) 166 | # 1. Create a mask to know where special speech tokens are 167 | special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id 168 | num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1) 169 | # Compute the maximum embed dimension 170 | max_embed_dim = ( 171 | num_special_speech_tokens.max() * (speech_len - 1) 172 | ) + sequence_length 173 | batch_indices, non_speech_indices = torch.where( 174 | input_ids != self.llm.config.default_speech_token_id 175 | ) 176 | 177 | # 2. Compute the positions where text should be written 178 | # Calculate new positions for text tokens in merged speech-text sequence. 179 | # `special_speech_token_mask` identifies speech tokens. Each speech token will be replaced by `nb_text_tokens_per_speechs - 1` text tokens. 180 | # `torch.cumsum` computes how each speech token shifts subsequent text token positions. 181 | # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. 182 | new_token_positions = ( 183 | torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1 184 | ) # (N,U) 185 | nb_speech_pad = max_embed_dim - 1 - new_token_positions[:, -1] 186 | if left_padding: 187 | new_token_positions += nb_speech_pad[:, None] # offset for left padding 188 | text_to_overwrite = new_token_positions[batch_indices, non_speech_indices] 189 | 190 | # 3. Create the full embedding, already padded to the maximum position 191 | final_embedding = torch.zeros( 192 | batch_size, 193 | max_embed_dim, 194 | embed_dim, 195 | dtype=inputs_embeds.dtype, 196 | device=inputs_embeds.device, 197 | ) 198 | final_attention_mask = torch.zeros( 199 | batch_size, 200 | max_embed_dim, 201 | dtype=attention_mask.dtype, 202 | device=inputs_embeds.device, 203 | ) 204 | if labels is not None: 205 | final_labels = torch.full( 206 | (batch_size, max_embed_dim), 207 | IGNORE_TOKEN_ID, 208 | dtype=input_ids.dtype, 209 | device=input_ids.device, 210 | ) 211 | # In case the Vision model or the Language model has been offloaded to CPU, we need to manually 212 | # set the corresponding tensors into their correct target device. 213 | target_device = inputs_embeds.device 214 | batch_indices, non_speech_indices, text_to_overwrite = ( 215 | batch_indices.to(target_device), 216 | non_speech_indices.to(target_device), 217 | text_to_overwrite.to(target_device), 218 | ) 219 | attention_mask = attention_mask.to(target_device) 220 | 221 | # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] 222 | # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the speech features 223 | final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[ 224 | batch_indices, non_speech_indices 225 | ] 226 | final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[ 227 | batch_indices, non_speech_indices 228 | ] 229 | if labels is not None: 230 | final_labels[batch_indices, text_to_overwrite] = labels[ 231 | batch_indices, non_speech_indices 232 | ] 233 | 234 | # 5. Fill the embeddings corresponding to the speechs. Anything that is not `text_positions` needs filling (#29835) 235 | speech_to_overwrite = torch.full( 236 | (batch_size, max_embed_dim), 237 | True, 238 | dtype=torch.bool, 239 | device=inputs_embeds.device, 240 | ) 241 | speech_to_overwrite[batch_indices, text_to_overwrite] = False 242 | if speech_lens is not None: 243 | speech_pad_position = speech_to_overwrite.cumsum(-1) <= speech_lens[:, None] 244 | speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[ 245 | :, None 246 | ].to(target_device) 247 | 248 | if speech_to_overwrite.sum() != speech_features.shape[:-1].numel(): 249 | raise ValueError( 250 | f"The input provided to the model are wrong. The number of speech tokens is {torch.sum(special_speech_token_mask)} while" 251 | f" the number of speech given to the model is {num_speechs}. This prevents correct indexing and breaks batch generation." 252 | ) 253 | 254 | final_embedding[speech_to_overwrite] = ( 255 | speech_features.contiguous().reshape(-1, embed_dim).to(target_device) 256 | ) 257 | if speech_lens is not None: 258 | speech_to_overwrite &= speech_pad_position 259 | final_attention_mask |= speech_to_overwrite 260 | 261 | # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. 262 | batch_indices, pad_indices = torch.where( 263 | input_ids == self.llm.config.pad_token_id 264 | ) 265 | indices_to_mask = new_token_positions[batch_indices, pad_indices] 266 | 267 | final_embedding[batch_indices, indices_to_mask] = 0 268 | 269 | if labels is None: 270 | final_labels = None 271 | 272 | return final_embedding, final_attention_mask, final_labels #, position_ids 273 | -------------------------------------------------------------------------------- /fireredasr/models/module/adapter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Adapter(nn.Module): 6 | def __init__(self, encoder_dim, llm_dim, downsample_rate=2): 7 | super().__init__() 8 | self.ds = downsample_rate 9 | self.linear1 = nn.Linear(encoder_dim * downsample_rate, llm_dim) 10 | self.relu = nn.ReLU() 11 | self.linear2 = nn.Linear(llm_dim, llm_dim) 12 | 13 | def forward(self, x, x_lens): 14 | batch_size, seq_len, feat_dim = x.size() 15 | num_frames_to_discard = seq_len % self.ds 16 | if num_frames_to_discard > 0: 17 | x = x[:, :-num_frames_to_discard, :] 18 | seq_len = x.size(1) 19 | 20 | x = x.contiguous() 21 | x = x.view( 22 | batch_size, seq_len // self.ds, feat_dim * self.ds 23 | ) 24 | 25 | x = self.linear1(x) 26 | x = self.relu(x) 27 | x = self.linear2(x) 28 | 29 | new_x_lens = torch.clamp(x_lens, max=seq_len) // self.ds 30 | return x, new_x_lens 31 | -------------------------------------------------------------------------------- /fireredasr/models/module/conformer_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ConformerEncoder(nn.Module): 7 | def __init__(self, idim, n_layers, n_head, d_model, 8 | residual_dropout=0.1, dropout_rate=0.1, kernel_size=33, 9 | pe_maxlen=5000): 10 | super().__init__() 11 | self.odim = d_model 12 | 13 | self.input_preprocessor = Conv2dSubsampling(idim, d_model) 14 | self.positional_encoding = RelPositionalEncoding(d_model) 15 | self.dropout = nn.Dropout(residual_dropout) 16 | 17 | self.layer_stack = nn.ModuleList() 18 | for l in range(n_layers): 19 | block = RelPosEmbConformerBlock(d_model, n_head, 20 | residual_dropout, 21 | dropout_rate, kernel_size) 22 | self.layer_stack.append(block) 23 | 24 | def forward(self, padded_input, input_lengths, pad=True): 25 | if pad: 26 | padded_input = F.pad(padded_input, 27 | (0, 0, 0, self.input_preprocessor.context - 1), 'constant', 0.0) 28 | src_mask = self.padding_position_is_0(padded_input, input_lengths) 29 | 30 | embed_output, input_lengths, src_mask = self.input_preprocessor(padded_input, src_mask) 31 | enc_output = self.dropout(embed_output) 32 | 33 | pos_emb = self.dropout(self.positional_encoding(embed_output)) 34 | 35 | enc_outputs = [] 36 | for enc_layer in self.layer_stack: 37 | enc_output = enc_layer(enc_output, pos_emb, slf_attn_mask=src_mask, 38 | pad_mask=src_mask) 39 | enc_outputs.append(enc_output) 40 | 41 | return enc_output, input_lengths, src_mask 42 | 43 | def padding_position_is_0(self, padded_input, input_lengths): 44 | N, T = padded_input.size()[:2] 45 | mask = torch.ones((N, T)).to(padded_input.device) 46 | for i in range(N): 47 | mask[i, input_lengths[i]:] = 0 48 | mask = mask.unsqueeze(dim=1) 49 | return mask.to(torch.uint8) 50 | 51 | 52 | class RelPosEmbConformerBlock(nn.Module): 53 | def __init__(self, d_model, n_head, 54 | residual_dropout=0.1, 55 | dropout_rate=0.1, kernel_size=33): 56 | super().__init__() 57 | self.ffn1 = ConformerFeedForward(d_model, dropout_rate) 58 | self.mhsa = RelPosMultiHeadAttention(n_head, d_model, 59 | residual_dropout) 60 | self.conv = ConformerConvolution(d_model, kernel_size, 61 | dropout_rate) 62 | self.ffn2 = ConformerFeedForward(d_model, dropout_rate) 63 | self.layer_norm = nn.LayerNorm(d_model) 64 | 65 | def forward(self, x, pos_emb, slf_attn_mask=None, pad_mask=None): 66 | out = 0.5 * x + 0.5 * self.ffn1(x) 67 | out = self.mhsa(out, out, out, pos_emb, mask=slf_attn_mask)[0] 68 | out = self.conv(out, pad_mask) 69 | out = 0.5 * out + 0.5 * self.ffn2(out) 70 | out = self.layer_norm(out) 71 | return out 72 | 73 | 74 | class Swish(nn.Module): 75 | def forward(self, x): 76 | return x * torch.sigmoid(x) 77 | 78 | 79 | class Conv2dSubsampling(nn.Module): 80 | def __init__(self, idim, d_model, out_channels=32): 81 | super().__init__() 82 | self.conv = nn.Sequential( 83 | nn.Conv2d(1, out_channels, 3, 2), 84 | nn.ReLU(), 85 | nn.Conv2d(out_channels, out_channels, 3, 2), 86 | nn.ReLU(), 87 | ) 88 | subsample_idim = ((idim - 1) // 2 - 1) // 2 89 | self.out = nn.Linear(out_channels * subsample_idim, d_model) 90 | 91 | self.subsampling = 4 92 | left_context = right_context = 3 # both exclude currect frame 93 | self.context = left_context + 1 + right_context # 7 94 | 95 | def forward(self, x, x_mask): 96 | x = x.unsqueeze(1) 97 | x = self.conv(x) 98 | N, C, T, D = x.size() 99 | x = self.out(x.transpose(1, 2).contiguous().view(N, T, C * D)) 100 | mask = x_mask[:, :, :-2:2][:, :, :-2:2] 101 | input_lengths = mask[:, -1, :].sum(dim=-1) 102 | return x, input_lengths, mask 103 | 104 | 105 | class RelPositionalEncoding(torch.nn.Module): 106 | def __init__(self, d_model, max_len=5000): 107 | super().__init__() 108 | pe_positive = torch.zeros(max_len, d_model, requires_grad=False) 109 | pe_negative = torch.zeros(max_len, d_model, requires_grad=False) 110 | position = torch.arange(0, max_len).unsqueeze(1).float() 111 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * 112 | -(torch.log(torch.tensor(10000.0)).item()/d_model)) 113 | pe_positive[:, 0::2] = torch.sin(position * div_term) 114 | pe_positive[:, 1::2] = torch.cos(position * div_term) 115 | pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) 116 | pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) 117 | 118 | pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) 119 | pe_negative = pe_negative[1:].unsqueeze(0) 120 | pe = torch.cat([pe_positive, pe_negative], dim=1) 121 | self.register_buffer('pe', pe) 122 | 123 | def forward(self, x): 124 | # Tmax = 2 * max_len - 1 125 | Tmax, T = self.pe.size(1), x.size(1) 126 | pos_emb = self.pe[:, Tmax // 2 - T + 1 : Tmax // 2 + T].clone().detach() 127 | return pos_emb 128 | 129 | 130 | class ConformerFeedForward(nn.Module): 131 | def __init__(self, d_model, dropout_rate=0.1): 132 | super().__init__() 133 | pre_layer_norm = nn.LayerNorm(d_model) 134 | linear_expand = nn.Linear(d_model, d_model*4) 135 | nonlinear = Swish() 136 | dropout_pre = nn.Dropout(dropout_rate) 137 | linear_project = nn.Linear(d_model*4, d_model) 138 | dropout_post = nn.Dropout(dropout_rate) 139 | self.net = nn.Sequential(pre_layer_norm, 140 | linear_expand, 141 | nonlinear, 142 | dropout_pre, 143 | linear_project, 144 | dropout_post) 145 | 146 | def forward(self, x): 147 | residual = x 148 | output = self.net(x) 149 | output = output + residual 150 | return output 151 | 152 | 153 | class ConformerConvolution(nn.Module): 154 | def __init__(self, d_model, kernel_size=33, dropout_rate=0.1): 155 | super().__init__() 156 | assert kernel_size % 2 == 1 157 | self.pre_layer_norm = nn.LayerNorm(d_model) 158 | self.pointwise_conv1 = nn.Conv1d(d_model, d_model*4, kernel_size=1, bias=False) 159 | self.glu = F.glu 160 | self.padding = (kernel_size - 1) // 2 161 | self.depthwise_conv = nn.Conv1d(d_model*2, d_model*2, 162 | kernel_size, stride=1, 163 | padding=self.padding, 164 | groups=d_model*2, bias=False) 165 | self.batch_norm = nn.LayerNorm(d_model*2) 166 | self.swish = Swish() 167 | self.pointwise_conv2 = nn.Conv1d(d_model*2, d_model, kernel_size=1, bias=False) 168 | self.dropout = nn.Dropout(dropout_rate) 169 | 170 | def forward(self, x, mask=None): 171 | residual = x 172 | out = self.pre_layer_norm(x) 173 | out = out.transpose(1, 2) 174 | if mask is not None: 175 | out.masked_fill_(mask.ne(1), 0.0) 176 | out = self.pointwise_conv1(out) 177 | out = F.glu(out, dim=1) 178 | out = self.depthwise_conv(out) 179 | 180 | out = out.transpose(1, 2) 181 | out = self.swish(self.batch_norm(out)) 182 | out = out.transpose(1, 2) 183 | 184 | out = self.dropout(self.pointwise_conv2(out)) 185 | if mask is not None: 186 | out.masked_fill_(mask.ne(1), 0.0) 187 | out = out.transpose(1, 2) 188 | return out + residual 189 | 190 | 191 | class EncoderMultiHeadAttention(nn.Module): 192 | def __init__(self, n_head, d_model, 193 | residual_dropout=0.1): 194 | super().__init__() 195 | assert d_model % n_head == 0 196 | self.n_head = n_head 197 | self.d_k = d_model // n_head 198 | self.d_v = self.d_k 199 | 200 | self.w_qs = nn.Linear(d_model, n_head * self.d_k, bias=False) 201 | self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False) 202 | self.w_vs = nn.Linear(d_model, n_head * self.d_v, bias=False) 203 | 204 | self.layer_norm_q = nn.LayerNorm(d_model) 205 | self.layer_norm_k = nn.LayerNorm(d_model) 206 | self.layer_norm_v = nn.LayerNorm(d_model) 207 | 208 | self.attention = ScaledDotProductAttention(temperature=self.d_k ** 0.5) 209 | self.fc = nn.Linear(n_head * self.d_v, d_model, bias=False) 210 | self.dropout = nn.Dropout(residual_dropout) 211 | 212 | def forward(self, q, k, v, mask=None): 213 | sz_b, len_q = q.size(0), q.size(1) 214 | 215 | residual = q 216 | q, k, v = self.forward_qkv(q, k, v) 217 | 218 | output, attn = self.attention(q, k, v, mask=mask) 219 | 220 | output = self.forward_output(output, residual, sz_b, len_q) 221 | return output, attn 222 | 223 | def forward_qkv(self, q, k, v): 224 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 225 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) 226 | 227 | q = self.layer_norm_q(q) 228 | k = self.layer_norm_k(k) 229 | v = self.layer_norm_v(v) 230 | 231 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 232 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 233 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 234 | q = q.transpose(1, 2) 235 | k = k.transpose(1, 2) 236 | v = v.transpose(1, 2) 237 | return q, k, v 238 | 239 | def forward_output(self, output, residual, sz_b, len_q): 240 | output = output.transpose(1, 2).contiguous().view(sz_b, len_q, -1) 241 | fc_out = self.fc(output) 242 | output = self.dropout(fc_out) 243 | output = output + residual 244 | return output 245 | 246 | 247 | class ScaledDotProductAttention(nn.Module): 248 | def __init__(self, temperature): 249 | super().__init__() 250 | self.temperature = temperature 251 | self.dropout = nn.Dropout(0.0) 252 | self.INF = float('inf') 253 | 254 | def forward(self, q, k, v, mask=None): 255 | attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature 256 | output, attn = self.forward_attention(attn, v, mask) 257 | return output, attn 258 | 259 | def forward_attention(self, attn, v, mask=None): 260 | if mask is not None: 261 | mask = mask.unsqueeze(1) 262 | mask = mask.eq(0) 263 | attn = attn.masked_fill(mask, -self.INF) 264 | attn = torch.softmax(attn, dim=-1).masked_fill(mask, 0.0) 265 | else: 266 | attn = torch.softmax(attn, dim=-1) 267 | 268 | d_attn = self.dropout(attn) 269 | output = torch.matmul(d_attn, v) 270 | 271 | return output, attn 272 | 273 | 274 | class RelPosMultiHeadAttention(EncoderMultiHeadAttention): 275 | def __init__(self, n_head, d_model, 276 | residual_dropout=0.1): 277 | super().__init__(n_head, d_model, 278 | residual_dropout) 279 | d_k = d_model // n_head 280 | self.scale = 1.0 / (d_k ** 0.5) 281 | self.linear_pos = nn.Linear(d_model, n_head * d_k, bias=False) 282 | self.pos_bias_u = nn.Parameter(torch.FloatTensor(n_head, d_k)) 283 | self.pos_bias_v = nn.Parameter(torch.FloatTensor(n_head, d_k)) 284 | torch.nn.init.xavier_uniform_(self.pos_bias_u) 285 | torch.nn.init.xavier_uniform_(self.pos_bias_v) 286 | 287 | def _rel_shift(self, x): 288 | N, H, T1, T2 = x.size() 289 | zero_pad = torch.zeros((N, H, T1, 1), device=x.device, dtype=x.dtype) 290 | x_padded = torch.cat([zero_pad, x], dim=-1) 291 | 292 | x_padded = x_padded.view(N, H, T2 + 1, T1) 293 | x = x_padded[:, :, 1:].view_as(x) 294 | x = x[:, :, :, : x.size(-1) // 2 + 1] 295 | return x 296 | 297 | def forward(self, q, k, v, pos_emb, mask=None): 298 | sz_b, len_q = q.size(0), q.size(1) 299 | 300 | residual = q 301 | q, k, v = self.forward_qkv(q, k, v) 302 | 303 | q = q.transpose(1, 2) 304 | n_batch_pos = pos_emb.size(0) 305 | p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.n_head, self.d_k) 306 | p = p.transpose(1, 2) 307 | 308 | q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) 309 | q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) 310 | 311 | matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) 312 | 313 | matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) 314 | matrix_bd = self._rel_shift(matrix_bd) 315 | 316 | attn_scores = matrix_ac + matrix_bd 317 | attn_scores.mul_(self.scale) 318 | 319 | output, attn = self.attention.forward_attention(attn_scores, v, mask=mask) 320 | 321 | output = self.forward_output(output, residual, sz_b, len_q) 322 | return output, attn 323 | -------------------------------------------------------------------------------- /fireredasr/models/module/transformer_decoder.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Dict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | 8 | 9 | class TransformerDecoder(nn.Module): 10 | def __init__( 11 | self, sos_id, eos_id, pad_id, odim, 12 | n_layers, n_head, d_model, 13 | residual_dropout=0.1, pe_maxlen=5000): 14 | super().__init__() 15 | self.INF = 1e10 16 | # parameters 17 | self.pad_id = pad_id 18 | self.sos_id = sos_id 19 | self.eos_id = eos_id 20 | self.n_layers = n_layers 21 | 22 | # Components 23 | self.tgt_word_emb = nn.Embedding(odim, d_model, padding_idx=self.pad_id) 24 | self.positional_encoding = PositionalEncoding(d_model, max_len=pe_maxlen) 25 | self.dropout = nn.Dropout(residual_dropout) 26 | 27 | self.layer_stack = nn.ModuleList() 28 | for l in range(n_layers): 29 | block = DecoderLayer(d_model, n_head, residual_dropout) 30 | self.layer_stack.append(block) 31 | 32 | self.tgt_word_prj = nn.Linear(d_model, odim, bias=False) 33 | self.layer_norm_out = nn.LayerNorm(d_model) 34 | 35 | self.tgt_word_prj.weight = self.tgt_word_emb.weight 36 | self.scale = (d_model ** 0.5) 37 | 38 | def batch_beam_search(self, encoder_outputs, src_masks, 39 | beam_size=1, nbest=1, decode_max_len=0, 40 | softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0): 41 | B = beam_size 42 | N, Ti, H = encoder_outputs.size() 43 | device = encoder_outputs.device 44 | maxlen = decode_max_len if decode_max_len > 0 else Ti 45 | assert eos_penalty > 0.0 and eos_penalty <= 1.0 46 | 47 | # Init 48 | encoder_outputs = encoder_outputs.unsqueeze(1).repeat(1, B, 1, 1).view(N*B, Ti, H) 49 | src_mask = src_masks.unsqueeze(1).repeat(1, B, 1, 1).view(N*B, -1, Ti) 50 | ys = torch.ones(N*B, 1).fill_(self.sos_id).long().to(device) 51 | caches: List[Optional[Tensor]] = [] 52 | for _ in range(self.n_layers): 53 | caches.append(None) 54 | scores = torch.tensor([0.0] + [-self.INF]*(B-1)).float().to(device) 55 | scores = scores.repeat(N).view(N*B, 1) 56 | is_finished = torch.zeros_like(scores) 57 | 58 | # Autoregressive Prediction 59 | for t in range(maxlen): 60 | tgt_mask = self.ignored_target_position_is_0(ys, self.pad_id) 61 | 62 | dec_output = self.dropout( 63 | self.tgt_word_emb(ys) * self.scale + 64 | self.positional_encoding(ys)) 65 | 66 | i = 0 67 | for dec_layer in self.layer_stack: 68 | dec_output = dec_layer.forward( 69 | dec_output, encoder_outputs, 70 | tgt_mask, src_mask, 71 | cache=caches[i]) 72 | caches[i] = dec_output 73 | i += 1 74 | 75 | dec_output = self.layer_norm_out(dec_output) 76 | 77 | t_logit = self.tgt_word_prj(dec_output[:, -1]) 78 | t_scores = F.log_softmax(t_logit / softmax_smoothing, dim=-1) 79 | 80 | if eos_penalty != 1.0: 81 | t_scores[:, self.eos_id] *= eos_penalty 82 | 83 | t_topB_scores, t_topB_ys = torch.topk(t_scores, k=B, dim=1) 84 | t_topB_scores = self.set_finished_beam_score_to_zero(t_topB_scores, is_finished) 85 | t_topB_ys = self.set_finished_beam_y_to_eos(t_topB_ys, is_finished) 86 | 87 | # Accumulated 88 | scores = scores + t_topB_scores 89 | 90 | # Pruning 91 | scores = scores.view(N, B*B) 92 | scores, topB_score_ids = torch.topk(scores, k=B, dim=1) 93 | scores = scores.view(-1, 1) 94 | 95 | topB_row_number_in_each_B_rows_of_ys = torch.div(topB_score_ids, B).view(N*B) 96 | stride = B * torch.arange(N).view(N, 1).repeat(1, B).view(N*B).to(device) 97 | topB_row_number_in_ys = topB_row_number_in_each_B_rows_of_ys.long() + stride.long() 98 | 99 | # Update ys 100 | ys = ys[topB_row_number_in_ys] 101 | t_ys = torch.gather(t_topB_ys.view(N, B*B), dim=1, index=topB_score_ids).view(N*B, 1) 102 | ys = torch.cat((ys, t_ys), dim=1) 103 | 104 | # Update caches 105 | new_caches: List[Optional[Tensor]] = [] 106 | for cache in caches: 107 | if cache is not None: 108 | new_caches.append(cache[topB_row_number_in_ys]) 109 | caches = new_caches 110 | 111 | # Update finished state 112 | is_finished = t_ys.eq(self.eos_id) 113 | if is_finished.sum().item() == N*B: 114 | break 115 | 116 | # Length penalty (follow GNMT) 117 | scores = scores.view(N, B) 118 | ys = ys.view(N, B, -1) 119 | ys_lengths = self.get_ys_lengths(ys) 120 | if length_penalty > 0.0: 121 | penalty = torch.pow((5+ys_lengths.float())/(5.0+1), length_penalty) 122 | scores /= penalty 123 | nbest_scores, nbest_ids = torch.topk(scores, k=int(nbest), dim=1) 124 | nbest_scores = -1.0 * nbest_scores 125 | index = nbest_ids + B * torch.arange(N).view(N, 1).to(device).long() 126 | nbest_ys = ys.view(N*B, -1)[index.view(-1)] 127 | nbest_ys = nbest_ys.view(N, nbest_ids.size(1), -1) 128 | nbest_ys_lengths = ys_lengths.view(N*B)[index.view(-1)].view(N, -1) 129 | 130 | # result 131 | nbest_hyps: List[List[Dict[str, Tensor]]] = [] 132 | for n in range(N): 133 | n_nbest_hyps: List[Dict[str, Tensor]] = [] 134 | for i, score in enumerate(nbest_scores[n]): 135 | new_hyp = { 136 | "yseq": nbest_ys[n, i, 1:nbest_ys_lengths[n, i]] 137 | } 138 | n_nbest_hyps.append(new_hyp) 139 | nbest_hyps.append(n_nbest_hyps) 140 | return nbest_hyps 141 | 142 | def ignored_target_position_is_0(self, padded_targets, ignore_id): 143 | mask = torch.ne(padded_targets, ignore_id) 144 | mask = mask.unsqueeze(dim=1) 145 | T = padded_targets.size(-1) 146 | upper_tri_0_mask = self.upper_triangular_is_0(T).unsqueeze(0).to(mask.dtype) 147 | upper_tri_0_mask = upper_tri_0_mask.to(mask.dtype).to(mask.device) 148 | return mask.to(torch.uint8) & upper_tri_0_mask.to(torch.uint8) 149 | 150 | def upper_triangular_is_0(self, size): 151 | ones = torch.ones(size, size) 152 | tri_left_ones = torch.tril(ones) 153 | return tri_left_ones.to(torch.uint8) 154 | 155 | def set_finished_beam_score_to_zero(self, scores, is_finished): 156 | NB, B = scores.size() 157 | is_finished = is_finished.float() 158 | mask_score = torch.tensor([0.0] + [-self.INF]*(B-1)).float().to(scores.device) 159 | mask_score = mask_score.view(1, B).repeat(NB, 1) 160 | return scores * (1 - is_finished) + mask_score * is_finished 161 | 162 | def set_finished_beam_y_to_eos(self, ys, is_finished): 163 | is_finished = is_finished.long() 164 | return ys * (1 - is_finished) + self.eos_id * is_finished 165 | 166 | def get_ys_lengths(self, ys): 167 | N, B, Tmax = ys.size() 168 | ys_lengths = torch.sum(torch.ne(ys, self.eos_id), dim=-1) 169 | return ys_lengths.int() 170 | 171 | 172 | 173 | class DecoderLayer(nn.Module): 174 | def __init__(self, d_model, n_head, dropout): 175 | super().__init__() 176 | self.self_attn_norm = nn.LayerNorm(d_model) 177 | self.self_attn = DecoderMultiHeadAttention(d_model, n_head, dropout) 178 | 179 | self.cross_attn_norm = nn.LayerNorm(d_model) 180 | self.cross_attn = DecoderMultiHeadAttention(d_model, n_head, dropout) 181 | 182 | self.mlp_norm = nn.LayerNorm(d_model) 183 | self.mlp = PositionwiseFeedForward(d_model, d_model*4, dropout) 184 | 185 | def forward(self, dec_input, enc_output, self_attn_mask, cross_attn_mask, 186 | cache=None): 187 | x = dec_input 188 | residual = x 189 | x = self.self_attn_norm(x) 190 | if cache is not None: 191 | xq = x[:, -1:, :] 192 | residual = residual[:, -1:, :] 193 | self_attn_mask = self_attn_mask[:, -1:, :] 194 | else: 195 | xq = x 196 | x = self.self_attn(xq, x, x, mask=self_attn_mask) 197 | x = residual + x 198 | 199 | residual = x 200 | x = self.cross_attn_norm(x) 201 | x = self.cross_attn(x, enc_output, enc_output, mask=cross_attn_mask) 202 | x = residual + x 203 | 204 | residual = x 205 | x = self.mlp_norm(x) 206 | x = residual + self.mlp(x) 207 | 208 | if cache is not None: 209 | x = torch.cat([cache, x], dim=1) 210 | 211 | return x 212 | 213 | 214 | class DecoderMultiHeadAttention(nn.Module): 215 | def __init__(self, d_model, n_head, dropout=0.1): 216 | super().__init__() 217 | self.d_model = d_model 218 | self.n_head = n_head 219 | self.d_k = d_model // n_head 220 | 221 | self.w_qs = nn.Linear(d_model, n_head * self.d_k) 222 | self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False) 223 | self.w_vs = nn.Linear(d_model, n_head * self.d_k) 224 | 225 | self.attention = DecoderScaledDotProductAttention( 226 | temperature=self.d_k ** 0.5) 227 | self.fc = nn.Linear(n_head * self.d_k, d_model) 228 | self.dropout = nn.Dropout(dropout) 229 | 230 | def forward(self, q, k, v, mask=None): 231 | bs = q.size(0) 232 | 233 | q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k) 234 | k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k) 235 | v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k) 236 | q = q.transpose(1, 2) 237 | k = k.transpose(1, 2) 238 | v = v.transpose(1, 2) 239 | 240 | if mask is not None: 241 | mask = mask.unsqueeze(1) 242 | 243 | output = self.attention(q, k, v, mask=mask) 244 | 245 | output = output.transpose(1, 2).contiguous().view(bs, -1, self.d_model) 246 | output = self.fc(output) 247 | output = self.dropout(output) 248 | 249 | return output 250 | 251 | 252 | class DecoderScaledDotProductAttention(nn.Module): 253 | def __init__(self, temperature): 254 | super().__init__() 255 | self.temperature = temperature 256 | self.INF = float("inf") 257 | 258 | def forward(self, q, k, v, mask=None): 259 | attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature 260 | if mask is not None: 261 | mask = mask.eq(0) 262 | attn = attn.masked_fill(mask, -self.INF) 263 | attn = torch.softmax(attn, dim=-1).masked_fill(mask, 0.0) 264 | else: 265 | attn = torch.softmax(attn, dim=-1) 266 | output = torch.matmul(attn, v) 267 | return output 268 | 269 | 270 | class PositionwiseFeedForward(nn.Module): 271 | def __init__(self, d_model, d_ff, dropout=0.1): 272 | super().__init__() 273 | self.w_1 = nn.Linear(d_model, d_ff) 274 | self.act = nn.GELU() 275 | self.w_2 = nn.Linear(d_ff, d_model) 276 | self.dropout = nn.Dropout(dropout) 277 | 278 | def forward(self, x): 279 | output = self.w_2(self.act(self.w_1(x))) 280 | output = self.dropout(output) 281 | return output 282 | 283 | 284 | class PositionalEncoding(nn.Module): 285 | def __init__(self, d_model, max_len=5000): 286 | super().__init__() 287 | assert d_model % 2 == 0 288 | pe = torch.zeros(max_len, d_model, requires_grad=False) 289 | position = torch.arange(0, max_len).unsqueeze(1).float() 290 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * 291 | -(torch.log(torch.tensor(10000.0)).item()/d_model)) 292 | pe[:, 0::2] = torch.sin(position * div_term) 293 | pe[:, 1::2] = torch.cos(position * div_term) 294 | pe = pe.unsqueeze(0) 295 | self.register_buffer('pe', pe) 296 | 297 | def forward(self, x): 298 | length = x.size(1) 299 | return self.pe[:, :length].clone().detach() 300 | -------------------------------------------------------------------------------- /fireredasr/speech2text.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import glob 5 | import os 6 | import sys 7 | 8 | from fireredasr.models.fireredasr import FireRedAsr 9 | 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--asr_type', type=str, required=True, choices=["aed", "llm"]) 13 | parser.add_argument('--model_dir', type=str, required=True) 14 | 15 | # Input / Output 16 | parser.add_argument("--wav_path", type=str) 17 | parser.add_argument("--wav_paths", type=str, nargs="*") 18 | parser.add_argument("--wav_dir", type=str) 19 | parser.add_argument("--wav_scp", type=str) 20 | parser.add_argument("--output", type=str) 21 | 22 | # Decode Options 23 | parser.add_argument('--use_gpu', type=int, default=1) 24 | parser.add_argument("--batch_size", type=int, default=1) 25 | parser.add_argument("--beam_size", type=int, default=1) 26 | parser.add_argument("--decode_max_len", type=int, default=0) 27 | # FireRedASR-AED 28 | parser.add_argument("--nbest", type=int, default=1) 29 | parser.add_argument("--softmax_smoothing", type=float, default=1.0) 30 | parser.add_argument("--aed_length_penalty", type=float, default=0.0) 31 | parser.add_argument("--eos_penalty", type=float, default=1.0) 32 | # FireRedASR-LLM 33 | parser.add_argument("--decode_min_len", type=int, default=0) 34 | parser.add_argument("--repetition_penalty", type=float, default=1.0) 35 | parser.add_argument("--llm_length_penalty", type=float, default=0.0) 36 | parser.add_argument("--temperature", type=float, default=1.0) 37 | 38 | 39 | def main(args): 40 | wavs = get_wav_info(args) 41 | fout = open(args.output, "w") if args.output else None 42 | 43 | model = FireRedAsr.from_pretrained(args.asr_type, args.model_dir) 44 | 45 | batch_uttid = [] 46 | batch_wav_path = [] 47 | for i, wav in enumerate(wavs): 48 | uttid, wav_path = wav 49 | batch_uttid.append(uttid) 50 | batch_wav_path.append(wav_path) 51 | if len(batch_wav_path) < args.batch_size and i != len(wavs) - 1: 52 | continue 53 | 54 | results = model.transcribe( 55 | batch_uttid, 56 | batch_wav_path, 57 | { 58 | "use_gpu": args.use_gpu, 59 | "beam_size": args.beam_size, 60 | "nbest": args.nbest, 61 | "decode_max_len": args.decode_max_len, 62 | "softmax_smoothing": args.softmax_smoothing, 63 | "aed_length_penalty": args.aed_length_penalty, 64 | "eos_penalty": args.eos_penalty, 65 | "decode_min_len": args.decode_min_len, 66 | "repetition_penalty": args.repetition_penalty, 67 | "llm_length_penalty": args.llm_length_penalty, 68 | "temperature": args.temperature 69 | } 70 | ) 71 | 72 | for result in results: 73 | print(result) 74 | if fout is not None: 75 | fout.write(f"{result['uttid']}\t{result['text']}\n") 76 | 77 | batch_uttid = [] 78 | batch_wav_path = [] 79 | 80 | 81 | def get_wav_info(args): 82 | """ 83 | Returns: 84 | wavs: list of (uttid, wav_path) 85 | """ 86 | base = lambda p: os.path.basename(p).replace(".wav", "") 87 | if args.wav_path: 88 | wavs = [(base(args.wav_path), args.wav_path)] 89 | elif args.wav_paths and len(args.wav_paths) >= 1: 90 | wavs = [(base(p), p) for p in sorted(args.wav_paths)] 91 | elif args.wav_scp: 92 | wavs = [line.strip().split() for line in open(args.wav_scp)] 93 | elif args.wav_dir: 94 | wavs = glob.glob(f"{args.wav_dir}/**/*.wav", recursive=True) 95 | wavs = [(base(p), p) for p in sorted(wavs)] 96 | else: 97 | raise ValueError("Please provide valid wav info") 98 | print(f"#wavs={len(wavs)}") 99 | return wavs 100 | 101 | 102 | if __name__ == "__main__": 103 | args = parser.parse_args() 104 | print(args) 105 | main(args) 106 | -------------------------------------------------------------------------------- /fireredasr/tokenizer/aed_tokenizer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | 4 | import sentencepiece as spm 5 | 6 | from fireredasr.data.token_dict import TokenDict 7 | 8 | 9 | class ChineseCharEnglishSpmTokenizer: 10 | """ 11 | - One Chinese char is a token. 12 | - Split English word into SPM and one piece is a token. 13 | - Ignore ' ' between Chinese char 14 | - Replace ' ' between English word with "▁" by spm_model 15 | - Need to put SPM piece into dict file 16 | - If not set spm_model, will use English char and 17 | """ 18 | SPM_SPACE = "▁" 19 | 20 | def __init__(self, dict_path, spm_model, unk="", space=""): 21 | self.dict = TokenDict(dict_path, unk=unk) 22 | self.space = space 23 | if spm_model: 24 | self.sp = spm.SentencePieceProcessor() 25 | self.sp.Load(spm_model) 26 | else: 27 | self.sp = None 28 | print("[WRAN] Not set spm_model, will use English char") 29 | print("[WARN] Please check how to deal with ' '(space)") 30 | if self.space not in self.dict: 31 | print("Please add to your dict, or it will be ") 32 | 33 | def tokenize(self, text, replace_punc=True): 34 | #if text == "": 35 | # logging.info(f"empty text") 36 | text = text.upper() 37 | tokens = [] 38 | if replace_punc: 39 | text = re.sub("[,。?!,\.?!]", " ", text) 40 | pattern = re.compile(r'([\u3400-\u4dbf\u4e00-\u9fff])') 41 | parts = pattern.split(text.strip()) 42 | parts = [p for p in parts if len(p.strip()) > 0] 43 | for part in parts: 44 | if pattern.fullmatch(part) is not None: 45 | tokens.append(part) 46 | else: 47 | if self.sp: 48 | for piece in self.sp.EncodeAsPieces(part.strip()): 49 | tokens.append(piece) 50 | else: 51 | for char in part.strip(): 52 | tokens.append(char if char != " " else self.space) 53 | tokens_id = [] 54 | for token in tokens: 55 | tokens_id.append(self.dict.get(token, self.dict.unk)) 56 | return tokens, tokens_id 57 | 58 | def detokenize(self, inputs, join_symbol="", replace_spm_space=True): 59 | """inputs is ids or tokens, do not need self.sp""" 60 | if len(inputs) > 0 and type(inputs[0]) == int: 61 | tokens = [self.dict[id] for id in inputs] 62 | else: 63 | tokens = inputs 64 | s = f"{join_symbol}".join(tokens) 65 | if replace_spm_space: 66 | s = s.replace(self.SPM_SPACE, ' ').strip() 67 | return s 68 | -------------------------------------------------------------------------------- /fireredasr/tokenizer/llm_tokenizer.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import torch 4 | from transformers import AutoTokenizer 5 | from transformers.trainer_pt_utils import LabelSmoother 6 | 7 | DEFAULT_SPEECH_TOKEN = "" 8 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 9 | 10 | 11 | class LlmTokenizerWrapper: 12 | @classmethod 13 | def build_llm_tokenizer(cls, llm_path, use_flash_attn=False): 14 | tokenizer = AutoTokenizer.from_pretrained(llm_path) 15 | if use_flash_attn: 16 | tokenizer.padding_side = "left" 17 | else: 18 | tokenizer.padding_side = "right" 19 | special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]} 20 | tokenizer.add_special_tokens(special_tokens_dict) 21 | return tokenizer 22 | 23 | @classmethod 24 | def clean_text(cls, origin_text): 25 | """remove punc, remove space between Chinese and keep space between English""" 26 | # remove punc 27 | text = re.sub("[,。?!,\.!?《》()\·“”、\\/]", "", origin_text) 28 | # merge space 29 | text = re.sub("\s+", " ", text) 30 | 31 | # remove space between Chinese and keep space between English 32 | pattern = re.compile(r'([\u3400-\u4dbf\u4e00-\u9fff])') # Chinese 33 | parts = pattern.split(text.strip()) 34 | parts = [p for p in parts if len(p.strip()) > 0] 35 | text = "".join(parts) 36 | text = text.strip() 37 | 38 | text = text.lower() 39 | return text 40 | 41 | @classmethod 42 | def preprocess_texts(cls, origin_texts, tokenizer, max_len, decode=False): 43 | messages = [] 44 | clean_texts = [] 45 | for i, origin_text in enumerate(origin_texts): 46 | text = cls.clean_text(origin_text) 47 | clean_texts.append(text) 48 | text = text if not decode else "" 49 | message = [ 50 | {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"}, 51 | {"role": "assistant", "content": text}, 52 | ] 53 | messages.append(message) 54 | 55 | texts = [] 56 | if not decode: 57 | TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}" 58 | else: 59 | TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}" 60 | for i, msg in enumerate(messages): 61 | texts.append( 62 | tokenizer.apply_chat_template( 63 | msg, 64 | tokenize=True, 65 | chat_template=TEMPLATE, 66 | add_generation_prompt=False, 67 | padding="longest", 68 | max_length=max_len, 69 | truncation=True, 70 | ) 71 | ) 72 | 73 | # Padding texts 74 | max_len_texts = max([len(text) for text in texts]) 75 | if tokenizer.padding_side == "right": 76 | texts = [ 77 | text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) 78 | for text in texts 79 | ] 80 | else: 81 | texts = [ 82 | [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text 83 | for text in texts 84 | ] 85 | input_ids = torch.tensor(texts, dtype=torch.int) 86 | 87 | target_ids = input_ids.clone() 88 | target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID 89 | 90 | # first get the indices of the tokens 91 | mask_prompt = True 92 | if mask_prompt: 93 | mask_indices = torch.where( 94 | input_ids == tokenizer.convert_tokens_to_ids("assistant") 95 | ) 96 | for i in range(mask_indices[0].size(0)): 97 | row = mask_indices[0][i] 98 | col = mask_indices[1][i] 99 | target_ids[row, : col + 2] = IGNORE_TOKEN_ID 100 | 101 | attention_mask = input_ids.ne(tokenizer.pad_token_id) 102 | 103 | target_ids = target_ids.type(torch.LongTensor) 104 | input_ids = input_ids.type(torch.LongTensor) 105 | return input_ids, attention_mask, target_ids, clean_texts 106 | -------------------------------------------------------------------------------- /fireredasr/utils/param.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | 5 | 6 | def count_model_parameters(model): 7 | if not isinstance(model, torch.nn.Module): 8 | return 0, 0 9 | name = f"{model.__class__.__name__} {model.__class__}" 10 | num = sum(p.numel() for p in model.parameters() if p.requires_grad) 11 | size = num * 4.0 / 1024.0 / 1024.0 # float32, MB 12 | logging.info(f"#param of {name} is {num} = {size:.1f} MB (float32)") 13 | return num, size 14 | -------------------------------------------------------------------------------- /fireredasr/utils/wer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import re 5 | from collections import OrderedDict 6 | 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--ref", type=str, required=True) 10 | parser.add_argument("--hyp", type=str, required=True) 11 | parser.add_argument("--print_sentence_wer", type=int, default=0) 12 | parser.add_argument("--do_tn", type=int, default=0, help="simple tn by cn2an") 13 | parser.add_argument("--rm_special", type=int, default=0, help="remove <\|.*?\|>") 14 | 15 | 16 | def main(args): 17 | uttid2refs = read_uttid2tokens(args.ref, args.do_tn, args.rm_special) 18 | uttid2hyps = read_uttid2tokens(args.hyp, args.do_tn, args.rm_special) 19 | uttid2wer_info, wer_stat, en_dig_stat = compute_uttid2wer_info( 20 | uttid2refs, uttid2hyps, args.print_sentence_wer) 21 | wer_stat.print() 22 | en_dig_stat.print() 23 | 24 | 25 | def read_uttid2tokens(filename, do_tn=False, rm_special=False): 26 | print(f">>> Read uttid to tokens: {filename}", flush=True) 27 | uttid2tokens = OrderedDict() 28 | uttid2text = read_uttid2text(filename, do_tn, rm_special) 29 | for uttid, text in uttid2text.items(): 30 | tokens = text2tokens(text) 31 | uttid2tokens[uttid] = tokens 32 | return uttid2tokens 33 | 34 | 35 | def read_uttid2text(filename, do_tn=False, rm_special=False): 36 | uttid2text = OrderedDict() 37 | with open(filename, "r", encoding="utf8") as fin: 38 | for i, line in enumerate(fin): 39 | cols = line.split() 40 | if len(cols) == 0: 41 | print("[WARN] empty line, continue", i, flush=True) 42 | continue 43 | assert cols[0] not in uttid2text, f"repeated uttid: {line}" 44 | if len(cols) == 1: 45 | uttid2text[cols[0]] = "" 46 | continue 47 | txt = " ".join(cols[1:]) 48 | if rm_special: 49 | txt = " ".join([t for t in re.split("<\|.*?\|>", txt) if t.strip() != ""]) 50 | if do_tn: 51 | import cn2an 52 | txt = cn2an.transform(txt, "an2cn") 53 | uttid2text[cols[0]] = txt 54 | return uttid2text 55 | 56 | 57 | def text2tokens(text): 58 | PUNCTUATIONS = ",。?!,\.?!"#$%&'()*+-/:;<=>@[\]^_`{|}~⦅⦆「」、 、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·。\":" + "()\[\]{}/;`|=+" 59 | if text == "": 60 | return [] 61 | tokens = [] 62 | 63 | text = re.sub("", "", text) 64 | text = re.sub(r"[%s]+" % PUNCTUATIONS, " ", text) 65 | 66 | pattern = re.compile(r'([\u4e00-\u9fff])') 67 | parts = pattern.split(text.strip().upper()) 68 | parts = [p for p in parts if len(p.strip()) > 0] 69 | for part in parts: 70 | if pattern.fullmatch(part) is not None: 71 | tokens.append(part) 72 | else: 73 | for word in part.strip().split(): 74 | tokens.append(word) 75 | return tokens 76 | 77 | 78 | def compute_uttid2wer_info(refs, hyps, print_sentence_wer=False): 79 | print(f">>> Compute uttid to wer info", flush=True) 80 | 81 | uttid2wer_info = OrderedDict() 82 | wer_stat = WerStats() 83 | en_dig_stat = EnDigStats() 84 | 85 | for uttid, ref in refs.items(): 86 | if uttid not in hyps: 87 | print(f"[WARN] No hyp for {uttid}", flush=True) 88 | continue 89 | hyp = hyps[uttid] 90 | 91 | if len(hyp) - len(ref) >= 8: 92 | print(f"[BidLengthDiff]: {uttid} {len(ref)} {len(hyp)}#{' '.join(ref)}#{' '.join(hyp)}") 93 | #continue 94 | 95 | wer_info = compute_one_wer_info(ref, hyp) 96 | uttid2wer_info[uttid] = wer_info 97 | ns = count_english_ditgit(ref, hyp, wer_info) 98 | wer_stat.add(wer_info) 99 | en_dig_stat.add(*ns) 100 | if print_sentence_wer: 101 | print(f"{uttid} {wer_info}") 102 | 103 | return uttid2wer_info, wer_stat, en_dig_stat 104 | 105 | 106 | COST_SUB = 3 107 | COST_DEL = 3 108 | COST_INS = 3 109 | 110 | ALIGN_CRT = 0 111 | ALIGN_SUB = 1 112 | ALIGN_DEL = 2 113 | ALIGN_INS = 3 114 | ALIGN_END = 4 115 | 116 | 117 | def compute_one_wer_info(ref, hyp): 118 | """Impl minimum edit distance and backtrace. 119 | Args: 120 | ref, hyp: List[str] 121 | Returns: 122 | WerInfo 123 | """ 124 | ref_len = len(ref) 125 | hyp_len = len(hyp) 126 | 127 | class _DpPoint: 128 | def __init__(self, cost, align): 129 | self.cost = cost 130 | self.align = align 131 | 132 | dp = [] 133 | for i in range(0, ref_len + 1): 134 | dp.append([]) 135 | for j in range(0, hyp_len + 1): 136 | dp[-1].append(_DpPoint(i * j, ALIGN_CRT)) 137 | 138 | # Initialize 139 | for i in range(1, hyp_len + 1): 140 | dp[0][i].cost = dp[0][i - 1].cost + COST_INS; 141 | dp[0][i].align = ALIGN_INS 142 | for i in range(1, ref_len + 1): 143 | dp[i][0].cost = dp[i - 1][0].cost + COST_DEL 144 | dp[i][0].align = ALIGN_DEL 145 | 146 | # DP 147 | for i in range(1, ref_len + 1): 148 | for j in range(1, hyp_len + 1): 149 | min_cost = 0 150 | min_align = ALIGN_CRT 151 | if hyp[j - 1] == ref[i - 1]: 152 | min_cost = dp[i - 1][j - 1].cost 153 | min_align = ALIGN_CRT 154 | else: 155 | min_cost = dp[i - 1][j - 1].cost + COST_SUB 156 | min_align = ALIGN_SUB 157 | 158 | del_cost = dp[i - 1][j].cost + COST_DEL 159 | if del_cost < min_cost: 160 | min_cost = del_cost 161 | min_align = ALIGN_DEL 162 | 163 | ins_cost = dp[i][j - 1].cost + COST_INS 164 | if ins_cost < min_cost: 165 | min_cost = ins_cost 166 | min_align = ALIGN_INS 167 | 168 | dp[i][j].cost = min_cost 169 | dp[i][j].align = min_align 170 | 171 | # Backtrace 172 | crt = sub = ins = det = 0 173 | i = ref_len 174 | j = hyp_len 175 | align = [] 176 | while i > 0 or j > 0: 177 | if dp[i][j].align == ALIGN_CRT: 178 | align.append((i, j, ALIGN_CRT)) 179 | i -= 1 180 | j -= 1 181 | crt += 1 182 | elif dp[i][j].align == ALIGN_SUB: 183 | align.append((i, j, ALIGN_SUB)) 184 | i -= 1 185 | j -= 1 186 | sub += 1 187 | elif dp[i][j].align == ALIGN_DEL: 188 | align.append((i, j, ALIGN_DEL)) 189 | i -= 1 190 | det += 1 191 | elif dp[i][j].align == ALIGN_INS: 192 | align.append((i, j, ALIGN_INS)) 193 | j -= 1 194 | ins += 1 195 | 196 | err = sub + det + ins 197 | align.reverse() 198 | wer_info = WerInfo(ref_len, err, crt, sub, det, ins, align) 199 | return wer_info 200 | 201 | 202 | 203 | class WerInfo: 204 | def __init__(self, ref, err, crt, sub, dele, ins, ali): 205 | self.r = ref 206 | self.e = err 207 | self.c = crt 208 | self.s = sub 209 | self.d = dele 210 | self.i = ins 211 | self.ali = ali 212 | r = max(self.r, 1) 213 | self.wer = 100.0 * (self.s + self.d + self.i) / r 214 | 215 | def __repr__(self): 216 | s = f"wer {self.wer:.2f} ref {self.r:2d} sub {self.s:2d} del {self.d:2d} ins {self.i:2d}" 217 | return s 218 | 219 | 220 | class WerStats: 221 | def __init__(self): 222 | self.infos = [] 223 | 224 | def add(self, wer_info): 225 | self.infos.append(wer_info) 226 | 227 | def print(self): 228 | r = sum(info.r for info in self.infos) 229 | if r <= 0: 230 | print(f"REF len is {r}, check") 231 | r = 1 232 | s = sum(info.s for info in self.infos) 233 | d = sum(info.d for info in self.infos) 234 | i = sum(info.i for info in self.infos) 235 | se = 100.0 * s / r 236 | de = 100.0 * d / r 237 | ie = 100.0 * i / r 238 | wer = 100.0 * (s + d + i) / r 239 | sen = max(len(self.infos), 1) 240 | errsen = sum(info.e > 0 for info in self.infos) 241 | ser = 100.0 * errsen / sen 242 | print("-"*80) 243 | print(f"ref{r:6d} sub{s:6d} del{d:6d} ins{i:6d}") 244 | print(f"WER{wer:6.2f} sub{se:6.2f} del{de:6.2f} ins{ie:6.2f}") 245 | print(f"SER{ser:6.2f} = {errsen} / {sen}") 246 | print("-"*80) 247 | 248 | 249 | class EnDigStats: 250 | def __init__(self): 251 | self.n_en_word = 0 252 | self.n_en_correct = 0 253 | self.n_dig_word = 0 254 | self.n_dig_correct = 0 255 | 256 | def add(self, n_en_word, n_en_correct, n_dig_word, n_dig_correct): 257 | self.n_en_word += n_en_word 258 | self.n_en_correct += n_en_correct 259 | self.n_dig_word += n_dig_word 260 | self.n_dig_correct += n_dig_correct 261 | 262 | def print(self): 263 | print(f"English #word={self.n_en_word}, #correct={self.n_en_correct}\n" 264 | f"Digit #word={self.n_dig_word}, #correct={self.n_dig_correct}") 265 | print("-"*80) 266 | 267 | 268 | 269 | def count_english_ditgit(ref, hyp, wer_info): 270 | patt_en = "[a-zA-Z\.\-\']+" 271 | patt_dig = "[0-9]+" 272 | patt_cjk = re.compile(r'([\u4e00-\u9fff])') 273 | n_en_word = 0 274 | n_en_correct = 0 275 | n_dig_word = 0 276 | n_dig_correct = 0 277 | ali = wer_info.ali 278 | for i, token in enumerate(ref): 279 | if re.match(patt_en, token): 280 | n_en_word += 1 281 | for y in ali: 282 | if y[0] == i+1 and y[2] == ALIGN_CRT: 283 | j = y[1] - 1 284 | n_en_correct += 1 285 | break 286 | if re.match(patt_dig, token): 287 | n_dig_word += 1 288 | for y in ali: 289 | if y[0] == i+1 and y[2] == ALIGN_CRT: 290 | j = y[1] - 1 291 | n_dig_correct += 1 292 | break 293 | if not re.match(patt_cjk, token) and not re.match(patt_en, token) \ 294 | and not re.match(patt_dig, token): 295 | print("[WiredChar]:", token) 296 | return n_en_word, n_en_correct, n_dig_word, n_dig_correct 297 | 298 | 299 | 300 | if __name__ == "__main__": 301 | args = parser.parse_args() 302 | print(args, flush=True) 303 | main(args) 304 | -------------------------------------------------------------------------------- /pretrained_models/README.md: -------------------------------------------------------------------------------- 1 | Put pretrained models here. 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cn2an>=0.5.23 2 | kaldiio>=2.18.0 3 | kaldi_native_fbank>=1.15 4 | numpy>=1.26.1 5 | peft>=0.13.2 6 | sentencepiece 7 | torch>=2.0.0 8 | transformers>=4.46.3 9 | --------------------------------------------------------------------------------