├── LICENSE
├── README.md
├── assets
└── FireRedASR_model.png
├── examples
├── fireredasr
├── inference_fireredasr_aed.sh
├── inference_fireredasr_llm.sh
├── pretrained_models
└── wav
│ ├── BAC009S0764W0121.wav
│ ├── IT0011W0001.wav
│ ├── TEST_MEETING_T0000000001_S00000.wav
│ ├── TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000.wav
│ ├── text
│ └── wav.scp
├── fireredasr
├── data
│ ├── asr_feat.py
│ └── token_dict.py
├── models
│ ├── fireredasr.py
│ ├── fireredasr_aed.py
│ ├── fireredasr_llm.py
│ └── module
│ │ ├── adapter.py
│ │ ├── conformer_encoder.py
│ │ └── transformer_decoder.py
├── speech2text.py
├── tokenizer
│ ├── aed_tokenizer.py
│ └── llm_tokenizer.py
└── utils
│ ├── param.py
│ └── wer.py
├── pretrained_models
└── README.md
└── requirements.txt
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
FireRedASR: Open-Source Industrial-Grade
3 |
4 | Automatic Speech Recognition Models
5 |
6 |
7 |
8 | [[Paper]](https://arxiv.org/pdf/2501.14350)
9 | [[Model]](https://huggingface.co/fireredteam)
10 | [[Blog]](https://fireredteam.github.io/demos/firered_asr/)
11 |
12 | FireRedASR is a family of open-source industrial-grade automatic speech recognition (ASR) models supporting Mandarin, Chinese dialects and English, achieving a new state-of-the-art (SOTA) on public Mandarin ASR benchmarks, while also offering outstanding singing lyrics recognition capability.
13 |
14 |
15 | ## 🔥 News
16 | - [2025/02/17] We release [FireRedASR-LLM-L](https://huggingface.co/fireredteam/FireRedASR-LLM-L/tree/main) model weights.
17 | - [2025/01/24] We release [technical report](https://arxiv.org/pdf/2501.14350), [blog](https://fireredteam.github.io/demos/firered_asr/), and [FireRedASR-AED-L](https://huggingface.co/fireredteam/FireRedASR-AED-L/tree/main) model weights.
18 |
19 |
20 | ## Method
21 |
22 | FireRedASR is designed to meet diverse requirements in superior performance and optimal efficiency across various applications. It comprises two variants:
23 | - FireRedASR-LLM: Designed to achieve state-of-the-art (SOTA) performance and to enable seamless end-to-end speech interaction. It adopts an Encoder-Adapter-LLM framework leveraging large language model (LLM) capabilities.
24 | - FireRedASR-AED: Designed to balance high performance and computational efficiency and to serve as an effective speech representation module in LLM-based speech models. It utilizes an Attention-based Encoder-Decoder (AED) architecture.
25 |
26 | 
27 |
28 |
29 | ## Evaluation
30 | Results are reported in Character Error Rate (CER%) for Chinese and Word Error Rate (WER%) for English.
31 |
32 | ### Evaluation on Public Mandarin ASR Benchmarks
33 | | Model | #Params | aishell1 | aishell2 | ws\_net | ws\_meeting | Average-4 |
34 | |:----------------:|:-------:|:--------:|:--------:|:--------:|:-----------:|:---------:|
35 | | FireRedASR-LLM | 8.3B | 0.76 | 2.15 | 4.60 | 4.67 | 3.05 |
36 | | FireRedASR-AED | 1.1B | 0.55 | 2.52 | 4.88 | 4.76 | 3.18 |
37 | | Seed-ASR | 12B+ | 0.68 | 2.27 | 4.66 | 5.69 | 3.33 |
38 | | Qwen-Audio | 8.4B | 1.30 | 3.10 | 9.50 | 10.87 | 6.19 |
39 | | SenseVoice-L | 1.6B | 2.09 | 3.04 | 6.01 | 6.73 | 4.47 |
40 | | Whisper-Large-v3 | 1.6B | 5.14 | 4.96 | 10.48 | 18.87 | 9.86 |
41 | | Paraformer-Large | 0.2B | 1.68 | 2.85 | 6.74 | 6.97 | 4.56 |
42 |
43 | `ws` means WenetSpeech.
44 |
45 | ### Evaluation on Public Chinese Dialect and English ASR Benchmarks
46 | |Test Set | KeSpeech | LibriSpeech test-clean | LibriSpeech test-other |
47 | | :------------:| :------: | :--------------------: | :----------------------:|
48 | |FireRedASR-LLM | 3.56 | 1.73 | 3.67 |
49 | |FireRedASR-AED | 4.48 | 1.93 | 4.44 |
50 | |Previous SOTA Results | 6.70 | 1.82 | 3.50 |
51 |
52 |
53 | ## Usage
54 | Download model files from [huggingface](https://huggingface.co/fireredteam) and place them in the folder `pretrained_models`.
55 |
56 | If you want to use `FireRedASR-LLM-L`, you also need to download [Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct) and place it in the folder `pretrained_models`. Then, go to folder `FireRedASR-LLM-L` and run `$ ln -s ../Qwen2-7B-Instruct`
57 |
58 |
59 | ### Setup
60 | Create a Python environment and install dependencies
61 | ```bash
62 | $ git clone https://github.com/FireRedTeam/FireRedASR.git
63 | $ conda create --name fireredasr python=3.10
64 | $ pip install -r requirements.txt
65 | ```
66 |
67 | Set up Linux PATH and PYTHONPATH
68 | ```
69 | $ export PATH=$PWD/fireredasr/:$PWD/fireredasr/utils/:$PATH
70 | $ export PYTHONPATH=$PWD/:$PYTHONPATH
71 | ```
72 |
73 | Convert audio to 16kHz 16-bit PCM format
74 | ```
75 | ffmpeg -i input_audio -ar 16000 -ac 1 -acodec pcm_s16le -f wav output.wav
76 | ```
77 |
78 | ### Quick Start
79 | ```bash
80 | $ cd examples
81 | $ bash inference_fireredasr_aed.sh
82 | $ bash inference_fireredasr_llm.sh
83 | ```
84 |
85 | ### Command-line Usage
86 | ```bash
87 | $ speech2text.py --help
88 | $ speech2text.py --wav_path examples/wav/BAC009S0764W0121.wav --asr_type "aed" --model_dir pretrained_models/FireRedASR-AED-L
89 | $ speech2text.py --wav_path examples/wav/BAC009S0764W0121.wav --asr_type "llm" --model_dir pretrained_models/FireRedASR-LLM-L
90 | ```
91 |
92 | ### Python Usage
93 | ```python
94 | from fireredasr.models.fireredasr import FireRedAsr
95 |
96 | batch_uttid = ["BAC009S0764W0121"]
97 | batch_wav_path = ["examples/wav/BAC009S0764W0121.wav"]
98 |
99 | # FireRedASR-AED
100 | model = FireRedAsr.from_pretrained("aed", "pretrained_models/FireRedASR-AED-L")
101 | results = model.transcribe(
102 | batch_uttid,
103 | batch_wav_path,
104 | {
105 | "use_gpu": 1,
106 | "beam_size": 3,
107 | "nbest": 1,
108 | "decode_max_len": 0,
109 | "softmax_smoothing": 1.25,
110 | "aed_length_penalty": 0.6,
111 | "eos_penalty": 1.0
112 | }
113 | )
114 | print(results)
115 |
116 |
117 | # FireRedASR-LLM
118 | model = FireRedAsr.from_pretrained("llm", "pretrained_models/FireRedASR-LLM-L")
119 | results = model.transcribe(
120 | batch_uttid,
121 | batch_wav_path,
122 | {
123 | "use_gpu": 1,
124 | "beam_size": 3,
125 | "decode_max_len": 0,
126 | "decode_min_len": 0,
127 | "repetition_penalty": 3.0,
128 | "llm_length_penalty": 1.0,
129 | "temperature": 1.0
130 | }
131 | )
132 | print(results)
133 | ```
134 |
135 | ## Usage Tips
136 | ### Batch Beam Search
137 | - When performing batch beam search with FireRedASR-LLM, please ensure that the input lengths of the utterances are similar. If there are significant differences in utterance lengths, shorter utterances may experience repetition issues. You can either sort your dataset by length or set `batch_size` to 1 to avoid the repetition issue.
138 |
139 | ### Input Length Limitations
140 | - FireRedASR-AED supports audio input up to 60s. Input longer than 60s may cause hallucination issues, and input exceeding 200s will trigger positional encoding errors.
141 | - FireRedASR-LLM supports audio input up to 30s. The behavior for longer input is currently unknown.
142 |
143 |
144 | ## Acknowledgements
145 | Thanks to the following open-source works:
146 | - [Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct)
147 | - [icefall/ASR_LLM](https://github.com/k2-fsa/icefall/tree/master/egs/speech_llm/ASR_LLM)
148 | - [WeNet](https://github.com/wenet-e2e/wenet)
149 | - [Speech-Transformer](https://github.com/kaituoxu/Speech-Transformer)
150 |
151 |
152 | ## Citation
153 | ```bibtex
154 | @article{xu2025fireredasr,
155 | title={FireRedASR: Open-Source Industrial-Grade Mandarin Speech Recognition Models from Encoder-Decoder to LLM Integration},
156 | author={Xu, Kai-Tuo and Xie, Feng-Long and Tang, Xu and Hu, Yao},
157 | journal={arXiv preprint arXiv:2501.14350},
158 | year={2025}
159 | }
160 | ```
161 |
--------------------------------------------------------------------------------
/assets/FireRedASR_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FireRedTeam/FireRedASR/1eadb81b66eca948cd492bc0aeedd786333c049d/assets/FireRedASR_model.png
--------------------------------------------------------------------------------
/examples/fireredasr:
--------------------------------------------------------------------------------
1 | ../fireredasr
--------------------------------------------------------------------------------
/examples/inference_fireredasr_aed.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export PATH=$PWD/fireredasr/:$PWD/fireredasr/utils/:$PATH
4 | export PYTHONPATH=$PWD/:$PYTHONPATH
5 |
6 | # model_dir includes model.pth.tar, cmvn.ark, dict.txt
7 | model_dir=$PWD/pretrained_models/FireRedASR-AED-L
8 |
9 | # Support several input format
10 | wavs="--wav_path wav/BAC009S0764W0121.wav"
11 | wavs="--wav_paths wav/BAC009S0764W0121.wav wav/IT0011W0001.wav wav/TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000.wav wav/TEST_MEETING_T0000000001_S00000.wav"
12 | wavs="--wav_dir wav/"
13 | wavs="--wav_scp wav/wav.scp"
14 |
15 | out="out/aed-l-asr.txt"
16 |
17 | decode_args="
18 | --batch_size 2 --beam_size 3 --nbest 1
19 | --decode_max_len 0 --softmax_smoothing 1.25 --aed_length_penalty 0.6
20 | --eos_penalty 1.0
21 | "
22 |
23 | mkdir -p $(dirname $out)
24 | set -x
25 |
26 |
27 | CUDA_VISIBLE_DEVICES=0 \
28 | speech2text.py --asr_type "aed" --model_dir $model_dir $decode_args $wavs --output $out
29 |
30 |
31 | ref="wav/text"
32 | wer.py --print_sentence_wer 1 --do_tn 0 --rm_special 0 --ref $ref --hyp $out > $out.wer 2>&1
33 | tail -n8 $out.wer
34 |
--------------------------------------------------------------------------------
/examples/inference_fireredasr_llm.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export PATH=$PWD/fireredasr/:$PWD/fireredasr/utils/:$PATH
4 | export PYTHONPATH=$PWD/:$PYTHONPATH
5 |
6 | # model_dir includes model.pth.tar, asr_encoder.pth.tar, cmvn.ark, Qwen2-7B-Instruct
7 | model_dir=$PWD/pretrained_models/FireRedASR-LLM-L
8 |
9 | # Support several input format
10 | wavs="--wav_path wav/BAC009S0764W0121.wav"
11 | wavs="--wav_paths wav/BAC009S0764W0121.wav wav/IT0011W0001.wav wav/TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000.wav wav/TEST_MEETING_T0000000001_S00000.wav"
12 | wavs="--wav_dir wav/"
13 | wavs="--wav_scp wav/wav.scp"
14 |
15 | out="out/llm-l-asr.txt"
16 |
17 | decode_args="
18 | --batch_size 1 --beam_size 3 --decode_max_len 0 --decode_min_len 0
19 | --repetition_penalty 3.0 --llm_length_penalty 1.0 --temperature 1.0
20 | "
21 |
22 | mkdir -p $(dirname $out)
23 | set -x
24 |
25 |
26 | CUDA_VISIBLE_DEVICES=0 \
27 | speech2text.py --asr_type "llm" --model_dir $model_dir $decode_args $wavs --output $out
28 |
29 |
30 | ref="wav/text"
31 | wer.py --print_sentence_wer 1 --do_tn 0 --rm_special 1 --ref $ref --hyp $out > $out.wer 2>&1
32 | tail -n8 $out.wer
33 |
--------------------------------------------------------------------------------
/examples/pretrained_models:
--------------------------------------------------------------------------------
1 | ../pretrained_models
--------------------------------------------------------------------------------
/examples/wav/BAC009S0764W0121.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FireRedTeam/FireRedASR/1eadb81b66eca948cd492bc0aeedd786333c049d/examples/wav/BAC009S0764W0121.wav
--------------------------------------------------------------------------------
/examples/wav/IT0011W0001.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FireRedTeam/FireRedASR/1eadb81b66eca948cd492bc0aeedd786333c049d/examples/wav/IT0011W0001.wav
--------------------------------------------------------------------------------
/examples/wav/TEST_MEETING_T0000000001_S00000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FireRedTeam/FireRedASR/1eadb81b66eca948cd492bc0aeedd786333c049d/examples/wav/TEST_MEETING_T0000000001_S00000.wav
--------------------------------------------------------------------------------
/examples/wav/TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FireRedTeam/FireRedASR/1eadb81b66eca948cd492bc0aeedd786333c049d/examples/wav/TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000.wav
--------------------------------------------------------------------------------
/examples/wav/text:
--------------------------------------------------------------------------------
1 | BAC009S0764W0121 甚至 出现 交易 几乎 停滞 的 情况
2 | IT0011W0001 换一首歌
3 | TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000 我有的时候说不清楚你们知道吗
4 | TEST_MEETING_T0000000001_S00000 好首先说一下刚才这个经理说完的这个销售问题咱再说一下咱们的商场问题首先咱们商场上半年业这个先各部门儿汇报一下就是业绩
5 |
--------------------------------------------------------------------------------
/examples/wav/wav.scp:
--------------------------------------------------------------------------------
1 | BAC009S0764W0121 wav/BAC009S0764W0121.wav
2 | IT0011W0001 wav/IT0011W0001.wav
3 | TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000 wav/TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000.wav
4 | TEST_MEETING_T0000000001_S00000 wav/TEST_MEETING_T0000000001_S00000.wav
5 |
--------------------------------------------------------------------------------
/fireredasr/data/asr_feat.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 |
4 | import kaldiio
5 | import kaldi_native_fbank as knf
6 | import numpy as np
7 | import torch
8 |
9 |
10 | class ASRFeatExtractor:
11 | def __init__(self, kaldi_cmvn_file):
12 | self.cmvn = CMVN(kaldi_cmvn_file) if kaldi_cmvn_file != "" else None
13 | self.fbank = KaldifeatFbank(num_mel_bins=80, frame_length=25,
14 | frame_shift=10, dither=0.0)
15 |
16 | def __call__(self, wav_paths):
17 | feats = []
18 | durs = []
19 | for wav_path in wav_paths:
20 | sample_rate, wav_np = kaldiio.load_mat(wav_path)
21 | dur = wav_np.shape[0] / sample_rate
22 | fbank = self.fbank((sample_rate, wav_np))
23 | if self.cmvn is not None:
24 | fbank = self.cmvn(fbank)
25 | fbank = torch.from_numpy(fbank).float()
26 | feats.append(fbank)
27 | durs.append(dur)
28 | lengths = torch.tensor([feat.size(0) for feat in feats]).long()
29 | feats_pad = self.pad_feat(feats, 0.0)
30 | return feats_pad, lengths, durs
31 |
32 | def pad_feat(self, xs, pad_value):
33 | # type: (List[Tensor], int) -> Tensor
34 | n_batch = len(xs)
35 | max_len = max([xs[i].size(0) for i in range(n_batch)])
36 | pad = torch.ones(n_batch, max_len, *xs[0].size()[1:]).to(xs[0].device).to(xs[0].dtype).fill_(pad_value)
37 | for i in range(n_batch):
38 | pad[i, :xs[i].size(0)] = xs[i]
39 | return pad
40 |
41 |
42 |
43 |
44 | class CMVN:
45 | def __init__(self, kaldi_cmvn_file):
46 | self.dim, self.means, self.inverse_std_variences = \
47 | self.read_kaldi_cmvn(kaldi_cmvn_file)
48 |
49 | def __call__(self, x, is_train=False):
50 | assert x.shape[-1] == self.dim, "CMVN dim mismatch"
51 | out = x - self.means
52 | out = out * self.inverse_std_variences
53 | return out
54 |
55 | def read_kaldi_cmvn(self, kaldi_cmvn_file):
56 | assert os.path.exists(kaldi_cmvn_file)
57 | stats = kaldiio.load_mat(kaldi_cmvn_file)
58 | assert stats.shape[0] == 2
59 | dim = stats.shape[-1] - 1
60 | count = stats[0, dim]
61 | assert count >= 1
62 | floor = 1e-20
63 | means = []
64 | inverse_std_variences = []
65 | for d in range(dim):
66 | mean = stats[0, d] / count
67 | means.append(mean.item())
68 | varience = (stats[1, d] / count) - mean*mean
69 | if varience < floor:
70 | varience = floor
71 | istd = 1.0 / math.sqrt(varience)
72 | inverse_std_variences.append(istd)
73 | return dim, np.array(means), np.array(inverse_std_variences)
74 |
75 |
76 |
77 | class KaldifeatFbank:
78 | def __init__(self, num_mel_bins=80, frame_length=25, frame_shift=10,
79 | dither=1.0):
80 | self.dither = dither
81 | opts = knf.FbankOptions()
82 | opts.frame_opts.dither = dither
83 | opts.mel_opts.num_bins = num_mel_bins
84 | opts.frame_opts.snip_edges = True
85 | opts.mel_opts.debug_mel = False
86 | self.opts = opts
87 |
88 | def __call__(self, wav, is_train=False):
89 | if type(wav) is str:
90 | sample_rate, wav_np = kaldiio.load_mat(wav)
91 | elif type(wav) in [tuple, list] and len(wav) == 2:
92 | sample_rate, wav_np = wav
93 | assert len(wav_np.shape) == 1
94 |
95 | dither = self.dither if is_train else 0.0
96 | self.opts.frame_opts.dither = dither
97 | fbank = knf.OnlineFbank(self.opts)
98 |
99 | fbank.accept_waveform(sample_rate, wav_np.tolist())
100 | feat = []
101 | for i in range(fbank.num_frames_ready):
102 | feat.append(fbank.get_frame(i))
103 | if len(feat) == 0:
104 | print("Check data, len(feat) == 0", wav, flush=True)
105 | return np.zeros((0, self.opts.mel_opts.num_bins))
106 | feat = np.vstack(feat)
107 | return feat
108 |
--------------------------------------------------------------------------------
/fireredasr/data/token_dict.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 |
4 | class TokenDict:
5 | def __init__(self, dict_path, unk=""):
6 | assert dict_path != ""
7 | self.id2word, self.word2id = self.read_dict(dict_path)
8 | self.unk = unk
9 | assert unk == "" or unk in self.word2id
10 | self.unkid = self.word2id[unk] if unk else -1
11 |
12 | def get(self, key, default):
13 | if type(default) == str:
14 | default = self.word2id[default]
15 | return self.word2id.get(key, default)
16 |
17 | def __getitem__(self, key):
18 | if type(key) == str:
19 | if self.unk:
20 | return self.word2id.get(key, self.word2id[self.unk])
21 | else:
22 | return self.word2id[key]
23 | elif type(key) == int:
24 | return self.id2word[key]
25 | else:
26 | raise TypeError("Key should be str or int")
27 |
28 | def __len__(self):
29 | return len(self.id2word)
30 |
31 | def __contains__(self, query):
32 | if type(query) == str:
33 | return query in self.word2id
34 | elif type(query) == int:
35 | return query in self.id2word
36 | else:
37 | raise TypeError("query should be str or int")
38 |
39 | def read_dict(self, dict_path):
40 | id2word, word2id = [], {}
41 | with open(dict_path, encoding='utf8') as f:
42 | for i, line in enumerate(f):
43 | tokens = line.strip().split()
44 | if len(tokens) >= 2:
45 | word, index = tokens[0], int(tokens[1])
46 | elif len(tokens) == 1:
47 | word, index = tokens[0], i
48 | else: # empty line or space
49 | logging.info(f"Find empty line or space '{line.strip()}' in {dict_path}:L{i}, set to ' '")
50 | word, index = " ", i
51 | assert len(id2word) == index
52 | assert len(word2id) == index
53 | if word == "":
54 | logging.info(f"NOTE: Find in {dict_path}:L{i} and convert it to ' '")
55 | word = " "
56 | word2id[word] = index
57 | id2word.append(word)
58 | assert len(id2word) == len(word2id)
59 | return id2word, word2id
60 |
--------------------------------------------------------------------------------
/fireredasr/models/fireredasr.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 |
4 | import torch
5 |
6 | from fireredasr.data.asr_feat import ASRFeatExtractor
7 | from fireredasr.models.fireredasr_aed import FireRedAsrAed
8 | from fireredasr.models.fireredasr_llm import FireRedAsrLlm
9 | from fireredasr.tokenizer.aed_tokenizer import ChineseCharEnglishSpmTokenizer
10 | from fireredasr.tokenizer.llm_tokenizer import LlmTokenizerWrapper
11 |
12 |
13 | class FireRedAsr:
14 | @classmethod
15 | def from_pretrained(cls, asr_type, model_dir):
16 | assert asr_type in ["aed", "llm"]
17 |
18 | cmvn_path = os.path.join(model_dir, "cmvn.ark")
19 | feat_extractor = ASRFeatExtractor(cmvn_path)
20 |
21 | if asr_type == "aed":
22 | model_path = os.path.join(model_dir, "model.pth.tar")
23 | dict_path =os.path.join(model_dir, "dict.txt")
24 | spm_model = os.path.join(model_dir, "train_bpe1000.model")
25 | model = load_fireredasr_aed_model(model_path)
26 | tokenizer = ChineseCharEnglishSpmTokenizer(dict_path, spm_model)
27 | elif asr_type == "llm":
28 | model_path = os.path.join(model_dir, "model.pth.tar")
29 | encoder_path = os.path.join(model_dir, "asr_encoder.pth.tar")
30 | llm_dir = os.path.join(model_dir, "Qwen2-7B-Instruct")
31 | model, tokenizer = load_firered_llm_model_and_tokenizer(
32 | model_path, encoder_path, llm_dir)
33 | model.eval()
34 | return cls(asr_type, feat_extractor, model, tokenizer)
35 |
36 | def __init__(self, asr_type, feat_extractor, model, tokenizer):
37 | self.asr_type = asr_type
38 | self.feat_extractor = feat_extractor
39 | self.model = model
40 | self.tokenizer = tokenizer
41 |
42 | @torch.no_grad()
43 | def transcribe(self, batch_uttid, batch_wav_path, args={}):
44 | feats, lengths, durs = self.feat_extractor(batch_wav_path)
45 | total_dur = sum(durs)
46 | if args.get("use_gpu", False):
47 | feats, lengths = feats.cuda(), lengths.cuda()
48 | self.model.cuda()
49 | else:
50 | self.model.cpu()
51 |
52 | if self.asr_type == "aed":
53 | start_time = time.time()
54 |
55 | hyps = self.model.transcribe(
56 | feats, lengths,
57 | args.get("beam_size", 1),
58 | args.get("nbest", 1),
59 | args.get("decode_max_len", 0),
60 | args.get("softmax_smoothing", 1.0),
61 | args.get("aed_length_penalty", 0.0),
62 | args.get("eos_penalty", 1.0)
63 | )
64 |
65 | elapsed = time.time() - start_time
66 | rtf= elapsed / total_dur if total_dur > 0 else 0
67 |
68 | results = []
69 | for uttid, wav, hyp in zip(batch_uttid, batch_wav_path, hyps):
70 | hyp = hyp[0] # only return 1-best
71 | hyp_ids = [int(id) for id in hyp["yseq"].cpu()]
72 | text = self.tokenizer.detokenize(hyp_ids)
73 | results.append({"uttid": uttid, "text": text, "wav": wav,
74 | "rtf": f"{rtf:.4f}"})
75 | return results
76 |
77 | elif self.asr_type == "llm":
78 | input_ids, attention_mask, _, _ = \
79 | LlmTokenizerWrapper.preprocess_texts(
80 | origin_texts=[""]*feats.size(0), tokenizer=self.tokenizer,
81 | max_len=128, decode=True)
82 | if args.get("use_gpu", False):
83 | input_ids = input_ids.cuda()
84 | attention_mask = attention_mask.cuda()
85 | start_time = time.time()
86 |
87 | generated_ids = self.model.transcribe(
88 | feats, lengths, input_ids, attention_mask,
89 | args.get("beam_size", 1),
90 | args.get("decode_max_len", 0),
91 | args.get("decode_min_len", 0),
92 | args.get("repetition_penalty", 1.0),
93 | args.get("llm_length_penalty", 0.0),
94 | args.get("temperature", 1.0)
95 | )
96 |
97 | elapsed = time.time() - start_time
98 | rtf= elapsed / total_dur if total_dur > 0 else 0
99 | texts = self.tokenizer.batch_decode(generated_ids,
100 | skip_special_tokens=True)
101 | results = []
102 | for uttid, wav, text in zip(batch_uttid, batch_wav_path, texts):
103 | results.append({"uttid": uttid, "text": text, "wav": wav,
104 | "rtf": f"{rtf:.4f}"})
105 | return results
106 |
107 |
108 |
109 | def load_fireredasr_aed_model(model_path):
110 | package = torch.load(model_path, map_location=lambda storage, loc: storage)
111 | print("model args:", package["args"])
112 | model = FireRedAsrAed.from_args(package["args"])
113 | model.load_state_dict(package["model_state_dict"], strict=True)
114 | return model
115 |
116 |
117 | def load_firered_llm_model_and_tokenizer(model_path, encoder_path, llm_dir):
118 | package = torch.load(model_path, map_location=lambda storage, loc: storage)
119 | package["args"].encoder_path = encoder_path
120 | package["args"].llm_dir = llm_dir
121 | print("model args:", package["args"])
122 | model = FireRedAsrLlm.from_args(package["args"])
123 | model.load_state_dict(package["model_state_dict"], strict=False)
124 | tokenizer = LlmTokenizerWrapper.build_llm_tokenizer(llm_dir)
125 | return model, tokenizer
126 |
--------------------------------------------------------------------------------
/fireredasr/models/fireredasr_aed.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from fireredasr.models.module.conformer_encoder import ConformerEncoder
4 | from fireredasr.models.module.transformer_decoder import TransformerDecoder
5 |
6 |
7 | class FireRedAsrAed(torch.nn.Module):
8 | @classmethod
9 | def from_args(cls, args):
10 | return cls(args)
11 |
12 | def __init__(self, args):
13 | super().__init__()
14 | self.sos_id = args.sos_id
15 | self.eos_id = args.eos_id
16 |
17 | self.encoder = ConformerEncoder(
18 | args.idim, args.n_layers_enc, args.n_head, args.d_model,
19 | args.residual_dropout, args.dropout_rate,
20 | args.kernel_size, args.pe_maxlen)
21 |
22 | self.decoder = TransformerDecoder(
23 | args.sos_id, args.eos_id, args.pad_id, args.odim,
24 | args.n_layers_dec, args.n_head, args.d_model,
25 | args.residual_dropout, args.pe_maxlen)
26 |
27 | def transcribe(self, padded_input, input_lengths,
28 | beam_size=1, nbest=1, decode_max_len=0,
29 | softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0):
30 | enc_outputs, _, enc_mask = self.encoder(padded_input, input_lengths)
31 | nbest_hyps = self.decoder.batch_beam_search(
32 | enc_outputs, enc_mask,
33 | beam_size, nbest, decode_max_len,
34 | softmax_smoothing, length_penalty, eos_penalty)
35 | return nbest_hyps
36 |
--------------------------------------------------------------------------------
/fireredasr/models/fireredasr_llm.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import random
4 | import re
5 |
6 | import torch
7 | import torch.nn as nn
8 | from transformers import AutoModelForCausalLM
9 |
10 | from fireredasr.models.fireredasr_aed import FireRedAsrAed
11 | from fireredasr.models.module.adapter import Adapter
12 | from fireredasr.tokenizer.llm_tokenizer import DEFAULT_SPEECH_TOKEN, IGNORE_TOKEN_ID
13 | from fireredasr.tokenizer.llm_tokenizer import LlmTokenizerWrapper
14 | from fireredasr.utils.param import count_model_parameters
15 |
16 |
17 | class FireRedAsrLlm(nn.Module):
18 | @classmethod
19 | def load_encoder(cls, model_path):
20 | assert os.path.exists(model_path)
21 | package = torch.load(model_path, map_location=lambda storage, loc: storage)
22 | model = FireRedAsrAed.from_args(package["args"])
23 | if "model_state_dict" in package:
24 | model.load_state_dict(package["model_state_dict"], strict=False)
25 | encoder = model.encoder
26 | encoder_dim = encoder.odim
27 | return encoder, encoder_dim
28 |
29 | @classmethod
30 | def from_args(cls, args):
31 | logging.info(args)
32 | logging.info("Build FireRedAsrLlm")
33 | # Build Speech Encoder
34 | encoder, encoder_dim = cls.load_encoder(args.encoder_path)
35 | count_model_parameters(encoder)
36 | if args.freeze_encoder:
37 | logging.info(f"Frezee encoder")
38 | for name, param in encoder.named_parameters():
39 | param.requires_grad = False
40 | encoder.eval()
41 |
42 | if args.use_flash_attn:
43 | attn_implementation = "flash_attention_2"
44 | if args.use_fp16:
45 | torch_dtype = torch.float16
46 | else:
47 | torch_dtype = torch.float32
48 | else:
49 | attn_implementation = "eager"
50 | if args.use_fp16:
51 | torch_dtype = torch.float16
52 | else:
53 | torch_dtype = torch.float32
54 |
55 | # Build LLM
56 | llm = AutoModelForCausalLM.from_pretrained(
57 | args.llm_dir,
58 | attn_implementation=attn_implementation,
59 | torch_dtype=torch_dtype,
60 | )
61 | count_model_parameters(llm)
62 |
63 | # LLM Freeze or LoRA
64 | llm_dim = llm.config.hidden_size
65 | if args.freeze_llm:
66 | logging.info(f"Frezee LLM")
67 | for name, param in llm.named_parameters():
68 | param.requires_grad = False
69 | llm.eval()
70 | else:
71 | if args.use_lora:
72 | from peft import LoraConfig, get_peft_model
73 | lora_config = LoraConfig(
74 | r=64,
75 | lora_alpha=16,
76 | target_modules=[
77 | "q_proj",
78 | "k_proj",
79 | "v_proj",
80 | "o_proj",
81 | "up_proj",
82 | "gate_proj",
83 | "down_proj",
84 | ],
85 | lora_dropout=0.05,
86 | task_type="CAUSAL_LM",
87 | )
88 | llm = get_peft_model(llm, lora_config)
89 | llm.print_trainable_parameters()
90 |
91 | tokenizer = LlmTokenizerWrapper.build_llm_tokenizer(args.llm_dir)
92 | assert tokenizer.pad_token_id == tokenizer.convert_tokens_to_ids("<|endoftext|>")
93 | llm.config.pad_token_id = tokenizer.pad_token_id
94 | llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
95 | llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
96 | llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
97 | DEFAULT_SPEECH_TOKEN
98 | )
99 |
100 | # Build projector
101 | encoder_projector = Adapter(
102 | encoder_dim, llm_dim, args.encoder_downsample_rate)
103 | count_model_parameters(encoder_projector)
104 |
105 | return cls(encoder, llm, encoder_projector,
106 | args.freeze_encoder, args.freeze_llm)
107 |
108 | def __init__(self, encoder, llm, encoder_projector,
109 | freeze_encoder, freeze_llm):
110 | super().__init__()
111 | self.encoder = encoder
112 | self.llm = llm
113 | self.encoder_projector = encoder_projector
114 | # args
115 | self.freeze_encoder = freeze_encoder
116 | self.freeze_llm = freeze_llm
117 | self.llm_config = llm.config
118 |
119 | def transcribe(self, padded_feat, feat_lengths, padded_input_ids, attention_mask,
120 | beam_size=1, decode_max_len=0, decode_min_len=0,
121 | repetition_penalty=1.0, llm_length_penalty=1.0, temperature=1.0):
122 | encoder_outs, enc_lengths, enc_mask = self.encoder(padded_feat, feat_lengths)
123 | speech_features, speech_lens = self.encoder_projector(encoder_outs, enc_lengths)
124 | inputs_embeds = self.llm.get_input_embeddings()(padded_input_ids)
125 |
126 | inputs_embeds, attention_mask, _ = \
127 | self._merge_input_ids_with_speech_features(
128 | speech_features.to(inputs_embeds.dtype), inputs_embeds, padded_input_ids, attention_mask,
129 | speech_lens=speech_lens
130 | )
131 |
132 | max_new_tokens = speech_features.size(1) if decode_max_len < 1 else decode_max_len
133 | max_new_tokens = max(1, max_new_tokens)
134 |
135 | generated_ids = self.llm.generate(
136 | inputs_embeds=inputs_embeds,
137 | max_new_tokens=max_new_tokens,
138 | num_beams=beam_size,
139 | do_sample=False,
140 | min_length=decode_min_len,
141 | top_p=1.0,
142 | repetition_penalty=repetition_penalty,
143 | length_penalty=llm_length_penalty,
144 | temperature=temperature,
145 | bos_token_id=self.llm.config.bos_token_id,
146 | eos_token_id=self.llm.config.eos_token_id,
147 | pad_token_id=self.llm.config.pad_token_id,
148 | )
149 |
150 | return generated_ids
151 |
152 |
153 | def _merge_input_ids_with_speech_features(
154 | self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None,
155 | speech_lens=None
156 | ):
157 | """
158 | Modified from: https://github.com/k2-fsa/icefall/blob/master/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py
159 | """
160 | speech_lens = None
161 | num_speechs, speech_len, embed_dim = speech_features.shape
162 | batch_size, sequence_length = input_ids.shape
163 | left_padding = not torch.sum(
164 | input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id)
165 | )
166 | # 1. Create a mask to know where special speech tokens are
167 | special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id
168 | num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1)
169 | # Compute the maximum embed dimension
170 | max_embed_dim = (
171 | num_special_speech_tokens.max() * (speech_len - 1)
172 | ) + sequence_length
173 | batch_indices, non_speech_indices = torch.where(
174 | input_ids != self.llm.config.default_speech_token_id
175 | )
176 |
177 | # 2. Compute the positions where text should be written
178 | # Calculate new positions for text tokens in merged speech-text sequence.
179 | # `special_speech_token_mask` identifies speech tokens. Each speech token will be replaced by `nb_text_tokens_per_speechs - 1` text tokens.
180 | # `torch.cumsum` computes how each speech token shifts subsequent text token positions.
181 | # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
182 | new_token_positions = (
183 | torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1
184 | ) # (N,U)
185 | nb_speech_pad = max_embed_dim - 1 - new_token_positions[:, -1]
186 | if left_padding:
187 | new_token_positions += nb_speech_pad[:, None] # offset for left padding
188 | text_to_overwrite = new_token_positions[batch_indices, non_speech_indices]
189 |
190 | # 3. Create the full embedding, already padded to the maximum position
191 | final_embedding = torch.zeros(
192 | batch_size,
193 | max_embed_dim,
194 | embed_dim,
195 | dtype=inputs_embeds.dtype,
196 | device=inputs_embeds.device,
197 | )
198 | final_attention_mask = torch.zeros(
199 | batch_size,
200 | max_embed_dim,
201 | dtype=attention_mask.dtype,
202 | device=inputs_embeds.device,
203 | )
204 | if labels is not None:
205 | final_labels = torch.full(
206 | (batch_size, max_embed_dim),
207 | IGNORE_TOKEN_ID,
208 | dtype=input_ids.dtype,
209 | device=input_ids.device,
210 | )
211 | # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
212 | # set the corresponding tensors into their correct target device.
213 | target_device = inputs_embeds.device
214 | batch_indices, non_speech_indices, text_to_overwrite = (
215 | batch_indices.to(target_device),
216 | non_speech_indices.to(target_device),
217 | text_to_overwrite.to(target_device),
218 | )
219 | attention_mask = attention_mask.to(target_device)
220 |
221 | # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"]
222 | # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the speech features
223 | final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[
224 | batch_indices, non_speech_indices
225 | ]
226 | final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[
227 | batch_indices, non_speech_indices
228 | ]
229 | if labels is not None:
230 | final_labels[batch_indices, text_to_overwrite] = labels[
231 | batch_indices, non_speech_indices
232 | ]
233 |
234 | # 5. Fill the embeddings corresponding to the speechs. Anything that is not `text_positions` needs filling (#29835)
235 | speech_to_overwrite = torch.full(
236 | (batch_size, max_embed_dim),
237 | True,
238 | dtype=torch.bool,
239 | device=inputs_embeds.device,
240 | )
241 | speech_to_overwrite[batch_indices, text_to_overwrite] = False
242 | if speech_lens is not None:
243 | speech_pad_position = speech_to_overwrite.cumsum(-1) <= speech_lens[:, None]
244 | speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[
245 | :, None
246 | ].to(target_device)
247 |
248 | if speech_to_overwrite.sum() != speech_features.shape[:-1].numel():
249 | raise ValueError(
250 | f"The input provided to the model are wrong. The number of speech tokens is {torch.sum(special_speech_token_mask)} while"
251 | f" the number of speech given to the model is {num_speechs}. This prevents correct indexing and breaks batch generation."
252 | )
253 |
254 | final_embedding[speech_to_overwrite] = (
255 | speech_features.contiguous().reshape(-1, embed_dim).to(target_device)
256 | )
257 | if speech_lens is not None:
258 | speech_to_overwrite &= speech_pad_position
259 | final_attention_mask |= speech_to_overwrite
260 |
261 | # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
262 | batch_indices, pad_indices = torch.where(
263 | input_ids == self.llm.config.pad_token_id
264 | )
265 | indices_to_mask = new_token_positions[batch_indices, pad_indices]
266 |
267 | final_embedding[batch_indices, indices_to_mask] = 0
268 |
269 | if labels is None:
270 | final_labels = None
271 |
272 | return final_embedding, final_attention_mask, final_labels #, position_ids
273 |
--------------------------------------------------------------------------------
/fireredasr/models/module/adapter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Adapter(nn.Module):
6 | def __init__(self, encoder_dim, llm_dim, downsample_rate=2):
7 | super().__init__()
8 | self.ds = downsample_rate
9 | self.linear1 = nn.Linear(encoder_dim * downsample_rate, llm_dim)
10 | self.relu = nn.ReLU()
11 | self.linear2 = nn.Linear(llm_dim, llm_dim)
12 |
13 | def forward(self, x, x_lens):
14 | batch_size, seq_len, feat_dim = x.size()
15 | num_frames_to_discard = seq_len % self.ds
16 | if num_frames_to_discard > 0:
17 | x = x[:, :-num_frames_to_discard, :]
18 | seq_len = x.size(1)
19 |
20 | x = x.contiguous()
21 | x = x.view(
22 | batch_size, seq_len // self.ds, feat_dim * self.ds
23 | )
24 |
25 | x = self.linear1(x)
26 | x = self.relu(x)
27 | x = self.linear2(x)
28 |
29 | new_x_lens = torch.clamp(x_lens, max=seq_len) // self.ds
30 | return x, new_x_lens
31 |
--------------------------------------------------------------------------------
/fireredasr/models/module/conformer_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ConformerEncoder(nn.Module):
7 | def __init__(self, idim, n_layers, n_head, d_model,
8 | residual_dropout=0.1, dropout_rate=0.1, kernel_size=33,
9 | pe_maxlen=5000):
10 | super().__init__()
11 | self.odim = d_model
12 |
13 | self.input_preprocessor = Conv2dSubsampling(idim, d_model)
14 | self.positional_encoding = RelPositionalEncoding(d_model)
15 | self.dropout = nn.Dropout(residual_dropout)
16 |
17 | self.layer_stack = nn.ModuleList()
18 | for l in range(n_layers):
19 | block = RelPosEmbConformerBlock(d_model, n_head,
20 | residual_dropout,
21 | dropout_rate, kernel_size)
22 | self.layer_stack.append(block)
23 |
24 | def forward(self, padded_input, input_lengths, pad=True):
25 | if pad:
26 | padded_input = F.pad(padded_input,
27 | (0, 0, 0, self.input_preprocessor.context - 1), 'constant', 0.0)
28 | src_mask = self.padding_position_is_0(padded_input, input_lengths)
29 |
30 | embed_output, input_lengths, src_mask = self.input_preprocessor(padded_input, src_mask)
31 | enc_output = self.dropout(embed_output)
32 |
33 | pos_emb = self.dropout(self.positional_encoding(embed_output))
34 |
35 | enc_outputs = []
36 | for enc_layer in self.layer_stack:
37 | enc_output = enc_layer(enc_output, pos_emb, slf_attn_mask=src_mask,
38 | pad_mask=src_mask)
39 | enc_outputs.append(enc_output)
40 |
41 | return enc_output, input_lengths, src_mask
42 |
43 | def padding_position_is_0(self, padded_input, input_lengths):
44 | N, T = padded_input.size()[:2]
45 | mask = torch.ones((N, T)).to(padded_input.device)
46 | for i in range(N):
47 | mask[i, input_lengths[i]:] = 0
48 | mask = mask.unsqueeze(dim=1)
49 | return mask.to(torch.uint8)
50 |
51 |
52 | class RelPosEmbConformerBlock(nn.Module):
53 | def __init__(self, d_model, n_head,
54 | residual_dropout=0.1,
55 | dropout_rate=0.1, kernel_size=33):
56 | super().__init__()
57 | self.ffn1 = ConformerFeedForward(d_model, dropout_rate)
58 | self.mhsa = RelPosMultiHeadAttention(n_head, d_model,
59 | residual_dropout)
60 | self.conv = ConformerConvolution(d_model, kernel_size,
61 | dropout_rate)
62 | self.ffn2 = ConformerFeedForward(d_model, dropout_rate)
63 | self.layer_norm = nn.LayerNorm(d_model)
64 |
65 | def forward(self, x, pos_emb, slf_attn_mask=None, pad_mask=None):
66 | out = 0.5 * x + 0.5 * self.ffn1(x)
67 | out = self.mhsa(out, out, out, pos_emb, mask=slf_attn_mask)[0]
68 | out = self.conv(out, pad_mask)
69 | out = 0.5 * out + 0.5 * self.ffn2(out)
70 | out = self.layer_norm(out)
71 | return out
72 |
73 |
74 | class Swish(nn.Module):
75 | def forward(self, x):
76 | return x * torch.sigmoid(x)
77 |
78 |
79 | class Conv2dSubsampling(nn.Module):
80 | def __init__(self, idim, d_model, out_channels=32):
81 | super().__init__()
82 | self.conv = nn.Sequential(
83 | nn.Conv2d(1, out_channels, 3, 2),
84 | nn.ReLU(),
85 | nn.Conv2d(out_channels, out_channels, 3, 2),
86 | nn.ReLU(),
87 | )
88 | subsample_idim = ((idim - 1) // 2 - 1) // 2
89 | self.out = nn.Linear(out_channels * subsample_idim, d_model)
90 |
91 | self.subsampling = 4
92 | left_context = right_context = 3 # both exclude currect frame
93 | self.context = left_context + 1 + right_context # 7
94 |
95 | def forward(self, x, x_mask):
96 | x = x.unsqueeze(1)
97 | x = self.conv(x)
98 | N, C, T, D = x.size()
99 | x = self.out(x.transpose(1, 2).contiguous().view(N, T, C * D))
100 | mask = x_mask[:, :, :-2:2][:, :, :-2:2]
101 | input_lengths = mask[:, -1, :].sum(dim=-1)
102 | return x, input_lengths, mask
103 |
104 |
105 | class RelPositionalEncoding(torch.nn.Module):
106 | def __init__(self, d_model, max_len=5000):
107 | super().__init__()
108 | pe_positive = torch.zeros(max_len, d_model, requires_grad=False)
109 | pe_negative = torch.zeros(max_len, d_model, requires_grad=False)
110 | position = torch.arange(0, max_len).unsqueeze(1).float()
111 | div_term = torch.exp(torch.arange(0, d_model, 2).float() *
112 | -(torch.log(torch.tensor(10000.0)).item()/d_model))
113 | pe_positive[:, 0::2] = torch.sin(position * div_term)
114 | pe_positive[:, 1::2] = torch.cos(position * div_term)
115 | pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
116 | pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
117 |
118 | pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
119 | pe_negative = pe_negative[1:].unsqueeze(0)
120 | pe = torch.cat([pe_positive, pe_negative], dim=1)
121 | self.register_buffer('pe', pe)
122 |
123 | def forward(self, x):
124 | # Tmax = 2 * max_len - 1
125 | Tmax, T = self.pe.size(1), x.size(1)
126 | pos_emb = self.pe[:, Tmax // 2 - T + 1 : Tmax // 2 + T].clone().detach()
127 | return pos_emb
128 |
129 |
130 | class ConformerFeedForward(nn.Module):
131 | def __init__(self, d_model, dropout_rate=0.1):
132 | super().__init__()
133 | pre_layer_norm = nn.LayerNorm(d_model)
134 | linear_expand = nn.Linear(d_model, d_model*4)
135 | nonlinear = Swish()
136 | dropout_pre = nn.Dropout(dropout_rate)
137 | linear_project = nn.Linear(d_model*4, d_model)
138 | dropout_post = nn.Dropout(dropout_rate)
139 | self.net = nn.Sequential(pre_layer_norm,
140 | linear_expand,
141 | nonlinear,
142 | dropout_pre,
143 | linear_project,
144 | dropout_post)
145 |
146 | def forward(self, x):
147 | residual = x
148 | output = self.net(x)
149 | output = output + residual
150 | return output
151 |
152 |
153 | class ConformerConvolution(nn.Module):
154 | def __init__(self, d_model, kernel_size=33, dropout_rate=0.1):
155 | super().__init__()
156 | assert kernel_size % 2 == 1
157 | self.pre_layer_norm = nn.LayerNorm(d_model)
158 | self.pointwise_conv1 = nn.Conv1d(d_model, d_model*4, kernel_size=1, bias=False)
159 | self.glu = F.glu
160 | self.padding = (kernel_size - 1) // 2
161 | self.depthwise_conv = nn.Conv1d(d_model*2, d_model*2,
162 | kernel_size, stride=1,
163 | padding=self.padding,
164 | groups=d_model*2, bias=False)
165 | self.batch_norm = nn.LayerNorm(d_model*2)
166 | self.swish = Swish()
167 | self.pointwise_conv2 = nn.Conv1d(d_model*2, d_model, kernel_size=1, bias=False)
168 | self.dropout = nn.Dropout(dropout_rate)
169 |
170 | def forward(self, x, mask=None):
171 | residual = x
172 | out = self.pre_layer_norm(x)
173 | out = out.transpose(1, 2)
174 | if mask is not None:
175 | out.masked_fill_(mask.ne(1), 0.0)
176 | out = self.pointwise_conv1(out)
177 | out = F.glu(out, dim=1)
178 | out = self.depthwise_conv(out)
179 |
180 | out = out.transpose(1, 2)
181 | out = self.swish(self.batch_norm(out))
182 | out = out.transpose(1, 2)
183 |
184 | out = self.dropout(self.pointwise_conv2(out))
185 | if mask is not None:
186 | out.masked_fill_(mask.ne(1), 0.0)
187 | out = out.transpose(1, 2)
188 | return out + residual
189 |
190 |
191 | class EncoderMultiHeadAttention(nn.Module):
192 | def __init__(self, n_head, d_model,
193 | residual_dropout=0.1):
194 | super().__init__()
195 | assert d_model % n_head == 0
196 | self.n_head = n_head
197 | self.d_k = d_model // n_head
198 | self.d_v = self.d_k
199 |
200 | self.w_qs = nn.Linear(d_model, n_head * self.d_k, bias=False)
201 | self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False)
202 | self.w_vs = nn.Linear(d_model, n_head * self.d_v, bias=False)
203 |
204 | self.layer_norm_q = nn.LayerNorm(d_model)
205 | self.layer_norm_k = nn.LayerNorm(d_model)
206 | self.layer_norm_v = nn.LayerNorm(d_model)
207 |
208 | self.attention = ScaledDotProductAttention(temperature=self.d_k ** 0.5)
209 | self.fc = nn.Linear(n_head * self.d_v, d_model, bias=False)
210 | self.dropout = nn.Dropout(residual_dropout)
211 |
212 | def forward(self, q, k, v, mask=None):
213 | sz_b, len_q = q.size(0), q.size(1)
214 |
215 | residual = q
216 | q, k, v = self.forward_qkv(q, k, v)
217 |
218 | output, attn = self.attention(q, k, v, mask=mask)
219 |
220 | output = self.forward_output(output, residual, sz_b, len_q)
221 | return output, attn
222 |
223 | def forward_qkv(self, q, k, v):
224 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
225 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
226 |
227 | q = self.layer_norm_q(q)
228 | k = self.layer_norm_k(k)
229 | v = self.layer_norm_v(v)
230 |
231 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
232 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
233 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
234 | q = q.transpose(1, 2)
235 | k = k.transpose(1, 2)
236 | v = v.transpose(1, 2)
237 | return q, k, v
238 |
239 | def forward_output(self, output, residual, sz_b, len_q):
240 | output = output.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
241 | fc_out = self.fc(output)
242 | output = self.dropout(fc_out)
243 | output = output + residual
244 | return output
245 |
246 |
247 | class ScaledDotProductAttention(nn.Module):
248 | def __init__(self, temperature):
249 | super().__init__()
250 | self.temperature = temperature
251 | self.dropout = nn.Dropout(0.0)
252 | self.INF = float('inf')
253 |
254 | def forward(self, q, k, v, mask=None):
255 | attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature
256 | output, attn = self.forward_attention(attn, v, mask)
257 | return output, attn
258 |
259 | def forward_attention(self, attn, v, mask=None):
260 | if mask is not None:
261 | mask = mask.unsqueeze(1)
262 | mask = mask.eq(0)
263 | attn = attn.masked_fill(mask, -self.INF)
264 | attn = torch.softmax(attn, dim=-1).masked_fill(mask, 0.0)
265 | else:
266 | attn = torch.softmax(attn, dim=-1)
267 |
268 | d_attn = self.dropout(attn)
269 | output = torch.matmul(d_attn, v)
270 |
271 | return output, attn
272 |
273 |
274 | class RelPosMultiHeadAttention(EncoderMultiHeadAttention):
275 | def __init__(self, n_head, d_model,
276 | residual_dropout=0.1):
277 | super().__init__(n_head, d_model,
278 | residual_dropout)
279 | d_k = d_model // n_head
280 | self.scale = 1.0 / (d_k ** 0.5)
281 | self.linear_pos = nn.Linear(d_model, n_head * d_k, bias=False)
282 | self.pos_bias_u = nn.Parameter(torch.FloatTensor(n_head, d_k))
283 | self.pos_bias_v = nn.Parameter(torch.FloatTensor(n_head, d_k))
284 | torch.nn.init.xavier_uniform_(self.pos_bias_u)
285 | torch.nn.init.xavier_uniform_(self.pos_bias_v)
286 |
287 | def _rel_shift(self, x):
288 | N, H, T1, T2 = x.size()
289 | zero_pad = torch.zeros((N, H, T1, 1), device=x.device, dtype=x.dtype)
290 | x_padded = torch.cat([zero_pad, x], dim=-1)
291 |
292 | x_padded = x_padded.view(N, H, T2 + 1, T1)
293 | x = x_padded[:, :, 1:].view_as(x)
294 | x = x[:, :, :, : x.size(-1) // 2 + 1]
295 | return x
296 |
297 | def forward(self, q, k, v, pos_emb, mask=None):
298 | sz_b, len_q = q.size(0), q.size(1)
299 |
300 | residual = q
301 | q, k, v = self.forward_qkv(q, k, v)
302 |
303 | q = q.transpose(1, 2)
304 | n_batch_pos = pos_emb.size(0)
305 | p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.n_head, self.d_k)
306 | p = p.transpose(1, 2)
307 |
308 | q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
309 | q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
310 |
311 | matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
312 |
313 | matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
314 | matrix_bd = self._rel_shift(matrix_bd)
315 |
316 | attn_scores = matrix_ac + matrix_bd
317 | attn_scores.mul_(self.scale)
318 |
319 | output, attn = self.attention.forward_attention(attn_scores, v, mask=mask)
320 |
321 | output = self.forward_output(output, residual, sz_b, len_q)
322 | return output, attn
323 |
--------------------------------------------------------------------------------
/fireredasr/models/module/transformer_decoder.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Dict
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch import Tensor
7 |
8 |
9 | class TransformerDecoder(nn.Module):
10 | def __init__(
11 | self, sos_id, eos_id, pad_id, odim,
12 | n_layers, n_head, d_model,
13 | residual_dropout=0.1, pe_maxlen=5000):
14 | super().__init__()
15 | self.INF = 1e10
16 | # parameters
17 | self.pad_id = pad_id
18 | self.sos_id = sos_id
19 | self.eos_id = eos_id
20 | self.n_layers = n_layers
21 |
22 | # Components
23 | self.tgt_word_emb = nn.Embedding(odim, d_model, padding_idx=self.pad_id)
24 | self.positional_encoding = PositionalEncoding(d_model, max_len=pe_maxlen)
25 | self.dropout = nn.Dropout(residual_dropout)
26 |
27 | self.layer_stack = nn.ModuleList()
28 | for l in range(n_layers):
29 | block = DecoderLayer(d_model, n_head, residual_dropout)
30 | self.layer_stack.append(block)
31 |
32 | self.tgt_word_prj = nn.Linear(d_model, odim, bias=False)
33 | self.layer_norm_out = nn.LayerNorm(d_model)
34 |
35 | self.tgt_word_prj.weight = self.tgt_word_emb.weight
36 | self.scale = (d_model ** 0.5)
37 |
38 | def batch_beam_search(self, encoder_outputs, src_masks,
39 | beam_size=1, nbest=1, decode_max_len=0,
40 | softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0):
41 | B = beam_size
42 | N, Ti, H = encoder_outputs.size()
43 | device = encoder_outputs.device
44 | maxlen = decode_max_len if decode_max_len > 0 else Ti
45 | assert eos_penalty > 0.0 and eos_penalty <= 1.0
46 |
47 | # Init
48 | encoder_outputs = encoder_outputs.unsqueeze(1).repeat(1, B, 1, 1).view(N*B, Ti, H)
49 | src_mask = src_masks.unsqueeze(1).repeat(1, B, 1, 1).view(N*B, -1, Ti)
50 | ys = torch.ones(N*B, 1).fill_(self.sos_id).long().to(device)
51 | caches: List[Optional[Tensor]] = []
52 | for _ in range(self.n_layers):
53 | caches.append(None)
54 | scores = torch.tensor([0.0] + [-self.INF]*(B-1)).float().to(device)
55 | scores = scores.repeat(N).view(N*B, 1)
56 | is_finished = torch.zeros_like(scores)
57 |
58 | # Autoregressive Prediction
59 | for t in range(maxlen):
60 | tgt_mask = self.ignored_target_position_is_0(ys, self.pad_id)
61 |
62 | dec_output = self.dropout(
63 | self.tgt_word_emb(ys) * self.scale +
64 | self.positional_encoding(ys))
65 |
66 | i = 0
67 | for dec_layer in self.layer_stack:
68 | dec_output = dec_layer.forward(
69 | dec_output, encoder_outputs,
70 | tgt_mask, src_mask,
71 | cache=caches[i])
72 | caches[i] = dec_output
73 | i += 1
74 |
75 | dec_output = self.layer_norm_out(dec_output)
76 |
77 | t_logit = self.tgt_word_prj(dec_output[:, -1])
78 | t_scores = F.log_softmax(t_logit / softmax_smoothing, dim=-1)
79 |
80 | if eos_penalty != 1.0:
81 | t_scores[:, self.eos_id] *= eos_penalty
82 |
83 | t_topB_scores, t_topB_ys = torch.topk(t_scores, k=B, dim=1)
84 | t_topB_scores = self.set_finished_beam_score_to_zero(t_topB_scores, is_finished)
85 | t_topB_ys = self.set_finished_beam_y_to_eos(t_topB_ys, is_finished)
86 |
87 | # Accumulated
88 | scores = scores + t_topB_scores
89 |
90 | # Pruning
91 | scores = scores.view(N, B*B)
92 | scores, topB_score_ids = torch.topk(scores, k=B, dim=1)
93 | scores = scores.view(-1, 1)
94 |
95 | topB_row_number_in_each_B_rows_of_ys = torch.div(topB_score_ids, B).view(N*B)
96 | stride = B * torch.arange(N).view(N, 1).repeat(1, B).view(N*B).to(device)
97 | topB_row_number_in_ys = topB_row_number_in_each_B_rows_of_ys.long() + stride.long()
98 |
99 | # Update ys
100 | ys = ys[topB_row_number_in_ys]
101 | t_ys = torch.gather(t_topB_ys.view(N, B*B), dim=1, index=topB_score_ids).view(N*B, 1)
102 | ys = torch.cat((ys, t_ys), dim=1)
103 |
104 | # Update caches
105 | new_caches: List[Optional[Tensor]] = []
106 | for cache in caches:
107 | if cache is not None:
108 | new_caches.append(cache[topB_row_number_in_ys])
109 | caches = new_caches
110 |
111 | # Update finished state
112 | is_finished = t_ys.eq(self.eos_id)
113 | if is_finished.sum().item() == N*B:
114 | break
115 |
116 | # Length penalty (follow GNMT)
117 | scores = scores.view(N, B)
118 | ys = ys.view(N, B, -1)
119 | ys_lengths = self.get_ys_lengths(ys)
120 | if length_penalty > 0.0:
121 | penalty = torch.pow((5+ys_lengths.float())/(5.0+1), length_penalty)
122 | scores /= penalty
123 | nbest_scores, nbest_ids = torch.topk(scores, k=int(nbest), dim=1)
124 | nbest_scores = -1.0 * nbest_scores
125 | index = nbest_ids + B * torch.arange(N).view(N, 1).to(device).long()
126 | nbest_ys = ys.view(N*B, -1)[index.view(-1)]
127 | nbest_ys = nbest_ys.view(N, nbest_ids.size(1), -1)
128 | nbest_ys_lengths = ys_lengths.view(N*B)[index.view(-1)].view(N, -1)
129 |
130 | # result
131 | nbest_hyps: List[List[Dict[str, Tensor]]] = []
132 | for n in range(N):
133 | n_nbest_hyps: List[Dict[str, Tensor]] = []
134 | for i, score in enumerate(nbest_scores[n]):
135 | new_hyp = {
136 | "yseq": nbest_ys[n, i, 1:nbest_ys_lengths[n, i]]
137 | }
138 | n_nbest_hyps.append(new_hyp)
139 | nbest_hyps.append(n_nbest_hyps)
140 | return nbest_hyps
141 |
142 | def ignored_target_position_is_0(self, padded_targets, ignore_id):
143 | mask = torch.ne(padded_targets, ignore_id)
144 | mask = mask.unsqueeze(dim=1)
145 | T = padded_targets.size(-1)
146 | upper_tri_0_mask = self.upper_triangular_is_0(T).unsqueeze(0).to(mask.dtype)
147 | upper_tri_0_mask = upper_tri_0_mask.to(mask.dtype).to(mask.device)
148 | return mask.to(torch.uint8) & upper_tri_0_mask.to(torch.uint8)
149 |
150 | def upper_triangular_is_0(self, size):
151 | ones = torch.ones(size, size)
152 | tri_left_ones = torch.tril(ones)
153 | return tri_left_ones.to(torch.uint8)
154 |
155 | def set_finished_beam_score_to_zero(self, scores, is_finished):
156 | NB, B = scores.size()
157 | is_finished = is_finished.float()
158 | mask_score = torch.tensor([0.0] + [-self.INF]*(B-1)).float().to(scores.device)
159 | mask_score = mask_score.view(1, B).repeat(NB, 1)
160 | return scores * (1 - is_finished) + mask_score * is_finished
161 |
162 | def set_finished_beam_y_to_eos(self, ys, is_finished):
163 | is_finished = is_finished.long()
164 | return ys * (1 - is_finished) + self.eos_id * is_finished
165 |
166 | def get_ys_lengths(self, ys):
167 | N, B, Tmax = ys.size()
168 | ys_lengths = torch.sum(torch.ne(ys, self.eos_id), dim=-1)
169 | return ys_lengths.int()
170 |
171 |
172 |
173 | class DecoderLayer(nn.Module):
174 | def __init__(self, d_model, n_head, dropout):
175 | super().__init__()
176 | self.self_attn_norm = nn.LayerNorm(d_model)
177 | self.self_attn = DecoderMultiHeadAttention(d_model, n_head, dropout)
178 |
179 | self.cross_attn_norm = nn.LayerNorm(d_model)
180 | self.cross_attn = DecoderMultiHeadAttention(d_model, n_head, dropout)
181 |
182 | self.mlp_norm = nn.LayerNorm(d_model)
183 | self.mlp = PositionwiseFeedForward(d_model, d_model*4, dropout)
184 |
185 | def forward(self, dec_input, enc_output, self_attn_mask, cross_attn_mask,
186 | cache=None):
187 | x = dec_input
188 | residual = x
189 | x = self.self_attn_norm(x)
190 | if cache is not None:
191 | xq = x[:, -1:, :]
192 | residual = residual[:, -1:, :]
193 | self_attn_mask = self_attn_mask[:, -1:, :]
194 | else:
195 | xq = x
196 | x = self.self_attn(xq, x, x, mask=self_attn_mask)
197 | x = residual + x
198 |
199 | residual = x
200 | x = self.cross_attn_norm(x)
201 | x = self.cross_attn(x, enc_output, enc_output, mask=cross_attn_mask)
202 | x = residual + x
203 |
204 | residual = x
205 | x = self.mlp_norm(x)
206 | x = residual + self.mlp(x)
207 |
208 | if cache is not None:
209 | x = torch.cat([cache, x], dim=1)
210 |
211 | return x
212 |
213 |
214 | class DecoderMultiHeadAttention(nn.Module):
215 | def __init__(self, d_model, n_head, dropout=0.1):
216 | super().__init__()
217 | self.d_model = d_model
218 | self.n_head = n_head
219 | self.d_k = d_model // n_head
220 |
221 | self.w_qs = nn.Linear(d_model, n_head * self.d_k)
222 | self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False)
223 | self.w_vs = nn.Linear(d_model, n_head * self.d_k)
224 |
225 | self.attention = DecoderScaledDotProductAttention(
226 | temperature=self.d_k ** 0.5)
227 | self.fc = nn.Linear(n_head * self.d_k, d_model)
228 | self.dropout = nn.Dropout(dropout)
229 |
230 | def forward(self, q, k, v, mask=None):
231 | bs = q.size(0)
232 |
233 | q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k)
234 | k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k)
235 | v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k)
236 | q = q.transpose(1, 2)
237 | k = k.transpose(1, 2)
238 | v = v.transpose(1, 2)
239 |
240 | if mask is not None:
241 | mask = mask.unsqueeze(1)
242 |
243 | output = self.attention(q, k, v, mask=mask)
244 |
245 | output = output.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
246 | output = self.fc(output)
247 | output = self.dropout(output)
248 |
249 | return output
250 |
251 |
252 | class DecoderScaledDotProductAttention(nn.Module):
253 | def __init__(self, temperature):
254 | super().__init__()
255 | self.temperature = temperature
256 | self.INF = float("inf")
257 |
258 | def forward(self, q, k, v, mask=None):
259 | attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature
260 | if mask is not None:
261 | mask = mask.eq(0)
262 | attn = attn.masked_fill(mask, -self.INF)
263 | attn = torch.softmax(attn, dim=-1).masked_fill(mask, 0.0)
264 | else:
265 | attn = torch.softmax(attn, dim=-1)
266 | output = torch.matmul(attn, v)
267 | return output
268 |
269 |
270 | class PositionwiseFeedForward(nn.Module):
271 | def __init__(self, d_model, d_ff, dropout=0.1):
272 | super().__init__()
273 | self.w_1 = nn.Linear(d_model, d_ff)
274 | self.act = nn.GELU()
275 | self.w_2 = nn.Linear(d_ff, d_model)
276 | self.dropout = nn.Dropout(dropout)
277 |
278 | def forward(self, x):
279 | output = self.w_2(self.act(self.w_1(x)))
280 | output = self.dropout(output)
281 | return output
282 |
283 |
284 | class PositionalEncoding(nn.Module):
285 | def __init__(self, d_model, max_len=5000):
286 | super().__init__()
287 | assert d_model % 2 == 0
288 | pe = torch.zeros(max_len, d_model, requires_grad=False)
289 | position = torch.arange(0, max_len).unsqueeze(1).float()
290 | div_term = torch.exp(torch.arange(0, d_model, 2).float() *
291 | -(torch.log(torch.tensor(10000.0)).item()/d_model))
292 | pe[:, 0::2] = torch.sin(position * div_term)
293 | pe[:, 1::2] = torch.cos(position * div_term)
294 | pe = pe.unsqueeze(0)
295 | self.register_buffer('pe', pe)
296 |
297 | def forward(self, x):
298 | length = x.size(1)
299 | return self.pe[:, :length].clone().detach()
300 |
--------------------------------------------------------------------------------
/fireredasr/speech2text.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import argparse
4 | import glob
5 | import os
6 | import sys
7 |
8 | from fireredasr.models.fireredasr import FireRedAsr
9 |
10 |
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--asr_type', type=str, required=True, choices=["aed", "llm"])
13 | parser.add_argument('--model_dir', type=str, required=True)
14 |
15 | # Input / Output
16 | parser.add_argument("--wav_path", type=str)
17 | parser.add_argument("--wav_paths", type=str, nargs="*")
18 | parser.add_argument("--wav_dir", type=str)
19 | parser.add_argument("--wav_scp", type=str)
20 | parser.add_argument("--output", type=str)
21 |
22 | # Decode Options
23 | parser.add_argument('--use_gpu', type=int, default=1)
24 | parser.add_argument("--batch_size", type=int, default=1)
25 | parser.add_argument("--beam_size", type=int, default=1)
26 | parser.add_argument("--decode_max_len", type=int, default=0)
27 | # FireRedASR-AED
28 | parser.add_argument("--nbest", type=int, default=1)
29 | parser.add_argument("--softmax_smoothing", type=float, default=1.0)
30 | parser.add_argument("--aed_length_penalty", type=float, default=0.0)
31 | parser.add_argument("--eos_penalty", type=float, default=1.0)
32 | # FireRedASR-LLM
33 | parser.add_argument("--decode_min_len", type=int, default=0)
34 | parser.add_argument("--repetition_penalty", type=float, default=1.0)
35 | parser.add_argument("--llm_length_penalty", type=float, default=0.0)
36 | parser.add_argument("--temperature", type=float, default=1.0)
37 |
38 |
39 | def main(args):
40 | wavs = get_wav_info(args)
41 | fout = open(args.output, "w") if args.output else None
42 |
43 | model = FireRedAsr.from_pretrained(args.asr_type, args.model_dir)
44 |
45 | batch_uttid = []
46 | batch_wav_path = []
47 | for i, wav in enumerate(wavs):
48 | uttid, wav_path = wav
49 | batch_uttid.append(uttid)
50 | batch_wav_path.append(wav_path)
51 | if len(batch_wav_path) < args.batch_size and i != len(wavs) - 1:
52 | continue
53 |
54 | results = model.transcribe(
55 | batch_uttid,
56 | batch_wav_path,
57 | {
58 | "use_gpu": args.use_gpu,
59 | "beam_size": args.beam_size,
60 | "nbest": args.nbest,
61 | "decode_max_len": args.decode_max_len,
62 | "softmax_smoothing": args.softmax_smoothing,
63 | "aed_length_penalty": args.aed_length_penalty,
64 | "eos_penalty": args.eos_penalty,
65 | "decode_min_len": args.decode_min_len,
66 | "repetition_penalty": args.repetition_penalty,
67 | "llm_length_penalty": args.llm_length_penalty,
68 | "temperature": args.temperature
69 | }
70 | )
71 |
72 | for result in results:
73 | print(result)
74 | if fout is not None:
75 | fout.write(f"{result['uttid']}\t{result['text']}\n")
76 |
77 | batch_uttid = []
78 | batch_wav_path = []
79 |
80 |
81 | def get_wav_info(args):
82 | """
83 | Returns:
84 | wavs: list of (uttid, wav_path)
85 | """
86 | base = lambda p: os.path.basename(p).replace(".wav", "")
87 | if args.wav_path:
88 | wavs = [(base(args.wav_path), args.wav_path)]
89 | elif args.wav_paths and len(args.wav_paths) >= 1:
90 | wavs = [(base(p), p) for p in sorted(args.wav_paths)]
91 | elif args.wav_scp:
92 | wavs = [line.strip().split() for line in open(args.wav_scp)]
93 | elif args.wav_dir:
94 | wavs = glob.glob(f"{args.wav_dir}/**/*.wav", recursive=True)
95 | wavs = [(base(p), p) for p in sorted(wavs)]
96 | else:
97 | raise ValueError("Please provide valid wav info")
98 | print(f"#wavs={len(wavs)}")
99 | return wavs
100 |
101 |
102 | if __name__ == "__main__":
103 | args = parser.parse_args()
104 | print(args)
105 | main(args)
106 |
--------------------------------------------------------------------------------
/fireredasr/tokenizer/aed_tokenizer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import re
3 |
4 | import sentencepiece as spm
5 |
6 | from fireredasr.data.token_dict import TokenDict
7 |
8 |
9 | class ChineseCharEnglishSpmTokenizer:
10 | """
11 | - One Chinese char is a token.
12 | - Split English word into SPM and one piece is a token.
13 | - Ignore ' ' between Chinese char
14 | - Replace ' ' between English word with "▁" by spm_model
15 | - Need to put SPM piece into dict file
16 | - If not set spm_model, will use English char and
17 | """
18 | SPM_SPACE = "▁"
19 |
20 | def __init__(self, dict_path, spm_model, unk="", space=""):
21 | self.dict = TokenDict(dict_path, unk=unk)
22 | self.space = space
23 | if spm_model:
24 | self.sp = spm.SentencePieceProcessor()
25 | self.sp.Load(spm_model)
26 | else:
27 | self.sp = None
28 | print("[WRAN] Not set spm_model, will use English char")
29 | print("[WARN] Please check how to deal with ' '(space)")
30 | if self.space not in self.dict:
31 | print("Please add to your dict, or it will be ")
32 |
33 | def tokenize(self, text, replace_punc=True):
34 | #if text == "":
35 | # logging.info(f"empty text")
36 | text = text.upper()
37 | tokens = []
38 | if replace_punc:
39 | text = re.sub("[,。?!,\.?!]", " ", text)
40 | pattern = re.compile(r'([\u3400-\u4dbf\u4e00-\u9fff])')
41 | parts = pattern.split(text.strip())
42 | parts = [p for p in parts if len(p.strip()) > 0]
43 | for part in parts:
44 | if pattern.fullmatch(part) is not None:
45 | tokens.append(part)
46 | else:
47 | if self.sp:
48 | for piece in self.sp.EncodeAsPieces(part.strip()):
49 | tokens.append(piece)
50 | else:
51 | for char in part.strip():
52 | tokens.append(char if char != " " else self.space)
53 | tokens_id = []
54 | for token in tokens:
55 | tokens_id.append(self.dict.get(token, self.dict.unk))
56 | return tokens, tokens_id
57 |
58 | def detokenize(self, inputs, join_symbol="", replace_spm_space=True):
59 | """inputs is ids or tokens, do not need self.sp"""
60 | if len(inputs) > 0 and type(inputs[0]) == int:
61 | tokens = [self.dict[id] for id in inputs]
62 | else:
63 | tokens = inputs
64 | s = f"{join_symbol}".join(tokens)
65 | if replace_spm_space:
66 | s = s.replace(self.SPM_SPACE, ' ').strip()
67 | return s
68 |
--------------------------------------------------------------------------------
/fireredasr/tokenizer/llm_tokenizer.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import torch
4 | from transformers import AutoTokenizer
5 | from transformers.trainer_pt_utils import LabelSmoother
6 |
7 | DEFAULT_SPEECH_TOKEN = ""
8 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index
9 |
10 |
11 | class LlmTokenizerWrapper:
12 | @classmethod
13 | def build_llm_tokenizer(cls, llm_path, use_flash_attn=False):
14 | tokenizer = AutoTokenizer.from_pretrained(llm_path)
15 | if use_flash_attn:
16 | tokenizer.padding_side = "left"
17 | else:
18 | tokenizer.padding_side = "right"
19 | special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
20 | tokenizer.add_special_tokens(special_tokens_dict)
21 | return tokenizer
22 |
23 | @classmethod
24 | def clean_text(cls, origin_text):
25 | """remove punc, remove space between Chinese and keep space between English"""
26 | # remove punc
27 | text = re.sub("[,。?!,\.!?《》()\·“”、\\/]", "", origin_text)
28 | # merge space
29 | text = re.sub("\s+", " ", text)
30 |
31 | # remove space between Chinese and keep space between English
32 | pattern = re.compile(r'([\u3400-\u4dbf\u4e00-\u9fff])') # Chinese
33 | parts = pattern.split(text.strip())
34 | parts = [p for p in parts if len(p.strip()) > 0]
35 | text = "".join(parts)
36 | text = text.strip()
37 |
38 | text = text.lower()
39 | return text
40 |
41 | @classmethod
42 | def preprocess_texts(cls, origin_texts, tokenizer, max_len, decode=False):
43 | messages = []
44 | clean_texts = []
45 | for i, origin_text in enumerate(origin_texts):
46 | text = cls.clean_text(origin_text)
47 | clean_texts.append(text)
48 | text = text if not decode else ""
49 | message = [
50 | {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
51 | {"role": "assistant", "content": text},
52 | ]
53 | messages.append(message)
54 |
55 | texts = []
56 | if not decode:
57 | TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
58 | else:
59 | TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
60 | for i, msg in enumerate(messages):
61 | texts.append(
62 | tokenizer.apply_chat_template(
63 | msg,
64 | tokenize=True,
65 | chat_template=TEMPLATE,
66 | add_generation_prompt=False,
67 | padding="longest",
68 | max_length=max_len,
69 | truncation=True,
70 | )
71 | )
72 |
73 | # Padding texts
74 | max_len_texts = max([len(text) for text in texts])
75 | if tokenizer.padding_side == "right":
76 | texts = [
77 | text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
78 | for text in texts
79 | ]
80 | else:
81 | texts = [
82 | [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
83 | for text in texts
84 | ]
85 | input_ids = torch.tensor(texts, dtype=torch.int)
86 |
87 | target_ids = input_ids.clone()
88 | target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
89 |
90 | # first get the indices of the tokens
91 | mask_prompt = True
92 | if mask_prompt:
93 | mask_indices = torch.where(
94 | input_ids == tokenizer.convert_tokens_to_ids("assistant")
95 | )
96 | for i in range(mask_indices[0].size(0)):
97 | row = mask_indices[0][i]
98 | col = mask_indices[1][i]
99 | target_ids[row, : col + 2] = IGNORE_TOKEN_ID
100 |
101 | attention_mask = input_ids.ne(tokenizer.pad_token_id)
102 |
103 | target_ids = target_ids.type(torch.LongTensor)
104 | input_ids = input_ids.type(torch.LongTensor)
105 | return input_ids, attention_mask, target_ids, clean_texts
106 |
--------------------------------------------------------------------------------
/fireredasr/utils/param.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 |
5 |
6 | def count_model_parameters(model):
7 | if not isinstance(model, torch.nn.Module):
8 | return 0, 0
9 | name = f"{model.__class__.__name__} {model.__class__}"
10 | num = sum(p.numel() for p in model.parameters() if p.requires_grad)
11 | size = num * 4.0 / 1024.0 / 1024.0 # float32, MB
12 | logging.info(f"#param of {name} is {num} = {size:.1f} MB (float32)")
13 | return num, size
14 |
--------------------------------------------------------------------------------
/fireredasr/utils/wer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import argparse
4 | import re
5 | from collections import OrderedDict
6 |
7 |
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument("--ref", type=str, required=True)
10 | parser.add_argument("--hyp", type=str, required=True)
11 | parser.add_argument("--print_sentence_wer", type=int, default=0)
12 | parser.add_argument("--do_tn", type=int, default=0, help="simple tn by cn2an")
13 | parser.add_argument("--rm_special", type=int, default=0, help="remove <\|.*?\|>")
14 |
15 |
16 | def main(args):
17 | uttid2refs = read_uttid2tokens(args.ref, args.do_tn, args.rm_special)
18 | uttid2hyps = read_uttid2tokens(args.hyp, args.do_tn, args.rm_special)
19 | uttid2wer_info, wer_stat, en_dig_stat = compute_uttid2wer_info(
20 | uttid2refs, uttid2hyps, args.print_sentence_wer)
21 | wer_stat.print()
22 | en_dig_stat.print()
23 |
24 |
25 | def read_uttid2tokens(filename, do_tn=False, rm_special=False):
26 | print(f">>> Read uttid to tokens: {filename}", flush=True)
27 | uttid2tokens = OrderedDict()
28 | uttid2text = read_uttid2text(filename, do_tn, rm_special)
29 | for uttid, text in uttid2text.items():
30 | tokens = text2tokens(text)
31 | uttid2tokens[uttid] = tokens
32 | return uttid2tokens
33 |
34 |
35 | def read_uttid2text(filename, do_tn=False, rm_special=False):
36 | uttid2text = OrderedDict()
37 | with open(filename, "r", encoding="utf8") as fin:
38 | for i, line in enumerate(fin):
39 | cols = line.split()
40 | if len(cols) == 0:
41 | print("[WARN] empty line, continue", i, flush=True)
42 | continue
43 | assert cols[0] not in uttid2text, f"repeated uttid: {line}"
44 | if len(cols) == 1:
45 | uttid2text[cols[0]] = ""
46 | continue
47 | txt = " ".join(cols[1:])
48 | if rm_special:
49 | txt = " ".join([t for t in re.split("<\|.*?\|>", txt) if t.strip() != ""])
50 | if do_tn:
51 | import cn2an
52 | txt = cn2an.transform(txt, "an2cn")
53 | uttid2text[cols[0]] = txt
54 | return uttid2text
55 |
56 |
57 | def text2tokens(text):
58 | PUNCTUATIONS = ",。?!,\.?!"#$%&'()*+-/:;<=>@[\]^_`{|}~⦅⦆「」、 、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·。\":" + "()\[\]{}/;`|=+"
59 | if text == "":
60 | return []
61 | tokens = []
62 |
63 | text = re.sub("", "", text)
64 | text = re.sub(r"[%s]+" % PUNCTUATIONS, " ", text)
65 |
66 | pattern = re.compile(r'([\u4e00-\u9fff])')
67 | parts = pattern.split(text.strip().upper())
68 | parts = [p for p in parts if len(p.strip()) > 0]
69 | for part in parts:
70 | if pattern.fullmatch(part) is not None:
71 | tokens.append(part)
72 | else:
73 | for word in part.strip().split():
74 | tokens.append(word)
75 | return tokens
76 |
77 |
78 | def compute_uttid2wer_info(refs, hyps, print_sentence_wer=False):
79 | print(f">>> Compute uttid to wer info", flush=True)
80 |
81 | uttid2wer_info = OrderedDict()
82 | wer_stat = WerStats()
83 | en_dig_stat = EnDigStats()
84 |
85 | for uttid, ref in refs.items():
86 | if uttid not in hyps:
87 | print(f"[WARN] No hyp for {uttid}", flush=True)
88 | continue
89 | hyp = hyps[uttid]
90 |
91 | if len(hyp) - len(ref) >= 8:
92 | print(f"[BidLengthDiff]: {uttid} {len(ref)} {len(hyp)}#{' '.join(ref)}#{' '.join(hyp)}")
93 | #continue
94 |
95 | wer_info = compute_one_wer_info(ref, hyp)
96 | uttid2wer_info[uttid] = wer_info
97 | ns = count_english_ditgit(ref, hyp, wer_info)
98 | wer_stat.add(wer_info)
99 | en_dig_stat.add(*ns)
100 | if print_sentence_wer:
101 | print(f"{uttid} {wer_info}")
102 |
103 | return uttid2wer_info, wer_stat, en_dig_stat
104 |
105 |
106 | COST_SUB = 3
107 | COST_DEL = 3
108 | COST_INS = 3
109 |
110 | ALIGN_CRT = 0
111 | ALIGN_SUB = 1
112 | ALIGN_DEL = 2
113 | ALIGN_INS = 3
114 | ALIGN_END = 4
115 |
116 |
117 | def compute_one_wer_info(ref, hyp):
118 | """Impl minimum edit distance and backtrace.
119 | Args:
120 | ref, hyp: List[str]
121 | Returns:
122 | WerInfo
123 | """
124 | ref_len = len(ref)
125 | hyp_len = len(hyp)
126 |
127 | class _DpPoint:
128 | def __init__(self, cost, align):
129 | self.cost = cost
130 | self.align = align
131 |
132 | dp = []
133 | for i in range(0, ref_len + 1):
134 | dp.append([])
135 | for j in range(0, hyp_len + 1):
136 | dp[-1].append(_DpPoint(i * j, ALIGN_CRT))
137 |
138 | # Initialize
139 | for i in range(1, hyp_len + 1):
140 | dp[0][i].cost = dp[0][i - 1].cost + COST_INS;
141 | dp[0][i].align = ALIGN_INS
142 | for i in range(1, ref_len + 1):
143 | dp[i][0].cost = dp[i - 1][0].cost + COST_DEL
144 | dp[i][0].align = ALIGN_DEL
145 |
146 | # DP
147 | for i in range(1, ref_len + 1):
148 | for j in range(1, hyp_len + 1):
149 | min_cost = 0
150 | min_align = ALIGN_CRT
151 | if hyp[j - 1] == ref[i - 1]:
152 | min_cost = dp[i - 1][j - 1].cost
153 | min_align = ALIGN_CRT
154 | else:
155 | min_cost = dp[i - 1][j - 1].cost + COST_SUB
156 | min_align = ALIGN_SUB
157 |
158 | del_cost = dp[i - 1][j].cost + COST_DEL
159 | if del_cost < min_cost:
160 | min_cost = del_cost
161 | min_align = ALIGN_DEL
162 |
163 | ins_cost = dp[i][j - 1].cost + COST_INS
164 | if ins_cost < min_cost:
165 | min_cost = ins_cost
166 | min_align = ALIGN_INS
167 |
168 | dp[i][j].cost = min_cost
169 | dp[i][j].align = min_align
170 |
171 | # Backtrace
172 | crt = sub = ins = det = 0
173 | i = ref_len
174 | j = hyp_len
175 | align = []
176 | while i > 0 or j > 0:
177 | if dp[i][j].align == ALIGN_CRT:
178 | align.append((i, j, ALIGN_CRT))
179 | i -= 1
180 | j -= 1
181 | crt += 1
182 | elif dp[i][j].align == ALIGN_SUB:
183 | align.append((i, j, ALIGN_SUB))
184 | i -= 1
185 | j -= 1
186 | sub += 1
187 | elif dp[i][j].align == ALIGN_DEL:
188 | align.append((i, j, ALIGN_DEL))
189 | i -= 1
190 | det += 1
191 | elif dp[i][j].align == ALIGN_INS:
192 | align.append((i, j, ALIGN_INS))
193 | j -= 1
194 | ins += 1
195 |
196 | err = sub + det + ins
197 | align.reverse()
198 | wer_info = WerInfo(ref_len, err, crt, sub, det, ins, align)
199 | return wer_info
200 |
201 |
202 |
203 | class WerInfo:
204 | def __init__(self, ref, err, crt, sub, dele, ins, ali):
205 | self.r = ref
206 | self.e = err
207 | self.c = crt
208 | self.s = sub
209 | self.d = dele
210 | self.i = ins
211 | self.ali = ali
212 | r = max(self.r, 1)
213 | self.wer = 100.0 * (self.s + self.d + self.i) / r
214 |
215 | def __repr__(self):
216 | s = f"wer {self.wer:.2f} ref {self.r:2d} sub {self.s:2d} del {self.d:2d} ins {self.i:2d}"
217 | return s
218 |
219 |
220 | class WerStats:
221 | def __init__(self):
222 | self.infos = []
223 |
224 | def add(self, wer_info):
225 | self.infos.append(wer_info)
226 |
227 | def print(self):
228 | r = sum(info.r for info in self.infos)
229 | if r <= 0:
230 | print(f"REF len is {r}, check")
231 | r = 1
232 | s = sum(info.s for info in self.infos)
233 | d = sum(info.d for info in self.infos)
234 | i = sum(info.i for info in self.infos)
235 | se = 100.0 * s / r
236 | de = 100.0 * d / r
237 | ie = 100.0 * i / r
238 | wer = 100.0 * (s + d + i) / r
239 | sen = max(len(self.infos), 1)
240 | errsen = sum(info.e > 0 for info in self.infos)
241 | ser = 100.0 * errsen / sen
242 | print("-"*80)
243 | print(f"ref{r:6d} sub{s:6d} del{d:6d} ins{i:6d}")
244 | print(f"WER{wer:6.2f} sub{se:6.2f} del{de:6.2f} ins{ie:6.2f}")
245 | print(f"SER{ser:6.2f} = {errsen} / {sen}")
246 | print("-"*80)
247 |
248 |
249 | class EnDigStats:
250 | def __init__(self):
251 | self.n_en_word = 0
252 | self.n_en_correct = 0
253 | self.n_dig_word = 0
254 | self.n_dig_correct = 0
255 |
256 | def add(self, n_en_word, n_en_correct, n_dig_word, n_dig_correct):
257 | self.n_en_word += n_en_word
258 | self.n_en_correct += n_en_correct
259 | self.n_dig_word += n_dig_word
260 | self.n_dig_correct += n_dig_correct
261 |
262 | def print(self):
263 | print(f"English #word={self.n_en_word}, #correct={self.n_en_correct}\n"
264 | f"Digit #word={self.n_dig_word}, #correct={self.n_dig_correct}")
265 | print("-"*80)
266 |
267 |
268 |
269 | def count_english_ditgit(ref, hyp, wer_info):
270 | patt_en = "[a-zA-Z\.\-\']+"
271 | patt_dig = "[0-9]+"
272 | patt_cjk = re.compile(r'([\u4e00-\u9fff])')
273 | n_en_word = 0
274 | n_en_correct = 0
275 | n_dig_word = 0
276 | n_dig_correct = 0
277 | ali = wer_info.ali
278 | for i, token in enumerate(ref):
279 | if re.match(patt_en, token):
280 | n_en_word += 1
281 | for y in ali:
282 | if y[0] == i+1 and y[2] == ALIGN_CRT:
283 | j = y[1] - 1
284 | n_en_correct += 1
285 | break
286 | if re.match(patt_dig, token):
287 | n_dig_word += 1
288 | for y in ali:
289 | if y[0] == i+1 and y[2] == ALIGN_CRT:
290 | j = y[1] - 1
291 | n_dig_correct += 1
292 | break
293 | if not re.match(patt_cjk, token) and not re.match(patt_en, token) \
294 | and not re.match(patt_dig, token):
295 | print("[WiredChar]:", token)
296 | return n_en_word, n_en_correct, n_dig_word, n_dig_correct
297 |
298 |
299 |
300 | if __name__ == "__main__":
301 | args = parser.parse_args()
302 | print(args, flush=True)
303 | main(args)
304 |
--------------------------------------------------------------------------------
/pretrained_models/README.md:
--------------------------------------------------------------------------------
1 | Put pretrained models here.
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cn2an>=0.5.23
2 | kaldiio>=2.18.0
3 | kaldi_native_fbank>=1.15
4 | numpy>=1.26.1
5 | peft>=0.13.2
6 | sentencepiece
7 | torch>=2.0.0
8 | transformers>=4.46.3
9 |
--------------------------------------------------------------------------------