├── .gitignore ├── LICENSE ├── README.md ├── audio ├── 0000c7286ebc7edef1c505b78d5ed1a3.wav ├── 0000e12e2402775c2d506d77b6dbb411.wav ├── 000af5671fdbaa3e55c5e2bd0bdf8cdd_hi.pcm ├── 000af5671fdbaa3e55c5e2bd0bdf8cdd_hi.wav ├── 000eae543947c70feb9401f82da03dcf_hi.wav └── gongqu-4.5_0000.wav ├── docs ├── model_convert.md └── model_train.md ├── model_convert ├── export_onnx.py ├── model │ ├── classifier.py │ ├── cmvn.py │ ├── fsmn.py │ ├── kws_model.py │ ├── loss.py │ ├── mdtc.py │ ├── subsampling.py │ └── tcn.py └── utils │ ├── checkpoint.py │ ├── cmvn.py │ ├── executor.py │ ├── file_utils.py │ ├── mask.py │ └── train_utils.py ├── onnxruntime ├── CMakeLists.txt ├── README.md ├── bin │ ├── CMakeLists.txt │ ├── device_test.cc │ ├── kws_main.cc │ ├── stream_kws_main.cc │ └── stream_kws_testing.cc ├── cmake │ ├── onnxruntime.cmake │ └── portaudio.cmake ├── frontend │ ├── CMakeLists.txt │ ├── fbank.h │ ├── feature_pipeline.cc │ ├── feature_pipeline.h │ ├── fft.cc │ ├── fft.h │ └── wav.h ├── kws │ ├── CMakeLists.txt │ ├── keyword_spotting.cc │ ├── keyword_spotting.h │ ├── maxpooling_keyword.txt │ ├── tokens.txt │ ├── utils.cpp │ └── utils.h └── utils │ ├── blocking_queue.h │ └── log.h └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | 5 | # Visual Studio Code files 6 | .vscode 7 | .vs 8 | 9 | # PyCharm files and clion files 10 | .idea 11 | 12 | # Eclipse Project settings 13 | *.*project 14 | .settings 15 | 16 | # Sublime Text settings 17 | *.sublime-workspace 18 | *.sublime-project 19 | 20 | # Editor temporaries 21 | *.swn 22 | *.swo 23 | *.swp 24 | *.swm 25 | *~ 26 | 27 | # IPython notebook checkpoints 28 | .ipynb_checkpoints 29 | 30 | # macOS dir files 31 | .DS_Store 32 | 33 | dict 34 | exp 35 | data 36 | raw_wav 37 | tensorboard 38 | **/*build* 39 | 40 | #third package 41 | fc_base/ 42 | 43 | # 训练过程中产生的中间文件 44 | models/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 ChenYang (cyang8050@163.com) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 关键词检测 2 | 3 | 本工程主要基于[wekws](https://github.com/wenet-e2e/wekws/tree/main)进行构建,旨在搭建基于当下新训练框架(HF, modelscope)来实现更高效、快速的模型训练,微调及部署落地。 4 | 5 | 6 | 7 | # Features 8 | 9 | - 支持CTC和Max-Pooling方案的唤醒词模型推理。 10 | - 支持模型转换,Pytorch2ONNX, ONNX2ORT(端侧部署)。 11 | - 支持CPP onnxruntime流式推理。 12 | - 非流式测试:支持wav、pcm格式的音频输入。 13 | 14 | 15 | 16 | # Change Log 17 | 18 | - 2024/03/21 : 提供完整的CTC和Max-Pooling唤醒词方案的onnx模型cpp推理测试。 19 | - 2024/03/26: 提供模型转换工具,支持模型从Pytorch转换到onnx,再转换到ort用于端侧部署。支持CPP onnxruntime流式推理。 20 | - 2024/04/17: 提供模型训练方法。 21 | 22 | 23 | 24 | 25 | 26 | # 推理测试 27 | 28 | 以下示例 CPP ONNX推理测试. 29 | 30 | ```shell 31 | git clone https://github.com/chenyangMl/keyword-spot.git 32 | cd keyword-spot/onnxruntime/ 33 | mkdir build && cd build 34 | cmake .. 35 | cmake --build . --target kws_main 36 | 37 | #不同模型使用如下对应参数进行模型推理。 38 | ``` 39 | 40 | 41 | 42 | ## Max-Pooling方案模型 43 | 44 | ``` 45 | cd build/bin 46 | ./kws_main [solution_type, int] [num_bins, int] [batch_size, int] [model_path, str] [wave_path,str] 47 | 48 | #eg 49 | ./kws_main 0 40 1 keyword-spot-dstcn-maxpooling-wenwen/onnx/keyword-spot-dstcn-maxpooling-wenwen.ort ../../../audio/0000c7286ebc7edef1c505b78d5ed1a3.wav 50 | ``` 51 | 52 | 53 | 54 | ## CTC 方案模型 55 | 56 | ``` 57 | cd build/bin 58 | ./kws_main [solution_type, int] [num_bins, int] [batch_size, int] [model_path, str] [wave_path,str] [key_word,str] 59 | 60 | #eg 61 | ./kws_main 1 80 1 keyword-spot-fsmn-ctc-wenwen/onnx/keyword_spot_fsmn_ctc_wenwen.ort ../../../audio/0000c7286ebc7edef1c505b78d5ed1a3.wav 你好问问 62 | 63 | ./kws_main 1 80 1 keyword-spot-fsmn-ctc-wenwen/onnx/keyword_spot_fsmn_ctc_wenwen.ort ../../../audio/000af5671fdbaa3e55c5e2bd0bdf8cdd_hi.pcm 嗨小问 64 | ``` 65 | 66 | 更多详细信息参考: [onnx runtime](onnxruntime/README.md) 67 | 68 | PS: solution_type:{0:表示max-pooling方案, 1:表示ctc方案} 69 | 70 | ​ key_word: {你好问问,嗨小问} 71 | 72 | 73 | 74 | 如需要其他端测的推理测试,可参考wekws提供的[Android, RaspberryPI示例](https://github.com/wenet-e2e/wekws/tree/main/runtime)。 75 | 76 | 77 | 78 | # 模型训练 79 | 80 | 详细内容参考 [唤醒词自定义和模型训练](docs/model_train.md)。 81 | 82 | 83 | 84 | 85 | 86 | # 模型转换 87 | 88 | - pytorch2onnx: 将训练好的pytorch模型转换为onnx模型。onnx模型是常见的中间态模型,支持转换其他平台的模型(ncnn, tensorRT等各类推理引擎模型)。 89 | - onnx2ort: 将onnx模型转换成ort模型,用于端侧部署。 90 | 91 | 详细内容参考[唤醒词模型转换](docs/model_convert.md) 92 | 93 | 94 | 95 | 96 | 97 | ## 模型列表 98 | 99 | | 损失函数 | 模型名称 | 模型(Pytorch ckpt) | 模型(ONNX) | 端侧模型 | 100 | | ----------- | ---------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | 101 | | Max-Pooling | DS_TCN(你好问问) | [DSTCN-MaxPooling, wekws训练](https://modelscope.cn/models/daydream-factory/keyword-spot-dstcn-maxpooling-wenwen/summary) | [ONNX](https://modelscope.cn/models/daydream-factory/keyword-spot-dstcn-maxpooling-wenwen/files) | [ORT](https://modelscope.cn/models/daydream-factory/keyword-spot-dstcn-maxpooling-wenwen/files) | 102 | | | | | | | 103 | | CTC | FSMN(你好问问) | [FSMN-CTC, wekws训练](https://modelscope.cn/models/daydream-factory/keyword-spot-fsmn-ctc-wenwen/summary) | [ONNX](https://modelscope.cn/models/daydream-factory/keyword-spot-fsmn-ctc-wenwen/files) | [ORT](https://modelscope.cn/models/daydream-factory/keyword-spot-fsmn-ctc-wenwen/files) | 104 | | | | | | | 105 | | CTC | FSMN(你好问问) | [FSMN-CTC, modelscope训练](https://modelscope.cn/models/daydream-factory/keyword-spot-fsmn-ctc-nihaowenwen/summary) | [ONNX](https://modelscope.cn/models/daydream-factory/keyword-spot-fsmn-ctc-nihaowenwen/files) | [ORT](https://modelscope.cn/models/daydream-factory/keyword-spot-fsmn-ctc-nihaowenwen/files) | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | ## 参考&鸣谢 114 | 115 | 本工程主要基于[wekws](https://github.com/wenet-e2e/wekws/tree/main)进行语音唤醒的模型训练,模型转换,推理,部署等流程构建,特此感谢。 116 | 117 | - [Sequence Modeling With CTC](https://distill.pub/2017/ctc/) : CTC的设计思路和原理。 118 | - [魔搭: 你好问问 唤醒词检测体验测试Demo](https://modelscope.cn/studios/thuduj12/KWS_Nihao_Xiaojing/summary) 119 | - https://modelscope.cn/models/iic/speech_charctc_kws_phone-wenwen/summary 120 | 121 | 122 | 123 | # [License](./LICENSE) 124 | 125 | MIT -------------------------------------------------------------------------------- /audio/0000c7286ebc7edef1c505b78d5ed1a3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenyangMl/keyword-spot/0979503c7a6ee6a6b7191da524de00943c0bb52b/audio/0000c7286ebc7edef1c505b78d5ed1a3.wav -------------------------------------------------------------------------------- /audio/0000e12e2402775c2d506d77b6dbb411.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenyangMl/keyword-spot/0979503c7a6ee6a6b7191da524de00943c0bb52b/audio/0000e12e2402775c2d506d77b6dbb411.wav -------------------------------------------------------------------------------- /audio/000af5671fdbaa3e55c5e2bd0bdf8cdd_hi.pcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenyangMl/keyword-spot/0979503c7a6ee6a6b7191da524de00943c0bb52b/audio/000af5671fdbaa3e55c5e2bd0bdf8cdd_hi.pcm -------------------------------------------------------------------------------- /audio/000af5671fdbaa3e55c5e2bd0bdf8cdd_hi.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenyangMl/keyword-spot/0979503c7a6ee6a6b7191da524de00943c0bb52b/audio/000af5671fdbaa3e55c5e2bd0bdf8cdd_hi.wav -------------------------------------------------------------------------------- /audio/000eae543947c70feb9401f82da03dcf_hi.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenyangMl/keyword-spot/0979503c7a6ee6a6b7191da524de00943c0bb52b/audio/000eae543947c70feb9401f82da03dcf_hi.wav -------------------------------------------------------------------------------- /audio/gongqu-4.5_0000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenyangMl/keyword-spot/0979503c7a6ee6a6b7191da524de00943c0bb52b/audio/gongqu-4.5_0000.wav -------------------------------------------------------------------------------- /docs/model_convert.md: -------------------------------------------------------------------------------- 1 | # 模型转换 2 | 3 | ## 运行环境 4 | 5 | ``` 6 | git clone https://github.com/chenyangMl/keyword-spot.git 7 | conda create -n kws python=3.9 8 | conda activate kws 9 | pip install -r requirements.txt 10 | 11 | cd keyword-spot 12 | mkdir models 13 | ``` 14 | 15 | 16 | 17 | 下面示例模型转换,示例均使用“你好问问”数据训练的模型。 18 | 19 | ## Max-pooling方案模型转换 20 | 21 | 1. 先下载模型 22 | 23 | ``` 24 | cd models 25 | 26 | 下载方式1 27 | git clone https://www.modelscope.cn/daydream-factory/keyword-spot-dstcn-maxpooling-wenwen.git 28 | 29 | 下载方式2 30 | from modelscope import snapshot_download 31 | model_dir = snapshot_download('daydream-factory/keyword-spot-dstcn-maxpooling-wenwen') 32 | ``` 33 | 34 | 模型目录结构如下 >> tree keyword-spot-dstcn-maxpooling-wenwen 35 | 36 | ``` 37 | keyword-spot-dstcn-maxpooling-wenwen 38 | ├── avg_30.pt 39 | ├── configuration.json 40 | ├── config.yaml 41 | ├── global_cmvn 42 | ├── README.md 43 | ``` 44 | 45 | 46 | 47 | 2 模型转换。 48 | 49 | 确定下当前在主目录路径,比如示例目录/path/keyword-spotting/models/ 50 | 51 | ``` 52 | >> cd ../ 53 | >> pwd 54 | ``` 55 | 56 | 模型转换 1) pytorch to onnx 57 | 58 | ``` 59 | python model_convert/export_onnx.py \ 60 | --config models/keyword-spot-dstcn-maxpooling-wenwen/config.yaml \ 61 | --checkpoint models/keyword-spot-dstcn-maxpooling-wenwen/avg_30.pt \ 62 | --onnx_model models/keyword-spot-dstcn-maxpooling-wenwen/onnx/keyword-spot-dstcn-maxpooling-wenwen.onnx 63 | ``` 64 | 65 | 66 | 67 | 2) onnx2ort. 用于端侧设备部署. 68 | 69 | ``` 70 | python -m onnxruntime.tools.convert_onnx_models_to_ort models/keyword-spot-dstcn-maxpooling-wenwen/onnx/keyword-spot-dstcn-maxpooling-wenwen.onnx 71 | ``` 72 | 73 | 3) 输出模型结构,>> tree models/keyword-spot-dstcn-maxpooling-wenwen 74 | 75 | ``` 76 | models/keyword-spot-dstcn-maxpooling-wenwen 77 | ├── avg_30.pt 78 | ├── configuration.json 79 | ├── config.yaml 80 | ├── global_cmvn 81 | ├── onnx 82 | │ ├── keyword-spot-dstcn-maxpooling-wenwen.onnx #中间模型 83 | │ ├── keyword-spot-dstcn-maxpooling-wenwen.ort #用于端侧部署的ort模型 84 | │ ├── keyword-spot-dstcn-maxpooling-wenwen.required_operators.config 85 | │ ├── keyword-spot-dstcn-maxpooling-wenwen.required_operators.with_runtime_opt.config 86 | │ └── keyword-spot-dstcn-maxpooling-wenwen.with_runtime_opt.ort 87 | ├── README.md 88 | └── words.txt 89 | ``` 90 | 91 | 92 | 93 | ## CTC方案模型转换 94 | 95 | 1 下载模型 96 | 97 | ``` 98 | cd models 99 | 100 | 下载方式1 101 | git clone https://www.modelscope.cn/daydream-factory/keyword-spot-fsmn-ctc-wenwen.git 102 | 103 | 下载方式2 104 | from modelscope import snapshot_download 105 | model_dir = snapshot_download('daydream-factory/keyword-spot-fsmn-ctc-wenwen') 106 | ``` 107 | 108 | 模型目录查看>> tree keyword-spot-fsmn-ctc-wenwen/ 109 | 110 | ``` 111 | keyword-spot-fsmn-ctc-wenwen/ 112 | ├── avg_30.pt 113 | ├── configuration.json 114 | ├── config.yaml 115 | ├── global_cmvn.kaldi 116 | ├── lexicon.txt 117 | ├── README.md 118 | └── tokens.txt 119 | ``` 120 | 121 | 122 | 123 | 2 模型转换 124 | 125 | 模型转换 1) pytorch to onnx 126 | 127 | ``` 128 | cd path_to/keyword-spotting/ 129 | #在工程根目录运行 130 | python model_convert/export_onnx.py \ 131 | --config models/keyword-spot-fsmn-ctc-wenwen/config.yaml \ 132 | --checkpoint models/keyword-spot-fsmn-ctc-wenwen/avg_30.pt \ 133 | --onnx_model models/keyword-spot-fsmn-ctc-wenwen/onnx/keyword_spot_fsmn_ctc_wenwen.onnx 134 | ``` 135 | 136 | 137 | 138 | 2) onnx2ort. 用于端侧设备部署. 139 | 140 | ``` 141 | python -m onnxruntime.tools.convert_onnx_models_to_ort models/keyword-spot-fsmn-ctc-wenwen/onnx/keyword_spot_fsmn_ctc_wenwen.onnx 142 | ``` 143 | 144 | 3) 输出模型结构 145 | 146 | ``` 147 | models/keyword-spot-fsmn-ctc-wenwen 148 | ├── avg_30.pt 149 | ├── configuration.json 150 | ├── config.yaml 151 | ├── global_cmvn.kaldi 152 | ├── lexicon.txt 153 | ├── onnx 154 | │ ├── keyword_spot_fsmn_ctc_wenwen.onnx #中间模型 155 | │ ├── keyword_spot_fsmn_ctc_wenwen.ort #用于端侧部署的ort模型 156 | │ ├── keyword_spot_fsmn_ctc_wenwen.required_operators.config 157 | │ ├── keyword_spot_fsmn_ctc_wenwen.required_operators.with_runtime_opt.config 158 | │ └── keyword_spot_fsmn_ctc_wenwen.with_runtime_opt.ort 159 | ├── README.md 160 | └── tokens.txt 161 | ``` 162 | 163 | 164 | 165 | 166 | 167 | ## 模型可视化工具netron 168 | 169 | 使用模型可视化工具可以方便查看模型的整体结构,输入输出信息等,便于校验转换模型。 170 | 171 | ``` 172 | pip install netron 173 | ``` 174 | 175 | 查看模型使用命令 176 | 177 | ``` 178 | netron path_to_model 179 | ``` 180 | 181 | 打开提供的链接即可浏览器查看。 -------------------------------------------------------------------------------- /docs/model_train.md: -------------------------------------------------------------------------------- 1 | ## 语言唤醒——模型训练 2 | 3 | 工程示例提供了基于问问开源的唤醒词模型,支持{你好小问,嗨小问}两个唤醒词。 4 | 5 | 如果需要定制新的唤醒词可模型微调。 6 | 7 | - 先确定好唤醒词,可以是单唤醒词,也可以是多唤醒词。比如{你好小问,嗨小问} 8 | 9 | - 采集或模拟真实应用场景的唤醒词数据。  10 | - 进行模型微调得到新的唤醒词模型。 11 | 12 | 13 | 14 | ## [max-pooling方案训练](https://zhuanlan.zhihu.com/p/686365901) 15 | 16 | 17 | 18 | 19 | 20 | ## CTC方案训练 21 | 22 | ### 数据采集 23 | 24 | 训练数据需包含一定数量对应关键词(正例)和非关键词(负例)样本。 25 | 26 | **[达摩小云小云模型训练](https://www.modelscope.cn/models/iic/speech_charctc_kws_phone-xiaoyun/summary)建议关键词数据在25小时以上,混合负样本比例在1:2到1:10之间,实际性能与训练数据量、数据质量、场景匹配度、正负样本比例等诸多因素有关,需要具体分析和调整**。 27 | 28 | 29 | 30 | ``` 31 | 示例数据目录结构 32 | unittest/example_kws 33 | └── wav #wav格式的音频数据 34 | ├── 1330806238146100615.wav 35 | ├── 20200707_spk57db_storenoise52db_40cm_xiaoyun_sox_10.wav 36 | ├── 20200707_spk57db_storenoise52db_40cm_xiaoyun_sox_11.wav 37 | ├── 20200707_spk57db_storenoise52db_40cm_xiaoyun_sox_12.wav 38 | ├── cv_wav.scp #验证集 39 | ├── test_wav.scp #测试集 40 | ├── train_wav.scp #训练集 41 | ├── merge_trans.txt 42 | 43 | #示例文件内容 44 | $ cat cv_wav.scp 45 | kws_pos_example1 /home/admin/data/test/audios/kws_pos_example1.wav 46 | kws_pos_example2 /home/admin/data/test/audios/kws_pos_example2.wav 47 | ... 48 | kws_neg_example1 /home/admin/data/test/audios/kws_neg_example1.wav 49 | kws_neg_example2 /home/admin/data/test/audios/kws_neg_example2.wav 50 | ... 51 | 52 | $ cat merge_trans.txt 53 | kws_pos_example1 小 云 小 云 54 | kws_pos_example2 小 云 小 云 55 | ... 56 | kws_neg_example1 帮 我 导航 一下 回 临江 路 一百零八 还要 几个 小时 57 | kws_neg_example2 明天 的 天气 怎么样 58 | ... 59 | ``` 60 | 61 | 62 | 63 | ### 模型训练 64 | 65 | 模型训练采用"basetrain + finetune"的模式,**basetrain过程使用大量内部移动端数据**,在此基础上,使用1万条设备端录制安静场景“小云小云”数据进行微调,得到最终面向业务的模型。由于采用了中文char全量token建模,并使用充分数据进行basetrain,本模型支持基本的唤醒词/命令词自定义功能,但具体性能无法评估。 66 | 67 | 68 | 69 | ``` 70 | #下载训练工具 71 | git clone https://www.modelscope.cn/iic/speech_charctc_kws_phone-xiaoyun.git 72 | cd unittest/ 73 | 74 | CUDA_VISIBLE_DEVICES=0 python example_kws.py 75 | ``` 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /model_convert/export_onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Binbin Zhang(binbzha@qq.com) 2 | # Copyright (c) 2024 Yang Chen (cyang8050@163.com) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import argparse 17 | 18 | import torch 19 | import yaml,json 20 | import os, sys 21 | sys.path.insert(0, os.getcwd()) 22 | 23 | import onnx 24 | import onnxruntime as ort 25 | 26 | from model_convert.model.kws_model import init_model 27 | from model_convert.utils.checkpoint import load_checkpoint 28 | 29 | 30 | def get_args(): 31 | parser = argparse.ArgumentParser(description='export to onnx model') 32 | parser.add_argument('--config', required=True, help='config file') 33 | parser.add_argument('--onnx_model', 34 | required=True, 35 | help='output onnx model') 36 | parser.add_argument('--checkpoint', required=True, help='checkpoint model') 37 | args = parser.parse_args() 38 | return args 39 | 40 | 41 | def main(): 42 | args = get_args() 43 | if args.config.endswith("json"): 44 | with open(args.config) as f: 45 | configs = json.load(f) 46 | else: 47 | with open(args.config, 'r') as fin: 48 | configs = yaml.load(fin, Loader=yaml.FullLoader) 49 | feature_dim = configs['model']['input_dim'] 50 | model = init_model(configs['model']) 51 | is_fsmn = configs['model']['backbone']['type'] == 'fsmn' 52 | num_layers = configs['model']['backbone']['num_layers'] 53 | if configs['training_config'].get('criterion', 'max_pooling') == 'ctc': 54 | # if we use ctc_loss, the logits need to be convert into probs 55 | model.forward = model.forward_softmax 56 | print(model) 57 | 58 | load_checkpoint(model, args.checkpoint) 59 | model.eval() 60 | # dummy_input: (batch, time, feature_dim) 61 | dummy_input = torch.randn(1, 100, feature_dim, dtype=torch.float) 62 | cache = torch.zeros(1, 63 | model.hdim, 64 | model.backbone.padding, 65 | dtype=torch.float) 66 | if is_fsmn: 67 | cache = cache.unsqueeze(-1).expand(-1, -1, -1, num_layers) 68 | torch.onnx.export(model, (dummy_input, cache), 69 | args.onnx_model, 70 | input_names=['input', 'cache'], 71 | output_names=['output', 'r_cache'], 72 | dynamic_axes={ 73 | 'input': { 74 | 1: 'T' 75 | }, 76 | 'output': { 77 | 1: 'T' 78 | }}, 79 | opset_version=13, 80 | verbose=False, 81 | do_constant_folding=True) 82 | 83 | # Add hidden dim and cache size 84 | onnx_model = onnx.load(args.onnx_model) 85 | meta = onnx_model.metadata_props.add() 86 | meta.key, meta.value = 'cache_dim', str(model.hdim) 87 | meta = onnx_model.metadata_props.add() 88 | meta.key, meta.value = 'cache_len', str(model.backbone.padding) 89 | onnx.save(onnx_model, args.onnx_model) 90 | 91 | # Verify onnx precision 92 | torch_output = model(dummy_input, cache) 93 | ort_sess = ort.InferenceSession(args.onnx_model) 94 | onnx_output = ort_sess.run(None, { 95 | 'input': dummy_input.numpy(), 96 | 'cache': cache.numpy() 97 | }) 98 | 99 | if torch.allclose(torch_output[0], 100 | torch.tensor(onnx_output[0]), atol=1e-6) and \ 101 | torch.allclose(torch_output[1], 102 | torch.tensor(onnx_output[1]), atol=1e-6): 103 | print('Export to onnx succeed!') 104 | else: 105 | print('''Export to onnx succeed, but pytorch/onnx have different 106 | outputs when given the same input, please check!!!''') 107 | 108 | 109 | if __name__ == '__main__': 110 | main() 111 | -------------------------------------------------------------------------------- /model_convert/model/classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Jingyong Hou 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | 19 | class GlobalClassifier(nn.Module): 20 | """Add a global average pooling before the classifier""" 21 | def __init__(self, classifier: nn.Module): 22 | super(GlobalClassifier, self).__init__() 23 | self.classifier = classifier 24 | 25 | def forward(self, x: torch.Tensor): 26 | x = torch.mean(x, dim=1) 27 | return self.classifier(x) 28 | 29 | 30 | class LastClassifier(nn.Module): 31 | """Select last frame to do the classification""" 32 | def __init__(self, classifier: nn.Module): 33 | super(LastClassifier, self).__init__() 34 | self.classifier = classifier 35 | 36 | def forward(self, x: torch.Tensor): 37 | x = x[:, -1, :] 38 | return self.classifier(x) 39 | 40 | class ElementClassifier(nn.Module): 41 | """Classify all the frames in an utterance""" 42 | def __init__(self, classifier: nn.Module): 43 | super(ElementClassifier, self).__init__() 44 | self.classifier = classifier 45 | 46 | def forward(self, x: torch.Tensor): 47 | return self.classifier(x) 48 | 49 | class LinearClassifier(nn.Module): 50 | """ Wrapper of Linear """ 51 | def __init__(self, input_dim, output_dim): 52 | super().__init__() 53 | self.linear = torch.nn.Linear(input_dim, output_dim) 54 | self.quant = torch.quantization.QuantStub() 55 | self.dequant = torch.quantization.DeQuantStub() 56 | 57 | def forward(self, x: torch.Tensor): 58 | x = self.quant(x) 59 | x = self.linear(x) 60 | x = self.dequant(x) 61 | return x 62 | -------------------------------------------------------------------------------- /model_convert/model/cmvn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2020 Binbin Zhang 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | 18 | 19 | class GlobalCMVN(torch.nn.Module): 20 | def __init__(self, 21 | mean: torch.Tensor, 22 | istd: torch.Tensor, 23 | norm_var: bool = True): 24 | """ 25 | Args: 26 | mean (torch.Tensor): mean stats 27 | istd (torch.Tensor): inverse std, std which is 1.0 / std 28 | """ 29 | super().__init__() 30 | assert mean.shape == istd.shape 31 | self.norm_var = norm_var 32 | # The buffer can be accessed from this module using self.mean 33 | self.register_buffer("mean", mean) 34 | self.register_buffer("istd", istd) 35 | 36 | def forward(self, x: torch.Tensor): 37 | """ 38 | Args: 39 | x (torch.Tensor): (batch, max_len, feat_dim) 40 | 41 | Returns: 42 | (torch.Tensor): normalized feature 43 | """ 44 | x = x - self.mean 45 | if self.norm_var: 46 | x = x * self.istd 47 | return x 48 | -------------------------------------------------------------------------------- /model_convert/model/fsmn.py: -------------------------------------------------------------------------------- 1 | ''' 2 | FSMN implementation. 3 | 4 | Copyright: 2022-03-09 yueyue.nyy 5 | 2023 Jing Du 6 | ''' 7 | 8 | from typing import Tuple 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | def toKaldiMatrix(np_mat): 17 | np.set_printoptions(threshold=np.inf, linewidth=np.nan) 18 | out_str = str(np_mat) 19 | out_str = out_str.replace('[', '') 20 | out_str = out_str.replace(']', '') 21 | return '[ %s ]\n' % out_str 22 | 23 | 24 | def printTensor(torch_tensor): 25 | re_str = '' 26 | x = torch_tensor.detach().squeeze().numpy() 27 | re_str += toKaldiMatrix(x) 28 | # re_str += '\n' 29 | print(re_str) 30 | 31 | 32 | class LinearTransform(nn.Module): 33 | 34 | def __init__(self, input_dim, output_dim): 35 | super(LinearTransform, self).__init__() 36 | self.input_dim = input_dim 37 | self.output_dim = output_dim 38 | self.linear = nn.Linear(input_dim, output_dim, bias=False) 39 | self.quant = torch.quantization.QuantStub() 40 | self.dequant = torch.quantization.DeQuantStub() 41 | 42 | def forward(self, 43 | input: Tuple[torch.Tensor, torch.Tensor]): 44 | if isinstance(input, tuple): 45 | input, in_cache = input 46 | else: 47 | in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float) 48 | output = self.quant(input) 49 | output = self.linear(output) 50 | output = self.dequant(output) 51 | 52 | return (output, in_cache) 53 | 54 | def to_kaldi_net(self): 55 | re_str = '' 56 | re_str += ' %d %d\n' % (self.output_dim, 57 | self.input_dim) 58 | re_str += ' 1\n' 59 | 60 | linear_weights = self.state_dict()['linear.weight'] 61 | x = linear_weights.squeeze().numpy() 62 | re_str += toKaldiMatrix(x) 63 | # re_str += '\n' 64 | 65 | return re_str 66 | 67 | def to_pytorch_net(self, fread): 68 | linear_line = fread.readline() 69 | linear_split = linear_line.strip().split() 70 | assert len(linear_split) == 3 71 | assert linear_split[0] == '' 72 | self.output_dim = int(linear_split[1]) 73 | self.input_dim = int(linear_split[2]) 74 | 75 | learn_rate_line = fread.readline() 76 | assert learn_rate_line.find('LearnRateCoef') != -1 77 | 78 | self.linear.reset_parameters() 79 | 80 | # linear_weights = self.state_dict()['linear.weight'] 81 | # print(linear_weights.shape) 82 | new_weights = torch.zeros((self.output_dim, self.input_dim), 83 | dtype=torch.float32) 84 | for i in range(self.output_dim): 85 | line = fread.readline() 86 | splits = line.strip().strip('[]').strip().split() 87 | assert len(splits) == self.input_dim 88 | cols = torch.tensor([float(item) for item in splits], 89 | dtype=torch.float32) 90 | new_weights[i, :] = cols 91 | 92 | self.linear.weight.data = new_weights 93 | 94 | 95 | class AffineTransform(nn.Module): 96 | 97 | def __init__(self, input_dim, output_dim): 98 | super(AffineTransform, self).__init__() 99 | self.input_dim = input_dim 100 | self.output_dim = output_dim 101 | 102 | self.linear = nn.Linear(input_dim, output_dim) 103 | self.quant = torch.quantization.QuantStub() 104 | self.dequant = torch.quantization.DeQuantStub() 105 | 106 | def forward(self, 107 | input: Tuple[torch.Tensor, torch.Tensor]): 108 | if isinstance(input, tuple): 109 | input, in_cache = input 110 | else: 111 | in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float) 112 | output = self.quant(input) 113 | output = self.linear(output) 114 | output = self.dequant(output) 115 | 116 | return (output, in_cache) 117 | 118 | def to_kaldi_net(self): 119 | re_str = '' 120 | re_str += ' %d %d\n' % (self.output_dim, 121 | self.input_dim) 122 | re_str += ' 1 1 0\n' 123 | 124 | linear_weights = self.state_dict()['linear.weight'] 125 | x = linear_weights.squeeze().numpy() 126 | re_str += toKaldiMatrix(x) 127 | 128 | linear_bias = self.state_dict()['linear.bias'] 129 | x = linear_bias.squeeze().numpy() 130 | re_str += toKaldiMatrix(x) 131 | # re_str += '\n' 132 | 133 | return re_str 134 | 135 | def to_pytorch_net(self, fread): 136 | affine_line = fread.readline() 137 | affine_split = affine_line.strip().split() 138 | assert len(affine_split) == 3 139 | assert affine_split[0] == '' 140 | self.output_dim = int(affine_split[1]) 141 | self.input_dim = int(affine_split[2]) 142 | print('AffineTransform output/input dim: %d %d' % 143 | (self.output_dim, self.input_dim)) 144 | 145 | learn_rate_line = fread.readline() 146 | assert learn_rate_line.find('LearnRateCoef') != -1 147 | 148 | # linear_weights = self.state_dict()['linear.weight'] 149 | # print(linear_weights.shape) 150 | self.linear.reset_parameters() 151 | 152 | new_weights = torch.zeros((self.output_dim, self.input_dim), 153 | dtype=torch.float32) 154 | for i in range(self.output_dim): 155 | line = fread.readline() 156 | splits = line.strip().strip('[]').strip().split() 157 | assert len(splits) == self.input_dim 158 | cols = torch.tensor([float(item) for item in splits], 159 | dtype=torch.float32) 160 | new_weights[i, :] = cols 161 | 162 | self.linear.weight.data = new_weights 163 | 164 | # linear_bias = self.state_dict()['linear.bias'] 165 | # print(linear_bias.shape) 166 | bias_line = fread.readline() 167 | splits = bias_line.strip().strip('[]').strip().split() 168 | assert len(splits) == self.output_dim 169 | new_bias = torch.tensor([float(item) for item in splits], 170 | dtype=torch.float32) 171 | 172 | self.linear.bias.data = new_bias 173 | 174 | 175 | class FSMNBlock(nn.Module): 176 | 177 | def __init__( 178 | self, 179 | input_dim: int, 180 | output_dim: int, 181 | lorder=None, 182 | rorder=None, 183 | lstride=1, 184 | rstride=1, 185 | ): 186 | super(FSMNBlock, self).__init__() 187 | 188 | self.dim = input_dim 189 | 190 | if lorder is None: 191 | return 192 | 193 | self.lorder = lorder 194 | self.rorder = rorder 195 | self.lstride = lstride 196 | self.rstride = rstride 197 | 198 | self.conv_left = nn.Conv2d( 199 | self.dim, 200 | self.dim, [lorder, 1], 201 | dilation=[lstride, 1], 202 | groups=self.dim, 203 | bias=False) 204 | 205 | if rorder > 0: 206 | self.conv_right = nn.Conv2d( 207 | self.dim, 208 | self.dim, [rorder, 1], 209 | dilation=[rstride, 1], 210 | groups=self.dim, 211 | bias=False) 212 | else: 213 | self.conv_right = None 214 | 215 | self.quant = torch.quantization.QuantStub() 216 | self.dequant = torch.quantization.DeQuantStub() 217 | 218 | def forward(self, 219 | input: Tuple[torch.Tensor, torch.Tensor]): 220 | if isinstance(input, tuple): 221 | input, in_cache = input 222 | else : 223 | in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float) 224 | x = torch.unsqueeze(input, 1) 225 | x_per = x.permute(0, 3, 2, 1) 226 | 227 | if in_cache is None or len(in_cache) == 0 : 228 | x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride 229 | + self.rorder * self.rstride, 0]) 230 | else: 231 | in_cache = in_cache.to(x_per.device) 232 | x_pad = torch.cat((in_cache, x_per), dim=2) 233 | in_cache = x_pad[:, :, -((self.lorder - 1) * self.lstride 234 | + self.rorder * self.rstride):, :] 235 | y_left = x_pad[:, :, :-self.rorder * self.rstride, :] 236 | y_left = self.quant(y_left) 237 | y_left = self.conv_left(y_left) 238 | y_left = self.dequant(y_left) 239 | out = x_pad[:, :, (self.lorder - 1) * self.lstride: -self.rorder * 240 | self.rstride, :] + y_left 241 | 242 | if self.conv_right is not None: 243 | # y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride]) 244 | y_right = x_pad[:, :, -( 245 | x_per.size(2) + self.rorder * self.rstride):, :] 246 | y_right = y_right[:, :, self.rstride:, :] 247 | y_right = self.quant(y_right) 248 | y_right = self.conv_right(y_right) 249 | y_right = self.dequant(y_right) 250 | out += y_right 251 | 252 | out_per = out.permute(0, 3, 2, 1) 253 | output = out_per.squeeze(1) 254 | 255 | return (output, in_cache) 256 | 257 | def to_kaldi_net(self): 258 | re_str = '' 259 | re_str += ' %d %d\n' % (self.dim, self.dim) 260 | re_str += ' %d %d %d ' \ 261 | ' %d %d 0\n' % ( 262 | 1, self.lorder, self.rorder, self.lstride, self.rstride) 263 | 264 | # print(self.conv_left.weight,self.conv_right.weight) 265 | lfiters = self.state_dict()['conv_left.weight'] 266 | x = np.flipud(lfiters.squeeze().numpy().T) 267 | re_str += toKaldiMatrix(x) 268 | 269 | if self.conv_right is not None: 270 | rfiters = self.state_dict()['conv_right.weight'] 271 | x = (rfiters.squeeze().numpy().T) 272 | re_str += toKaldiMatrix(x) 273 | # re_str += '\n' 274 | 275 | return re_str 276 | 277 | def to_pytorch_net(self, fread): 278 | fsmn_line = fread.readline() 279 | fsmn_split = fsmn_line.strip().split() 280 | assert len(fsmn_split) == 3 281 | assert fsmn_split[0] == '' 282 | self.dim = int(fsmn_split[1]) 283 | 284 | params_line = fread.readline() 285 | params_split = params_line.strip().strip('[]').strip().split() 286 | assert len(params_split) == 12 287 | assert params_split[0] == '' 288 | assert params_split[2] == '' 289 | self.lorder = int(params_split[3]) 290 | assert params_split[4] == '' 291 | self.rorder = int(params_split[5]) 292 | assert params_split[6] == '' 293 | self.lstride = int(params_split[7]) 294 | assert params_split[8] == '' 295 | self.rstride = int(params_split[9]) 296 | assert params_split[10] == '' 297 | 298 | # lfilters = self.state_dict()['conv_left.weight'] 299 | # print(lfilters.shape) 300 | print('read conv_left weight') 301 | new_lfilters = torch.zeros((self.lorder, 1, self.dim, 1), 302 | dtype=torch.float32) 303 | for i in range(self.lorder): 304 | print('read conv_left weight -- %d' % i) 305 | line = fread.readline() 306 | splits = line.strip().strip('[]').strip().split() 307 | assert len(splits) == self.dim 308 | cols = torch.tensor([float(item) for item in splits], 309 | dtype=torch.float32) 310 | new_lfilters[self.lorder - 1 - i, 0, :, 0] = cols 311 | 312 | new_lfilters = torch.transpose(new_lfilters, 0, 2) 313 | # print(new_lfilters.shape) 314 | 315 | self.conv_left.reset_parameters() 316 | self.conv_left.weight.data = new_lfilters 317 | # print(self.conv_left.weight.shape) 318 | 319 | if self.rorder > 0: 320 | # rfilters = self.state_dict()['conv_right.weight'] 321 | # print(rfilters.shape) 322 | print('read conv_right weight') 323 | new_rfilters = torch.zeros((self.rorder, 1, self.dim, 1), 324 | dtype=torch.float32) 325 | line = fread.readline() 326 | for i in range(self.rorder): 327 | print('read conv_right weight -- %d' % i) 328 | line = fread.readline() 329 | splits = line.strip().strip('[]').strip().split() 330 | assert len(splits) == self.dim 331 | cols = torch.tensor([float(item) for item in splits], 332 | dtype=torch.float32) 333 | new_rfilters[i, 0, :, 0] = cols 334 | 335 | new_rfilters = torch.transpose(new_rfilters, 0, 2) 336 | # print(new_rfilters.shape) 337 | self.conv_right.reset_parameters() 338 | self.conv_right.weight.data = new_rfilters 339 | # print(self.conv_right.weight.shape) 340 | 341 | 342 | class RectifiedLinear(nn.Module): 343 | 344 | def __init__(self, input_dim, output_dim): 345 | super(RectifiedLinear, self).__init__() 346 | self.dim = input_dim 347 | self.relu = nn.ReLU() 348 | self.dropout = nn.Dropout(0.1) 349 | 350 | def forward(self, 351 | input: Tuple[torch.Tensor, torch.Tensor]): 352 | if isinstance(input, tuple): 353 | input, in_cache = input 354 | else : 355 | in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float) 356 | out = self.relu(input) 357 | # out = self.dropout(out) 358 | return (out, in_cache) 359 | 360 | def to_kaldi_net(self): 361 | re_str = '' 362 | re_str += ' %d %d\n' % (self.dim, self.dim) 363 | # re_str += '\n' 364 | return re_str 365 | 366 | # re_str = '' 367 | # re_str += ' %d %d\n' % (self.dim, self.dim) 368 | # re_str += ' 0 0\n' 369 | # re_str += toKaldiMatrix(np.ones((self.dim), dtype = 'int32')) 370 | # re_str += toKaldiMatrix(np.zeros((self.dim), dtype = 'int32')) 371 | # re_str += '\n' 372 | # return re_str 373 | 374 | def to_pytorch_net(self, fread): 375 | line = fread.readline() 376 | splits = line.strip().split() 377 | assert len(splits) == 3 378 | assert splits[0] == '' 379 | assert int(splits[1]) == int(splits[2]) 380 | assert int(splits[1]) == self.dim 381 | self.dim = int(splits[1]) 382 | 383 | 384 | def _build_repeats( 385 | fsmn_layers: int, 386 | linear_dim: int, 387 | proj_dim: int, 388 | lorder: int, 389 | rorder: int, 390 | lstride=1, 391 | rstride=1, 392 | ): 393 | repeats = [ 394 | nn.Sequential( 395 | LinearTransform(linear_dim, proj_dim), 396 | FSMNBlock(proj_dim, proj_dim, lorder, rorder, 1, 1), 397 | AffineTransform(proj_dim, linear_dim), 398 | RectifiedLinear(linear_dim, linear_dim)) 399 | for i in range(fsmn_layers) 400 | ] 401 | 402 | return nn.Sequential(*repeats) 403 | 404 | 405 | class FSMN(nn.Module): 406 | 407 | def __init__( 408 | self, 409 | input_dim: int, 410 | input_affine_dim: int, 411 | fsmn_layers: int, 412 | linear_dim: int, 413 | proj_dim: int, 414 | lorder: int, 415 | rorder: int, 416 | lstride: int, 417 | rstride: int, 418 | output_affine_dim: int, 419 | output_dim: int, 420 | ): 421 | """ 422 | Args: 423 | input_dim: input dimension 424 | input_affine_dim: input affine layer dimension 425 | fsmn_layers: no. of fsmn units 426 | linear_dim: fsmn input dimension 427 | proj_dim: fsmn projection dimension 428 | lorder: fsmn left order 429 | rorder: fsmn right order 430 | lstride: fsmn left stride 431 | rstride: fsmn right stride 432 | output_affine_dim: output affine layer dimension 433 | output_dim: output dimension 434 | """ 435 | super(FSMN, self).__init__() 436 | 437 | self.input_dim = input_dim 438 | self.input_affine_dim = input_affine_dim 439 | self.fsmn_layers = fsmn_layers 440 | self.linear_dim = linear_dim 441 | self.proj_dim = proj_dim 442 | self.lorder = lorder 443 | self.rorder = rorder 444 | self.lstride = lstride 445 | self.rstride = rstride 446 | self.output_affine_dim = output_affine_dim 447 | self.output_dim = output_dim 448 | 449 | self.padding = (self.lorder - 1) * self.lstride \ 450 | + self.rorder * self.rstride 451 | 452 | self.in_linear1 = AffineTransform(input_dim, input_affine_dim) 453 | self.in_linear2 = AffineTransform(input_affine_dim, linear_dim) 454 | self.relu = RectifiedLinear(linear_dim, linear_dim) 455 | 456 | self.fsmn = _build_repeats(fsmn_layers, linear_dim, proj_dim, lorder, 457 | rorder, lstride, rstride) 458 | 459 | self.out_linear1 = AffineTransform(linear_dim, output_affine_dim) 460 | self.out_linear2 = AffineTransform(output_affine_dim, output_dim) 461 | # self.softmax = nn.Softmax(dim = -1) 462 | 463 | def fuse_modules(self): 464 | pass 465 | 466 | def forward( 467 | self, 468 | input: torch.Tensor, 469 | in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) 470 | ) -> Tuple[torch.Tensor, torch.Tensor]: 471 | """ 472 | Args: 473 | input (torch.Tensor): Input tensor (B, T, D) 474 | in_cache(torch.Tensor): (B, D, C), C is the accumulated cache size 475 | """ 476 | 477 | if in_cache is None or len(in_cache) == 0 : 478 | in_cache = [torch.zeros(0, 0, 0, 0, dtype=torch.float) 479 | for _ in range(len(self.fsmn))] 480 | else: 481 | in_cache = [in_cache[:, :, :, i: i + 1] for i in range(in_cache.size(-1))] 482 | input = (input, in_cache) 483 | x1 = self.in_linear1(input) 484 | x2 = self.in_linear2(x1) 485 | x3 = self.relu(x2) 486 | # x4 = self.fsmn(x3) 487 | x4, _ = x3 488 | for layer, module in enumerate(self.fsmn): 489 | x4, in_cache[layer] = module((x4, in_cache[layer])) 490 | x5 = self.out_linear1(x4) 491 | x6 = self.out_linear2(x5) 492 | # x7 = self.softmax(x6) 493 | x7, _ = x6 494 | # return x7, None 495 | return x7, torch.cat(in_cache, dim=-1) 496 | 497 | def to_kaldi_net(self): 498 | re_str = '' 499 | re_str += '\n' 500 | re_str += self.in_linear1.to_kaldi_net() 501 | re_str += self.in_linear2.to_kaldi_net() 502 | re_str += self.relu.to_kaldi_net() 503 | 504 | for fsmn in self.fsmn: 505 | re_str += fsmn[0].to_kaldi_net() 506 | re_str += fsmn[1].to_kaldi_net() 507 | re_str += fsmn[2].to_kaldi_net() 508 | re_str += fsmn[3].to_kaldi_net() 509 | 510 | re_str += self.out_linear1.to_kaldi_net() 511 | re_str += self.out_linear2.to_kaldi_net() 512 | re_str += ' %d %d\n' % (self.output_dim, self.output_dim) 513 | # re_str += '\n' 514 | re_str += '\n' 515 | 516 | return re_str 517 | 518 | def to_pytorch_net(self, kaldi_file): 519 | with open(kaldi_file, 'r', encoding='utf8') as fread: 520 | fread = open(kaldi_file, 'r') 521 | nnet_start_line = fread.readline() 522 | assert nnet_start_line.strip() == '' 523 | 524 | self.in_linear1.to_pytorch_net(fread) 525 | self.in_linear2.to_pytorch_net(fread) 526 | self.relu.to_pytorch_net(fread) 527 | 528 | for fsmn in self.fsmn: 529 | fsmn[0].to_pytorch_net(fread) 530 | fsmn[1].to_pytorch_net(fread) 531 | fsmn[2].to_pytorch_net(fread) 532 | fsmn[3].to_pytorch_net(fread) 533 | 534 | self.out_linear1.to_pytorch_net(fread) 535 | self.out_linear2.to_pytorch_net(fread) 536 | 537 | softmax_line = fread.readline() 538 | softmax_split = softmax_line.strip().split() 539 | assert softmax_split[0].strip() == '' 540 | assert int(softmax_split[1]) == self.output_dim 541 | assert int(softmax_split[2]) == self.output_dim 542 | # '\n' 543 | 544 | nnet_end_line = fread.readline() 545 | assert nnet_end_line.strip() == '' 546 | fread.close() 547 | 548 | 549 | if __name__ == '__main__': 550 | fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599) 551 | print(fsmn) 552 | 553 | num_params = sum(p.numel() for p in fsmn.parameters()) 554 | print('the number of model params: {}'.format(num_params)) 555 | x = torch.zeros(128, 200, 400) # batch-size * time * dim 556 | y, _ = fsmn(x) # batch-size * time * dim 557 | print('input shape: {}'.format(x.shape)) 558 | print('output shape: {}'.format(y.shape)) 559 | 560 | print(fsmn.to_kaldi_net()) 561 | -------------------------------------------------------------------------------- /model_convert/model/kws_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Binbin Zhang 2 | # 2023 Jing Du 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import sys 17 | from typing import Optional, Tuple 18 | 19 | import torch 20 | import torch.nn as nn 21 | 22 | from model_convert.model.cmvn import GlobalCMVN 23 | from model_convert.model.classifier import (GlobalClassifier, LastClassifier, 24 | LinearClassifier) 25 | from model_convert.model.subsampling import (LinearSubsampling1, Conv1dSubsampling1, 26 | NoSubsampling) 27 | from model_convert.model.tcn import TCN, CnnBlock, DsCnnBlock 28 | from model_convert.model.mdtc import MDTC 29 | from model_convert.utils.cmvn import load_cmvn, load_kaldi_cmvn 30 | from model_convert.model.fsmn import FSMN 31 | 32 | 33 | class KWSModel(nn.Module): 34 | """Our model consists of four parts: 35 | 1. global_cmvn: Optional, (idim, idim) 36 | 2. preprocessing: feature dimention projection, (idim, hdim) 37 | 3. backbone: backbone of the whole network, (hdim, hdim) 38 | 4. classifier: output layer or classifier of KWS model, (hdim, odim) 39 | 5. activation: 40 | nn.Sigmoid for wakeup word 41 | nn.Identity for speech command dataset 42 | """ 43 | def __init__( 44 | self, 45 | idim: int, 46 | odim: int, 47 | hdim: int, 48 | global_cmvn: Optional[nn.Module], 49 | preprocessing: Optional[nn.Module], 50 | backbone: nn.Module, 51 | classifier: nn.Module, 52 | activation: nn.Module, 53 | ): 54 | super().__init__() 55 | self.idim = idim 56 | self.odim = odim 57 | self.hdim = hdim 58 | self.global_cmvn = global_cmvn 59 | self.preprocessing = preprocessing 60 | self.backbone = backbone 61 | self.classifier = classifier 62 | self.activation = activation 63 | 64 | def forward( 65 | self, 66 | x: torch.Tensor, 67 | in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) 68 | ) -> Tuple[torch.Tensor, torch.Tensor]: 69 | if self.global_cmvn is not None: 70 | x = self.global_cmvn(x) 71 | x = self.preprocessing(x) 72 | x, out_cache = self.backbone(x, in_cache) 73 | x = self.classifier(x) 74 | x = self.activation(x) 75 | return x, out_cache 76 | 77 | def forward_softmax(self, 78 | x: torch.Tensor, 79 | in_cache: torch.Tensor = torch.zeros( 80 | 0, 0, 0, dtype=torch.float) 81 | ) -> Tuple[torch.Tensor, torch.Tensor]: 82 | if self.global_cmvn is not None: 83 | x = self.global_cmvn(x) 84 | x = self.preprocessing(x) 85 | x, out_cache = self.backbone(x, in_cache) 86 | x = self.classifier(x) 87 | x = self.activation(x) 88 | x = x.softmax(2) 89 | return x, out_cache 90 | 91 | def fuse_modules(self): 92 | self.preprocessing.fuse_modules() 93 | self.backbone.fuse_modules() 94 | 95 | 96 | def init_model(configs): 97 | cmvn = configs.get('cmvn', {}) 98 | if 'cmvn_file' in cmvn and cmvn['cmvn_file'] is not None: 99 | if "kaldi" in cmvn['cmvn_file']: 100 | mean, istd = load_kaldi_cmvn(cmvn['cmvn_file']) 101 | else: 102 | mean, istd = load_cmvn(cmvn['cmvn_file']) 103 | global_cmvn = GlobalCMVN( 104 | torch.from_numpy(mean).float(), 105 | torch.from_numpy(istd).float(), 106 | cmvn['norm_var'], 107 | ) 108 | else: 109 | global_cmvn = None 110 | 111 | input_dim = configs['input_dim'] 112 | output_dim = configs['output_dim'] 113 | hidden_dim = configs['hidden_dim'] 114 | 115 | prep_type = configs['preprocessing']['type'] 116 | if prep_type == 'linear': 117 | preprocessing = LinearSubsampling1(input_dim, hidden_dim) 118 | elif prep_type == 'cnn1d_s1': 119 | preprocessing = Conv1dSubsampling1(input_dim, hidden_dim) 120 | elif prep_type == 'none': 121 | preprocessing = NoSubsampling() 122 | else: 123 | print('Unknown preprocessing type {}'.format(prep_type)) 124 | sys.exit(1) 125 | 126 | backbone_type = configs['backbone']['type'] 127 | if backbone_type == 'gru': 128 | num_layers = configs['backbone']['num_layers'] 129 | backbone = torch.nn.GRU(hidden_dim, 130 | hidden_dim, 131 | num_layers=num_layers, 132 | batch_first=True) 133 | elif backbone_type == 'tcn': 134 | # Depthwise Separable 135 | num_layers = configs['backbone']['num_layers'] 136 | ds = configs['backbone'].get('ds', False) 137 | if ds: 138 | block_class = DsCnnBlock 139 | else: 140 | block_class = CnnBlock 141 | kernel_size = configs['backbone'].get('kernel_size', 8) 142 | dropout = configs['backbone'].get('drouput', 0.1) 143 | backbone = TCN(num_layers, hidden_dim, kernel_size, dropout, 144 | block_class) 145 | elif backbone_type == 'mdtc': 146 | stack_size = configs['backbone']['stack_size'] 147 | num_stack = configs['backbone']['num_stack'] 148 | kernel_size = configs['backbone']['kernel_size'] 149 | hidden_dim = configs['backbone']['hidden_dim'] 150 | causal = configs['backbone']['causal'] 151 | backbone = MDTC(num_stack, 152 | stack_size, 153 | hidden_dim, 154 | hidden_dim, 155 | kernel_size, 156 | causal=causal) 157 | elif backbone_type == 'fsmn': 158 | input_affine_dim = configs['backbone']['input_affine_dim'] 159 | num_layers = configs['backbone']['num_layers'] 160 | linear_dim = configs['backbone']['linear_dim'] 161 | proj_dim = configs['backbone']['proj_dim'] 162 | left_order = configs['backbone']['left_order'] 163 | right_order = configs['backbone']['right_order'] 164 | left_stride = configs['backbone']['left_stride'] 165 | right_stride = configs['backbone']['right_stride'] 166 | output_affine_dim = configs['backbone']['output_affine_dim'] 167 | backbone = FSMN(input_dim, input_affine_dim, num_layers, linear_dim, 168 | proj_dim, left_order, right_order, left_stride, 169 | right_stride, output_affine_dim, output_dim) 170 | 171 | else: 172 | print('Unknown body type {}'.format(backbone_type)) 173 | sys.exit(1) 174 | if 'classifier' in configs: 175 | # For speech command dataset, we use 2 FC layer as classifier, 176 | # we add dropout after first FC layer to prevent overfitting 177 | classifier_type = configs['classifier']['type'] 178 | dropout = configs['classifier']['dropout'] 179 | 180 | classifier_base = nn.Sequential(nn.Linear(hidden_dim, 64), nn.ReLU(), 181 | nn.Dropout(dropout), 182 | nn.Linear(64, output_dim)) 183 | if classifier_type == 'global': 184 | # global means we add a global average pooling before classifier 185 | classifier = GlobalClassifier(classifier_base) 186 | elif classifier_type == 'last': 187 | # last means we use last frame to do backpropagation, so the model 188 | # can be infered streamingly 189 | classifier = LastClassifier(classifier_base) 190 | elif classifier_type == 'identity': 191 | classifier = nn.Identity() 192 | else: 193 | print('Unknown classifier type {}'.format(classifier_type)) 194 | sys.exit(1) 195 | activation = nn.Identity() 196 | else: 197 | classifier = LinearClassifier(hidden_dim, output_dim) 198 | activation = nn.Sigmoid() 199 | 200 | # Here we add a possible "activation_type", 201 | # one can choose to use other activation function. 202 | # We use nn.Identity just for CTC loss 203 | if "activation" in configs: 204 | activation_type = configs["activation"]["type"] 205 | if activation_type == 'identity': 206 | activation = nn.Identity() 207 | else: 208 | print('Unknown activation type {}'.format(activation_type)) 209 | sys.exit(1) 210 | 211 | kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn, 212 | preprocessing, backbone, classifier, activation) 213 | return kws_model 214 | -------------------------------------------------------------------------------- /model_convert/model/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Binbin Zhang 2 | # 2023 Jing Du 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | import math 18 | import sys 19 | import torch.nn.functional as F 20 | from collections import defaultdict 21 | from typing import List, Tuple 22 | 23 | from model_convert.utils.mask import padding_mask 24 | 25 | 26 | def max_pooling_loss(logits: torch.Tensor, 27 | target: torch.Tensor, 28 | lengths: torch.Tensor, 29 | min_duration: int = 0): 30 | ''' Max-pooling loss 31 | For keyword, select the frame with the highest posterior. 32 | The keyword is triggered when any of the frames is triggered. 33 | For none keyword, select the hardest frame, namely the frame 34 | with lowest filler posterior(highest keyword posterior). 35 | the keyword is not triggered when all frames are not triggered. 36 | 37 | Attributes: 38 | logits: (B, T, D), D is the number of keywords 39 | target: (B) 40 | lengths: (B) 41 | min_duration: min duration of the keyword 42 | Returns: 43 | (float): loss of current batch 44 | (float): accuracy of current batch 45 | ''' 46 | mask = padding_mask(lengths) 47 | num_utts = logits.size(0) 48 | num_keywords = logits.size(2) 49 | 50 | target = target.cpu() 51 | loss = 0.0 52 | for i in range(num_utts): 53 | for j in range(num_keywords): 54 | # Add entropy loss CE = -(t * log(p) + (1 - t) * log(1 - p)) 55 | if target[i] == j: 56 | # For the keyword, do max-polling 57 | prob = logits[i, :, j] 58 | m = mask[i].clone().detach() 59 | m[:min_duration] = True 60 | prob = prob.masked_fill(m, 0.0) 61 | prob = torch.clamp(prob, 1e-8, 1.0) 62 | max_prob = prob.max() 63 | loss += -torch.log(max_prob) 64 | else: 65 | # For other keywords or filler, do min-polling 66 | prob = 1 - logits[i, :, j] 67 | prob = prob.masked_fill(mask[i], 1.0) 68 | prob = torch.clamp(prob, 1e-8, 1.0) 69 | min_prob = prob.min() 70 | loss += -torch.log(min_prob) 71 | loss = loss / num_utts 72 | 73 | # Compute accuracy of current batch 74 | mask = mask.unsqueeze(-1) 75 | logits = logits.masked_fill(mask, 0.0) 76 | max_logits, index = logits.max(1) 77 | num_correct = 0 78 | for i in range(num_utts): 79 | max_p, idx = max_logits[i].max(0) 80 | # Predict correct as the i'th keyword 81 | if max_p > 0.5 and idx == target[i]: 82 | num_correct += 1 83 | # Predict correct as the filler, filler id < 0 84 | if max_p < 0.5 and target[i] < 0: 85 | num_correct += 1 86 | acc = num_correct / num_utts 87 | # acc = 0.0 88 | return loss, acc 89 | 90 | 91 | def acc_frame( 92 | logits: torch.Tensor, 93 | target: torch.Tensor, 94 | ): 95 | if logits is None: 96 | return 0 97 | pred = logits.max(1, keepdim=True)[1] 98 | correct = pred.eq(target.long().view_as(pred)).sum().item() 99 | return correct * 100.0 / logits.size(0) 100 | 101 | def acc_utterance(logits: torch.Tensor, target: torch.Tensor, 102 | logits_length: torch.Tensor, target_length: torch.Tensor): 103 | if logits is None: 104 | return 0 105 | 106 | logits = logits.softmax(2) # (1, maxlen, vocab_size) 107 | logits = logits.cpu() 108 | target = target.cpu() 109 | 110 | total_word = 0 111 | total_ins = 0 112 | total_sub = 0 113 | total_del = 0 114 | calculator = Calculator() 115 | for i in range(logits.size(0)): 116 | score = logits[i][:logits_length[i]] 117 | hyps = ctc_prefix_beam_search(score, logits_length[i], None, 3, 5) 118 | lab = [str(item) for item in target[i][:target_length[i]].tolist()] 119 | rec = [] 120 | if len(hyps) > 0: 121 | rec = [str(item) for item in hyps[0][0]] 122 | result = calculator.calculate(lab, rec) 123 | # print(f'result:{result}') 124 | if result['all'] != 0: 125 | total_word += result['all'] 126 | total_ins += result['ins'] 127 | total_sub += result['sub'] 128 | total_del += result['del'] 129 | 130 | return float(total_word - total_ins - total_sub 131 | - total_del) * 100.0 / total_word 132 | 133 | def ctc_loss(logits: torch.Tensor, 134 | target: torch.Tensor, 135 | logits_lengths: torch.Tensor, 136 | target_lengths: torch.Tensor, 137 | need_acc: bool = False): 138 | """ CTC Loss 139 | Args: 140 | logits: (B, D), D is the number of keywords plus 1 (non-keyword) 141 | target: (B) 142 | logits_lengths: (B) 143 | target_lengths: (B) 144 | Returns: 145 | (float): loss of current batch 146 | """ 147 | 148 | acc = 0.0 149 | if need_acc: 150 | acc = acc_utterance(logits, target, logits_lengths, target_lengths) 151 | 152 | # logits: (B, L, D) -> (L, B, D) 153 | logits = logits.transpose(0, 1) 154 | logits = logits.log_softmax(2) 155 | loss = F.ctc_loss( 156 | logits, target, logits_lengths, target_lengths, reduction='sum') 157 | loss = loss / logits.size(1) # batch mean 158 | 159 | return loss, acc 160 | 161 | def cross_entropy(logits: torch.Tensor, target: torch.Tensor): 162 | """ Cross Entropy Loss 163 | Attributes: 164 | logits: (B, D), D is the number of keywords plus 1 (non-keyword) 165 | target: (B) 166 | lengths: (B) 167 | min_duration: min duration of the keyword 168 | Returns: 169 | (float): loss of current batch 170 | (float): accuracy of current batch 171 | """ 172 | loss = F.cross_entropy(logits, target.type(torch.int64)) 173 | acc = acc_frame(logits, target) 174 | return loss, acc 175 | 176 | 177 | def criterion(type: str, 178 | logits: torch.Tensor, 179 | target: torch.Tensor, 180 | lengths: torch.Tensor, 181 | target_lengths: torch.Tensor = None, 182 | min_duration: int = 0, 183 | validation: bool = False, ): 184 | if type == 'ce': 185 | loss, acc = cross_entropy(logits, target) 186 | return loss, acc 187 | elif type == 'max_pooling': 188 | loss, acc = max_pooling_loss(logits, target, lengths, min_duration) 189 | return loss, acc 190 | elif type == 'ctc': 191 | loss, acc = ctc_loss( 192 | logits, target, lengths, target_lengths, validation) 193 | return loss, acc 194 | else: 195 | exit(1) 196 | 197 | def ctc_prefix_beam_search( 198 | logits: torch.Tensor, 199 | logits_lengths: torch.Tensor, 200 | keywords_tokenset: set = None, 201 | score_beam_size: int = 3, 202 | path_beam_size: int = 20, 203 | ) -> Tuple[List[List[int]], torch.Tensor]: 204 | """ CTC prefix beam search inner implementation 205 | 206 | Args: 207 | logits (torch.Tensor): (1, max_len, vocab_size) 208 | logits_lengths (torch.Tensor): (1, ) 209 | keywords_tokenset (set): token set for filtering score 210 | score_beam_size (int): beam size for score 211 | path_beam_size (int): beam size for path 212 | 213 | Returns: 214 | List[List[int]]: nbest results 215 | """ 216 | maxlen = logits.size(0) 217 | # ctc_probs = logits.softmax(1) # (1, maxlen, vocab_size) 218 | ctc_probs = logits 219 | 220 | cur_hyps = [(tuple(), (1.0, 0.0, []))] 221 | 222 | # 2. CTC beam search step by step 223 | for t in range(0, maxlen): 224 | probs = ctc_probs[t] # (vocab_size,) 225 | # key: prefix, value (pb, pnb), default value(-inf, -inf) 226 | next_hyps = defaultdict(lambda: (0.0, 0.0, [])) 227 | 228 | # 2.1 First beam prune: select topk best 229 | top_k_probs, top_k_index = probs.topk( 230 | score_beam_size) # (score_beam_size,) 231 | 232 | # filter prob score that is too small 233 | filter_probs = [] 234 | filter_index = [] 235 | for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()): 236 | if keywords_tokenset is not None: 237 | if prob > 0.05 and idx in keywords_tokenset: 238 | filter_probs.append(prob) 239 | filter_index.append(idx) 240 | else: 241 | if prob > 0.05: 242 | filter_probs.append(prob) 243 | filter_index.append(idx) 244 | 245 | if len(filter_index) == 0: 246 | continue 247 | 248 | for s in filter_index: 249 | ps = probs[s].item() 250 | 251 | for prefix, (pb, pnb, cur_nodes) in cur_hyps: 252 | last = prefix[-1] if len(prefix) > 0 else None 253 | if s == 0: # blank 254 | n_pb, n_pnb, nodes = next_hyps[prefix] 255 | n_pb = n_pb + pb * ps + pnb * ps 256 | nodes = cur_nodes.copy() 257 | next_hyps[prefix] = (n_pb, n_pnb, nodes) 258 | elif s == last: 259 | if not math.isclose(pnb, 0.0, abs_tol=0.000001): 260 | # Update *ss -> *s; 261 | n_pb, n_pnb, nodes = next_hyps[prefix] 262 | n_pnb = n_pnb + pnb * ps 263 | nodes = cur_nodes.copy() 264 | if ps > nodes[-1]['prob']: # update frame and prob 265 | nodes[-1]['prob'] = ps 266 | nodes[-1]['frame'] = t 267 | next_hyps[prefix] = (n_pb, n_pnb, nodes) 268 | 269 | if not math.isclose(pb, 0.0, abs_tol=0.000001): 270 | # Update *s-s -> *ss, - is for blank 271 | n_prefix = prefix + (s, ) 272 | n_pb, n_pnb, nodes = next_hyps[n_prefix] 273 | n_pnb = n_pnb + pb * ps 274 | nodes = cur_nodes.copy() 275 | nodes.append(dict(token=s, frame=t, 276 | prob=ps)) # to record token prob 277 | next_hyps[n_prefix] = (n_pb, n_pnb, nodes) 278 | else: 279 | n_prefix = prefix + (s, ) 280 | n_pb, n_pnb, nodes = next_hyps[n_prefix] 281 | if nodes: 282 | if ps > nodes[-1]['prob']: # update frame and prob 283 | # nodes[-1]['prob'] = ps 284 | # nodes[-1]['frame'] = t 285 | # avoid change other beam which has this node. 286 | nodes.pop() 287 | nodes.append(dict(token=s, frame=t, prob=ps)) 288 | else: 289 | nodes = cur_nodes.copy() 290 | nodes.append(dict(token=s, frame=t, 291 | prob=ps)) # to record token prob 292 | n_pnb = n_pnb + pb * ps + pnb * ps 293 | next_hyps[n_prefix] = (n_pb, n_pnb, nodes) 294 | 295 | # 2.2 Second beam prune 296 | next_hyps = sorted( 297 | next_hyps.items(), key=lambda x: (x[1][0] + x[1][1]), reverse=True) 298 | 299 | cur_hyps = next_hyps[:path_beam_size] 300 | 301 | hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in cur_hyps] 302 | return hyps 303 | 304 | 305 | class Calculator: 306 | 307 | def __init__(self): 308 | self.data = {} 309 | self.space = [] 310 | self.cost = {} 311 | self.cost['cor'] = 0 312 | self.cost['sub'] = 1 313 | self.cost['del'] = 1 314 | self.cost['ins'] = 1 315 | 316 | def calculate(self, lab, rec): 317 | # Initialization 318 | lab.insert(0, '') 319 | rec.insert(0, '') 320 | while len(self.space) < len(lab): 321 | self.space.append([]) 322 | for row in self.space: 323 | for element in row: 324 | element['dist'] = 0 325 | element['error'] = 'non' 326 | while len(row) < len(rec): 327 | row.append({'dist': 0, 'error': 'non'}) 328 | for i in range(len(lab)): 329 | self.space[i][0]['dist'] = i 330 | self.space[i][0]['error'] = 'del' 331 | for j in range(len(rec)): 332 | self.space[0][j]['dist'] = j 333 | self.space[0][j]['error'] = 'ins' 334 | self.space[0][0]['error'] = 'non' 335 | for token in lab: 336 | if token not in self.data and len(token) > 0: 337 | self.data[token] = { 338 | 'all': 0, 339 | 'cor': 0, 340 | 'sub': 0, 341 | 'ins': 0, 342 | 'del': 0 343 | } 344 | for token in rec: 345 | if token not in self.data and len(token) > 0: 346 | self.data[token] = { 347 | 'all': 0, 348 | 'cor': 0, 349 | 'sub': 0, 350 | 'ins': 0, 351 | 'del': 0 352 | } 353 | # Computing edit distance 354 | for i, lab_token in enumerate(lab): 355 | for j, rec_token in enumerate(rec): 356 | if i == 0 or j == 0: 357 | continue 358 | min_dist = sys.maxsize 359 | min_error = 'none' 360 | dist = self.space[i - 1][j]['dist'] + self.cost['del'] 361 | error = 'del' 362 | if dist < min_dist: 363 | min_dist = dist 364 | min_error = error 365 | dist = self.space[i][j - 1]['dist'] + self.cost['ins'] 366 | error = 'ins' 367 | if dist < min_dist: 368 | min_dist = dist 369 | min_error = error 370 | if lab_token == rec_token: 371 | dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor'] 372 | error = 'cor' 373 | else: 374 | dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub'] 375 | error = 'sub' 376 | if dist < min_dist: 377 | min_dist = dist 378 | min_error = error 379 | self.space[i][j]['dist'] = min_dist 380 | self.space[i][j]['error'] = min_error 381 | # Tracing back 382 | result = { 383 | 'lab': [], 384 | 'rec': [], 385 | 'all': 0, 386 | 'cor': 0, 387 | 'sub': 0, 388 | 'ins': 0, 389 | 'del': 0 390 | } 391 | i = len(lab) - 1 392 | j = len(rec) - 1 393 | while True: 394 | if self.space[i][j]['error'] == 'cor': # correct 395 | if len(lab[i]) > 0: 396 | self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 397 | self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1 398 | result['all'] = result['all'] + 1 399 | result['cor'] = result['cor'] + 1 400 | result['lab'].insert(0, lab[i]) 401 | result['rec'].insert(0, rec[j]) 402 | i = i - 1 403 | j = j - 1 404 | elif self.space[i][j]['error'] == 'sub': # substitution 405 | if len(lab[i]) > 0: 406 | self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 407 | self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1 408 | result['all'] = result['all'] + 1 409 | result['sub'] = result['sub'] + 1 410 | result['lab'].insert(0, lab[i]) 411 | result['rec'].insert(0, rec[j]) 412 | i = i - 1 413 | j = j - 1 414 | elif self.space[i][j]['error'] == 'del': # deletion 415 | if len(lab[i]) > 0: 416 | self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 417 | self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1 418 | result['all'] = result['all'] + 1 419 | result['del'] = result['del'] + 1 420 | result['lab'].insert(0, lab[i]) 421 | result['rec'].insert(0, '') 422 | i = i - 1 423 | elif self.space[i][j]['error'] == 'ins': # insertion 424 | if len(rec[j]) > 0: 425 | self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1 426 | result['ins'] = result['ins'] + 1 427 | result['lab'].insert(0, '') 428 | result['rec'].insert(0, rec[j]) 429 | j = j - 1 430 | elif self.space[i][j]['error'] == 'non': # starting point 431 | break 432 | else: # shouldn't reach here 433 | print( 434 | 'this should not happen, ' 435 | 'i = {i} , j = {j} , error = {error}' 436 | .format(i=i, j=j, error=self.space[i][j]['error'])) 437 | return result 438 | 439 | def overall(self): 440 | result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} 441 | for token in self.data: 442 | result['all'] = result['all'] + self.data[token]['all'] 443 | result['cor'] = result['cor'] + self.data[token]['cor'] 444 | result['sub'] = result['sub'] + self.data[token]['sub'] 445 | result['ins'] = result['ins'] + self.data[token]['ins'] 446 | result['del'] = result['del'] + self.data[token]['del'] 447 | return result 448 | 449 | def cluster(self, data): 450 | result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} 451 | for token in data: 452 | if token in self.data: 453 | result['all'] = result['all'] + self.data[token]['all'] 454 | result['cor'] = result['cor'] + self.data[token]['cor'] 455 | result['sub'] = result['sub'] + self.data[token]['sub'] 456 | result['ins'] = result['ins'] + self.data[token]['ins'] 457 | result['del'] = result['del'] + self.data[token]['del'] 458 | return result 459 | 460 | def keys(self): 461 | return list(self.data.keys()) 462 | -------------------------------------------------------------------------------- /model_convert/model/mdtc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2021 Jingyong Hou (houjingyong@gmail.com) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Tuple 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | 22 | 23 | class DSDilatedConv1d(nn.Module): 24 | """Dilated Depthwise-Separable Convolution""" 25 | def __init__( 26 | self, 27 | in_channels: int, 28 | out_channels: int, 29 | kernel_size: int, 30 | dilation: int = 1, 31 | stride: int = 1, 32 | bias: bool = True, 33 | ): 34 | super(DSDilatedConv1d, self).__init__() 35 | self.padding = dilation * (kernel_size - 1) 36 | self.conv = nn.Conv1d( 37 | in_channels, 38 | in_channels, 39 | kernel_size, 40 | padding=0, 41 | dilation=dilation, 42 | stride=stride, 43 | groups=in_channels, 44 | bias=bias, 45 | ) 46 | self.bn = nn.BatchNorm1d(in_channels) 47 | self.pointwise = nn.Conv1d(in_channels, 48 | out_channels, 49 | kernel_size=1, 50 | padding=0, 51 | dilation=1, 52 | bias=bias) 53 | 54 | def forward(self, inputs: torch.Tensor): 55 | outputs = self.conv(inputs) 56 | outputs = self.bn(outputs) 57 | outputs = self.pointwise(outputs) 58 | return outputs 59 | 60 | 61 | class TCNBlock(nn.Module): 62 | def __init__( 63 | self, 64 | in_channels: int, 65 | res_channels: int, 66 | kernel_size: int, 67 | dilation: int, 68 | causal: bool, 69 | ): 70 | super(TCNBlock, self).__init__() 71 | self.in_channels = in_channels 72 | self.res_channels = res_channels 73 | self.kernel_size = kernel_size 74 | self.dilation = dilation 75 | self.causal = causal 76 | self.padding = dilation * (kernel_size - 1) 77 | self.half_padding = self.padding // 2 78 | self.conv1 = DSDilatedConv1d( 79 | in_channels=in_channels, 80 | out_channels=res_channels, 81 | kernel_size=kernel_size, 82 | dilation=dilation, 83 | ) 84 | self.bn1 = nn.BatchNorm1d(res_channels) 85 | self.relu1 = nn.ReLU() 86 | 87 | self.conv2 = nn.Conv1d(in_channels=res_channels, 88 | out_channels=res_channels, 89 | kernel_size=1) 90 | self.bn2 = nn.BatchNorm1d(res_channels) 91 | self.relu2 = nn.ReLU() 92 | 93 | def forward( 94 | self, 95 | inputs: torch.Tensor, 96 | cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) 97 | ) -> Tuple[torch.Tensor, torch.Tensor]: 98 | """ 99 | Args: 100 | inputs(torch.Tensor): Input tensor (B, D, T) 101 | cache(torch.Tensor): Input cache(B, D, self.padding) 102 | Returns: 103 | torch.Tensor(B, D, T): outputs 104 | torch.Tensor(B, D, self.padding): new cache 105 | """ 106 | if cache.size(0) == 0: 107 | outputs = F.pad(inputs, (self.padding, 0), value=0.0) 108 | else: 109 | outputs = torch.cat((cache, inputs), dim=2) 110 | assert outputs.size(2) > self.padding 111 | new_cache = outputs[:, :, -self.padding:] 112 | 113 | outputs = self.relu1(self.bn1(self.conv1(outputs))) 114 | outputs = self.bn2(self.conv2(outputs)) 115 | if self.in_channels == self.res_channels: 116 | res_out = self.relu2(outputs + inputs) 117 | else: 118 | res_out = self.relu2(outputs) 119 | return res_out, new_cache 120 | 121 | 122 | class TCNStack(nn.Module): 123 | def __init__( 124 | self, 125 | in_channels: int, 126 | stack_num: int, 127 | stack_size: int, 128 | res_channels: int, 129 | kernel_size: int, 130 | causal: bool, 131 | ): 132 | super(TCNStack, self).__init__() 133 | self.in_channels = in_channels 134 | self.stack_num = stack_num 135 | self.stack_size = stack_size 136 | self.res_channels = res_channels 137 | self.kernel_size = kernel_size 138 | self.causal = causal 139 | self.res_blocks = self.stack_tcn_blocks() 140 | self.padding = self.calculate_padding() 141 | 142 | def calculate_padding(self): 143 | padding = 0 144 | for block in self.res_blocks: 145 | padding += block.padding 146 | return padding 147 | 148 | def build_dilations(self): 149 | dilations = [] 150 | for s in range(0, self.stack_size): 151 | for l in range(0, self.stack_num): 152 | dilations.append(2**l) 153 | return dilations 154 | 155 | def stack_tcn_blocks(self): 156 | dilations = self.build_dilations() 157 | res_blocks = nn.ModuleList() 158 | 159 | res_blocks.append( 160 | TCNBlock( 161 | self.in_channels, 162 | self.res_channels, 163 | self.kernel_size, 164 | dilations[0], 165 | self.causal, 166 | )) 167 | for dilation in dilations[1:]: 168 | res_blocks.append( 169 | TCNBlock( 170 | self.res_channels, 171 | self.res_channels, 172 | self.kernel_size, 173 | dilation, 174 | self.causal, 175 | )) 176 | return res_blocks 177 | 178 | def forward( 179 | self, 180 | inputs: torch.Tensor, 181 | in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) 182 | ) -> Tuple[torch.Tensor, torch.Tensor]: 183 | outputs = inputs # (B, D, T) 184 | out_caches = [] 185 | offset = 0 186 | for block in self.res_blocks: 187 | if in_cache.size(0) > 0: 188 | c_in = in_cache[:, :, offset:offset + block.padding] 189 | else: 190 | c_in = torch.zeros(0, 0, 0) 191 | outputs, c_out = block(outputs, c_in) 192 | out_caches.append(c_out) 193 | offset += block.padding 194 | new_cache = torch.cat(out_caches, dim=2) 195 | return outputs, new_cache 196 | 197 | 198 | class MDTC(nn.Module): 199 | """Multi-scale Depthwise Temporal Convolution (MDTC). 200 | In MDTC, stacked depthwise one-dimensional (1-D) convolution with 201 | dilated connections is adopted to efficiently model long-range 202 | dependency of speech. With a large receptive field while 203 | keeping a small number of model parameters, the structure 204 | can model temporal context of speech effectively. It aslo 205 | extracts multi-scale features from different hidden layers 206 | of MDTC with different receptive fields. 207 | """ 208 | def __init__( 209 | self, 210 | stack_num: int, 211 | stack_size: int, 212 | in_channels: int, 213 | res_channels: int, 214 | kernel_size: int, 215 | causal: bool, 216 | ): 217 | super(MDTC, self).__init__() 218 | assert kernel_size % 2 == 1 219 | self.kernel_size = kernel_size 220 | assert causal is True, "we now only support causal mdtc" 221 | self.causal = causal 222 | self.preprocessor = TCNBlock(in_channels, 223 | res_channels, 224 | kernel_size, 225 | dilation=1, 226 | causal=causal) 227 | self.relu = nn.ReLU() 228 | self.blocks = nn.ModuleList() 229 | self.padding = self.preprocessor.padding 230 | for i in range(stack_num): 231 | self.blocks.append( 232 | TCNStack(res_channels, stack_size, 1, res_channels, 233 | kernel_size, causal)) 234 | self.padding += self.blocks[-1].padding 235 | self.half_padding = self.padding // 2 236 | print('Receptive Fields: %d' % self.padding) 237 | 238 | def forward( 239 | self, 240 | x: torch.Tensor, 241 | in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) 242 | ) -> Tuple[torch.Tensor, torch.Tensor]: 243 | outputs = x.transpose(1, 2) # (B, D, T) 244 | outputs_list = [] 245 | out_caches = [] 246 | offset = 0 247 | if in_cache.size(0) > 0: 248 | c_in = in_cache[:, :, offset:offset + self.preprocessor.padding] 249 | else: 250 | c_in = torch.zeros(0, 0, 0) 251 | 252 | outputs, c_out = self.preprocessor(outputs, c_in) 253 | outputs = self.relu(outputs) 254 | out_caches.append(c_out) 255 | offset += self.preprocessor.padding 256 | for block in self.blocks: 257 | if in_cache.size(0) > 0: 258 | c_in = in_cache[:, :, offset:offset + block.padding] 259 | else: 260 | c_in = torch.zeros(0, 0, 0) 261 | outputs, c_out = block(outputs, c_in) 262 | outputs_list.append(outputs) 263 | out_caches.append(c_out) 264 | offset += block.padding 265 | 266 | outputs = torch.zeros_like(outputs_list[-1], dtype=outputs_list[-1].dtype) 267 | for x in outputs_list: 268 | outputs += x 269 | outputs = outputs.transpose(1, 2) # (B, T, D) 270 | new_cache = torch.cat(out_caches, dim=2) 271 | return outputs, new_cache 272 | 273 | 274 | if __name__ == '__main__': 275 | mdtc = MDTC(3, 4, 64, 64, 5, causal=True) 276 | print(mdtc) 277 | 278 | num_params = sum(p.numel() for p in mdtc.parameters()) 279 | print('the number of model params: {}'.format(num_params)) 280 | x = torch.randn(128, 200, 64) # batch-size * time * dim 281 | y, c = mdtc(x) 282 | print('input shape: {}'.format(x.shape)) 283 | print('output shape: {}'.format(y.shape)) 284 | print('cache shape: {}'.format(c.shape)) 285 | 286 | print('########################################') 287 | for _ in range(10): 288 | y, c = mdtc(y, c) 289 | print('output shape: {}'.format(y.shape)) 290 | print('cache shape: {}'.format(c.shape)) 291 | -------------------------------------------------------------------------------- /model_convert/model/subsampling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Binbin Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | # There is no right context or lookahead in our Subsampling design, so 18 | # If there is CNN in Subsampling, it's a causal CNN. 19 | 20 | 21 | class SubsamplingBase(torch.nn.Module): 22 | def __init__(self): 23 | super().__init__() 24 | self.subsampling_rate = 1 25 | 26 | 27 | class NoSubsampling(SubsamplingBase): 28 | """No subsampling in accordance to the 'none' preprocessing 29 | """ 30 | def __init__(self): 31 | super().__init__() 32 | 33 | def forward(self, x: torch.Tensor) -> torch.Tensor: 34 | return x 35 | 36 | 37 | class LinearSubsampling1(SubsamplingBase): 38 | """Linear transform the input without subsampling 39 | """ 40 | def __init__(self, idim: int, odim: int): 41 | super().__init__() 42 | self.out = torch.nn.Sequential( 43 | torch.nn.Linear(idim, odim), 44 | torch.nn.ReLU(), 45 | ) 46 | self.subsampling_rate = 1 47 | self.quant = torch.quantization.QuantStub() 48 | self.dequant = torch.quantization.DeQuantStub() 49 | 50 | def forward(self, x: torch.Tensor) -> torch.Tensor: 51 | x = self.quant(x) 52 | x = self.out(x) 53 | x = self.dequant(x) 54 | return x 55 | 56 | def fuse_modules(self): 57 | torch.quantization.fuse_modules(self, [['out.0', 'out.1']], 58 | inplace=True) 59 | 60 | 61 | class Conv1dSubsampling1(SubsamplingBase): 62 | """Conv1d transform without subsampling 63 | """ 64 | def __init__(self, idim: int, odim: int): 65 | super().__init__() 66 | self.out = torch.nn.Sequential( 67 | torch.nn.Conv1d(idim, odim, 3), 68 | torch.nn.BatchNorm1d(odim), 69 | torch.nn.ReLU(), 70 | ) 71 | self.subsampling_rate = 1 72 | 73 | def forward(self, x: torch.Tensor) -> torch.Tensor: 74 | x = self.out(x) 75 | return x 76 | -------------------------------------------------------------------------------- /model_convert/model/tcn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2021 Binbin Zhang 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Tuple 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | 22 | 23 | class Block(nn.Module): 24 | def __init__(self, 25 | channel: int, 26 | kernel_size: int, 27 | dilation: int, 28 | dropout: float = 0.1): 29 | super().__init__() 30 | self.padding = (kernel_size - 1) * dilation 31 | self.quant = torch.quantization.QuantStub() 32 | self.dequant = torch.quantization.DeQuantStub() 33 | 34 | def forward( 35 | self, 36 | x: torch.Tensor, 37 | cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) 38 | ) -> Tuple[torch.Tensor, torch.Tensor]: 39 | """ 40 | Args: 41 | x(torch.Tensor): Input tensor (B, D, T) 42 | cache(torch.Tensor): Input cache(B, D, self.padding) 43 | Returns: 44 | torch.Tensor(B, D, T): output 45 | torch.Tensor(B, D, self.padding): new cache 46 | """ 47 | # The CNN used here is causal convolution 48 | if cache.size(0) == 0: 49 | y = F.pad(x, (self.padding, 0), value=0.0) 50 | else: 51 | y = torch.cat((cache, x), dim=2) 52 | assert y.size(2) > self.padding 53 | new_cache = y[:, :, -self.padding:] 54 | 55 | y = self.quant(y) 56 | # self.cnn is defined in the subclass of Block 57 | y = self.cnn(y) 58 | y = self.dequant(y) 59 | y = y + x # residual connection 60 | return y, new_cache 61 | 62 | def fuse_modules(self): 63 | self.cnn.fuse_modules() 64 | 65 | 66 | class CnnBlock(Block): 67 | def __init__(self, 68 | channel: int, 69 | kernel_size: int, 70 | dilation: int, 71 | dropout: float = 0.1): 72 | super().__init__(channel, kernel_size, dilation, dropout) 73 | self.cnn = nn.Sequential( 74 | nn.Conv1d(channel, 75 | channel, 76 | kernel_size, 77 | stride=1, 78 | dilation=dilation), 79 | nn.BatchNorm1d(channel), 80 | nn.ReLU(), 81 | nn.Dropout(dropout), 82 | ) 83 | 84 | def fuse_modules(self): 85 | torch.quantization.fuse_modules(self, [['cnn.0', 'cnn.1', 'cnn.2']], 86 | inplace=True) 87 | 88 | 89 | class DsCnnBlock(Block): 90 | """ Depthwise Separable Convolution 91 | """ 92 | def __init__(self, 93 | channel: int, 94 | kernel_size: int, 95 | dilation: int, 96 | dropout: float = 0.1): 97 | super().__init__(channel, kernel_size, dilation, dropout) 98 | self.cnn = nn.Sequential( 99 | nn.Conv1d(channel, 100 | channel, 101 | kernel_size, 102 | stride=1, 103 | dilation=dilation, 104 | groups=channel), 105 | nn.BatchNorm1d(channel), 106 | nn.ReLU(), 107 | nn.Conv1d(channel, channel, kernel_size=1, stride=1), 108 | nn.BatchNorm1d(channel), 109 | nn.ReLU(), 110 | nn.Dropout(dropout), 111 | ) 112 | 113 | def fuse_modules(self): 114 | torch.quantization.fuse_modules( 115 | self, [['cnn.0', 'cnn.1', 'cnn.2'], ['cnn.3', 'cnn.4', 'cnn.5']], 116 | inplace=True) 117 | 118 | 119 | class TCN(nn.Module): 120 | def __init__(self, 121 | num_layers: int, 122 | channel: int, 123 | kernel_size: int, 124 | dropout: float = 0.1, 125 | block_class=CnnBlock): 126 | super().__init__() 127 | self.padding = 0 128 | self.network = nn.ModuleList() 129 | for i in range(num_layers): 130 | dilation = 2**i 131 | self.padding += (kernel_size - 1) * dilation 132 | self.network.append( 133 | block_class(channel, kernel_size, dilation, dropout)) 134 | 135 | def forward( 136 | self, 137 | x: torch.Tensor, 138 | in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) 139 | ) -> Tuple[torch.Tensor, torch.Tensor]: 140 | """ 141 | Args: 142 | x (torch.Tensor): Input tensor (B, T, D) 143 | in_cache(torhc.Tensor): (B, D, C), C is the accumulated cache size 144 | 145 | Returns: 146 | torch.Tensor(B, T, D) 147 | torch.Tensor(B, D, C): C is the accumulated cache size 148 | """ 149 | x = x.transpose(1, 2) # (B, D, T) 150 | out_caches = [] 151 | offset = 0 152 | for block in self.network: 153 | if in_cache.size(0) > 0: 154 | c_in = in_cache[:, :, offset:offset + block.padding] 155 | else: 156 | c_in = torch.zeros(0, 0, 0) 157 | x, c_out = block(x, c_in) 158 | out_caches.append(c_out) 159 | offset += block.padding 160 | x = x.transpose(1, 2) # (B, T, D) 161 | new_cache = torch.cat(out_caches, dim=2) 162 | return x, new_cache 163 | 164 | def fuse_modules(self): 165 | for m in self.network: 166 | m.fuse_modules() 167 | 168 | 169 | if __name__ == '__main__': 170 | tcn = TCN(4, 64, 8, block_class=CnnBlock) 171 | print(tcn) 172 | print(tcn.padding) 173 | num_params = sum(p.numel() for p in tcn.parameters()) 174 | print('the number of model params: {}'.format(num_params)) 175 | x = torch.zeros(3, 15, 64) 176 | y = tcn(x) 177 | -------------------------------------------------------------------------------- /model_convert/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Binbin Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import logging 17 | import os 18 | import re 19 | 20 | import yaml 21 | import torch 22 | 23 | 24 | def load_checkpoint(model: torch.nn.Module, path: str) -> dict: 25 | if torch.cuda.is_available(): 26 | logging.info('Checkpoint: loading from checkpoint %s for GPU' % path) 27 | checkpoint = torch.load(path) 28 | else: 29 | logging.info('Checkpoint: loading from checkpoint %s for CPU' % path) 30 | checkpoint = torch.load(path, map_location='cpu') 31 | model.load_state_dict(checkpoint) 32 | info_path = re.sub('.pt$', '.yaml', path) 33 | configs = {} 34 | if os.path.exists(info_path): 35 | with open(info_path, 'r') as fin: 36 | configs = yaml.load(fin, Loader=yaml.FullLoader) 37 | return configs 38 | 39 | 40 | def save_checkpoint(model: torch.nn.Module, path: str, infos=None): 41 | ''' 42 | Args: 43 | infos (dict or None): any info you want to save. 44 | ''' 45 | logging.info('Checkpoint: save to checkpoint %s' % path) 46 | if isinstance(model, torch.nn.DataParallel): 47 | state_dict = model.module.state_dict() 48 | elif isinstance(model, torch.nn.parallel.DistributedDataParallel): 49 | state_dict = model.module.state_dict() 50 | else: 51 | state_dict = model.state_dict() 52 | torch.save(state_dict, path) 53 | info_path = re.sub('.pt$', '.yaml', path) 54 | if infos is None: 55 | infos = {} 56 | with open(info_path, 'w') as fout: 57 | data = yaml.dump(infos) 58 | fout.write(data) 59 | -------------------------------------------------------------------------------- /model_convert/utils/cmvn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2020 Binbin Zhang 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import json 17 | import math 18 | import re 19 | 20 | import numpy as np 21 | 22 | 23 | def load_cmvn(json_cmvn_file): 24 | """ Load the json format cmvn stats file and calculate cmvn 25 | 26 | Args: 27 | json_cmvn_file: cmvn stats file in json format 28 | 29 | Returns: 30 | a numpy array of [means, vars] 31 | """ 32 | with open(json_cmvn_file) as f: 33 | cmvn_stats = json.load(f) 34 | 35 | means = cmvn_stats['mean_stat'] 36 | variance = cmvn_stats['var_stat'] 37 | count = cmvn_stats['frame_num'] 38 | for i in range(len(means)): 39 | means[i] /= count 40 | variance[i] = variance[i] / count - means[i] * means[i] 41 | if variance[i] < 1.0e-20: 42 | variance[i] = 1.0e-20 43 | variance[i] = 1.0 / math.sqrt(variance[i]) 44 | cmvn = np.array([means, variance]) 45 | return cmvn 46 | 47 | def load_kaldi_cmvn(cmvn_file): 48 | """ Load the kaldi format cmvn stats file and no need to calculate 49 | 50 | Args: 51 | cmvn_file: cmvn stats file in kaldi format 52 | 53 | Returns: 54 | a numpy array of [means, vars] 55 | """ 56 | 57 | means = None 58 | variance = None 59 | with open(cmvn_file) as f: 60 | all_lines = f.readlines() 61 | for idx, line in enumerate(all_lines): 62 | if line.find('AddShift') != -1: 63 | segs = line.strip().split(' ') 64 | assert len(segs) == 3 65 | next_line = all_lines[idx + 1] 66 | means_str = re.findall(r'[\[](.*?)[\]]', next_line)[0] 67 | means_list = means_str.strip().split(' ') 68 | means = [0 - float(s) for s in means_list] 69 | assert len(means) == int(segs[1]) 70 | elif line.find('Rescale') != -1: 71 | segs = line.strip().split(' ') 72 | assert len(segs) == 3 73 | next_line = all_lines[idx + 1] 74 | vars_str = re.findall(r'[\[](.*?)[\]]', next_line)[0] 75 | vars_list = vars_str.strip().split(' ') 76 | variance = [float(s) for s in vars_list] 77 | assert len(variance) == int(segs[1]) 78 | elif line.find('Splice') != -1: 79 | segs = line.strip().split(' ') 80 | assert len(segs) == 3 81 | next_line = all_lines[idx + 1] 82 | splice_str = re.findall(r'[\[](.*?)[\]]', next_line)[0] 83 | splice_list = splice_str.strip().split(' ') 84 | assert len(splice_list) * int(segs[2]) == int(segs[1]) 85 | copy_times = len(splice_list) 86 | else: 87 | continue 88 | 89 | cmvn = np.array([means, variance]) 90 | cmvn = np.tile(cmvn, (1, copy_times)) 91 | 92 | return cmvn 93 | -------------------------------------------------------------------------------- /model_convert/utils/executor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Binbin Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | 17 | import torch 18 | from torch.nn.utils import clip_grad_norm_ 19 | 20 | from model_convert.model.loss import criterion 21 | 22 | 23 | class Executor: 24 | def __init__(self): 25 | self.step = 0 26 | 27 | def train(self, model, optimizer, data_loader, device, writer, args): 28 | ''' Train one epoch 29 | ''' 30 | model.train() 31 | clip = args.get('grad_clip', 50.0) 32 | log_interval = args.get('log_interval', 10) 33 | epoch = args.get('epoch', 0) 34 | min_duration = args.get('min_duration', 0) 35 | 36 | for batch_idx, batch in enumerate(data_loader): 37 | key, feats, target, feats_lengths, label_lengths = batch 38 | feats = feats.to(device) 39 | target = target.to(device) 40 | feats_lengths = feats_lengths.to(device) 41 | label_lengths = label_lengths.to(device) 42 | num_utts = feats_lengths.size(0) 43 | if num_utts == 0: 44 | continue 45 | logits, _ = model(feats) 46 | loss_type = args.get('criterion', 'max_pooling') 47 | loss, acc = criterion(loss_type, logits, target, feats_lengths, 48 | target_lengths=label_lengths, 49 | min_duration=min_duration, 50 | validation=False) 51 | optimizer.zero_grad() 52 | loss.backward() 53 | grad_norm = clip_grad_norm_(model.parameters(), clip) 54 | if torch.isfinite(grad_norm): 55 | optimizer.step() 56 | if batch_idx % log_interval == 0: 57 | logging.debug( 58 | 'TRAIN Batch {}/{} loss {:.8f} acc {:.8f}'.format( 59 | epoch, batch_idx, loss.item(), acc)) 60 | 61 | def cv(self, model, data_loader, device, args): 62 | ''' Cross validation on 63 | ''' 64 | model.eval() 65 | log_interval = args.get('log_interval', 10) 66 | epoch = args.get('epoch', 0) 67 | # in order to avoid division by 0 68 | num_seen_utts = 1 69 | total_loss = 0.0 70 | total_acc = 0.0 71 | with torch.no_grad(): 72 | for batch_idx, batch in enumerate(data_loader): 73 | key, feats, target, feats_lengths, label_lengths = batch 74 | feats = feats.to(device) 75 | target = target.to(device) 76 | feats_lengths = feats_lengths.to(device) 77 | label_lengths = label_lengths.to(device) 78 | num_utts = feats_lengths.size(0) 79 | if num_utts == 0: 80 | continue 81 | logits, _ = model(feats) 82 | loss, acc = criterion(args.get('criterion', 'max_pooling'), 83 | logits, target, feats_lengths, 84 | target_lengths=label_lengths, 85 | min_duration=0, 86 | validation=True) 87 | if torch.isfinite(loss): 88 | num_seen_utts += num_utts 89 | total_loss += loss.item() * num_utts 90 | total_acc += acc * num_utts 91 | if batch_idx % log_interval == 0: 92 | logging.debug( 93 | 'CV Batch {}/{} loss {:.8f} acc {:.8f} history loss {:.8f}' 94 | .format(epoch, batch_idx, loss.item(), acc, 95 | total_loss / num_seen_utts)) 96 | return total_loss / num_seen_utts, total_acc / num_seen_utts 97 | 98 | def test(self, model, data_loader, device, args): 99 | return self.cv(model, data_loader, device, args) 100 | -------------------------------------------------------------------------------- /model_convert/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | def read_lists(list_file): 17 | lists = [] 18 | with open(list_file, 'r', encoding='utf8') as fin: 19 | for line in fin: 20 | lists.append(line.strip()) 21 | return lists 22 | 23 | 24 | def read_symbol_table(symbol_table_file): 25 | symbol_table = {} 26 | with open(symbol_table_file, 'r', encoding='utf8') as fin: 27 | for line in fin: 28 | arr = line.strip().split() 29 | assert len(arr) == 2 30 | symbol_table[arr[0]] = int(arr[1]) 31 | return symbol_table 32 | -------------------------------------------------------------------------------- /model_convert/utils/mask.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Binbin Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | 18 | def padding_mask(lengths: torch.Tensor) -> torch.Tensor: 19 | """ 20 | Examples: 21 | >>> lengths = torch.tensor([2, 2, 3], dtype=torch.int32) 22 | >>> mask = padding_mask(lengths) 23 | >>> print(mask) 24 | tensor([[False, False, True], 25 | [False, False, True], 26 | [False, False, False]]) 27 | """ 28 | batch_size = lengths.size(0) 29 | max_len = int(lengths.max().item()) 30 | seq = torch.arange(max_len, dtype=torch.int64, device=lengths.device) 31 | seq = seq.expand(batch_size, max_len) 32 | return seq >= lengths.unsqueeze(1) 33 | -------------------------------------------------------------------------------- /model_convert/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2021 Jingyong Hou (houjingyong@gmail.com) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | import numpy as np 18 | import random 19 | 20 | 21 | def set_mannul_seed(seed): 22 | np.random.seed(seed) 23 | random.seed(seed) 24 | torch.manual_seed(seed) 25 | torch.cuda.manual_seed(seed) 26 | torch.backends.cudnn.deterministic = True 27 | 28 | 29 | def count_parameters(model): 30 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 31 | -------------------------------------------------------------------------------- /onnxruntime/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.13 FATAL_ERROR) 2 | 3 | project(wekws VERSION 0.1) 4 | 5 | set(CMAKE_VERBOSE_MAKEFILE on) 6 | 7 | include(FetchContent) 8 | set(FETCHCONTENT_QUIET OFF) 9 | get_filename_component(fc_base "fc_base" REALPATH BASE_DIR "${CMAKE_CURRENT_SOURCE_DIR}") 10 | set(FETCHCONTENT_BASE_DIR ${fc_base}) 11 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) 12 | 13 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -g -pthread") 14 | include_directories(${CMAKE_CURRENT_SOURCE_DIR}) 15 | 16 | include(portaudio) 17 | include(onnxruntime) 18 | add_subdirectory(frontend) 19 | add_subdirectory(kws) 20 | add_subdirectory(bin) 21 | -------------------------------------------------------------------------------- /onnxruntime/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # CPP ONNX推理测试. 4 | 5 | 6 | 7 | ```shell 8 | git clone https://github.com/chenyangMl/keyword-spot.git 9 | cd keyword-spot/onnxruntime/ 10 | mkdir build && cd build 11 | cmake .. 12 | cmake --build . --target kws_main 13 | 14 | #不同模型使用如下对应参数进行模型推理。 15 | ``` 16 | 17 | 18 | 19 | ## [测试音频](../audio) 20 | 21 | | 音频名称 | 关键字 | 22 | | --------------------------------------- | -------------- | 23 | | 000af5671fdbaa3e55c5e2bd0bdf8cdd_hi.wav | 嗨小问 | 24 | | 000eae543947c70feb9401f82da03dcf_hi.wav | 嗨小问 | 25 | | 0000c7286ebc7edef1c505b78d5ed1a3.wav | 你好问问 | 26 | | 0000e12e2402775c2d506d77b6dbb411.wav | 你好问问 | 27 | | gongqu-4.5_0000.wav | 其他(测试负例) | 28 | | 000af5671fdbaa3e55c5e2bd0bdf8cdd_hi.pcm | 嗨小问 | 29 | 30 | 31 | 32 | ## Max-Pooling方案模型 33 | 34 | - 非流式模式 35 | 36 | ``` 37 | cd build/bin 38 | ./kws_main [solution_type, int] [num_bins, int] [batch_size, int] [model_path, str] [wave_path,str] 39 | 40 | #eg 41 | ./kws_main 0 40 1 keyword-spot-dstcn-maxpooling-wenwen/onnx/keyword-spot-dstcn-maxpooling-wenwen.ort ../../../audio/0000c7286ebc7edef1c505b78d5ed1a3.wav 42 | ``` 43 | 44 | 测试日志: frame表示当前处理的time step. prob的第一列表示关键词1的分类概率,  第二列关键词2的分类概率。 45 | 46 | ``` 47 | > Kws Model Info: 48 | > cache_dim: 256 49 | > cache_len: 105 50 | > frame 0 prob 4.17233e-07 1.49012e-06 51 | > frame 1 prob 3.8743e-07 1.2517e-06 52 | > frame 2 prob 1.19209e-07 5.36442e-07 53 | > frame 3 prob 2.98023e-07 3.01003e-06 54 | > 55 | > ... 56 | > 57 | > frame 100 prob 0.963686 activated keyword: 嗨小问 0 58 | > frame 101 prob 0.955697 activated keyword: 嗨小问 0 59 | > frame 102 prob 0.94719 activated keyword: 嗨小问 0 60 | > frame 103 prob 0.909599 activated keyword: 嗨小问 2.98023e-08 61 | > frame 104 prob 0.985421 activated keyword: 嗨小问 2.98023e-08 62 | > frame 105 prob 0.926912 activated keyword: 嗨小问 2.98023e-08 63 | > frame 106 prob 0.980361 activated keyword: 嗨小问 0 64 | > frame 107 prob 0.988708 activated keyword: 嗨小问 0 65 | > frame 108 prob 0.998589 activated keyword: 嗨小问 0 66 | > 67 | > ... 68 | > 69 | > frame 149 prob 2.98023e-08 0 70 | > frame 150 prob 8.9407e-08 0 71 | > frame 151 prob 1.19209e-07 2.98023e-08 72 | > frame 152 prob 0 0 73 | > frame 153 prob 2.98023e-08 0 74 | > frame 154 prob 2.98023e-08 0 75 | > 76 | > Process finished with exit code 0 77 | ``` 78 | 79 | 80 | 81 | - 流式模式 82 | 83 | 84 | 85 | ``` 86 | #测试流式模式,请先编译流式模块。 87 | cd build/ 88 | cmake --build . --target kws_main 89 | 90 | cd build/bin 91 | ./stream_kws_main [solution_type, int] [num_bins, int] [batch_size, int] [model_path, str] [wave_path,str] 92 | 93 | #eg 94 | ./stream_kws_main 0 40 80 keyword-spot-dstcn-maxpooling-wenwen/onnx/keyword-spot-dstcn-maxpooling-wenwen.ort 95 | ``` 96 | 97 | PS: 需要提前接入麦克风进行音频输入。 98 | 99 | 100 | 101 | ## CTC 方案模型 102 | 103 | ``` 104 | cd build/bin 105 | ./kws_main [solution_type, int] [num_bins, int] [batch_size, int] [model_path, str] [wave_path,str] [key_word,str] 106 | 107 | #eg 108 | ./kws_main 1 80 1 keyword-spot-fsmn-ctc-wenwen/onnx/keyword_spot_fsmn_ctc_wenwen.ort ../../../audio/0000c7286ebc7edef1c505b78d5ed1a3.wav 你好问问 109 | ``` 110 | 111 | 测试日志: 如下是CTC prefix beam search的 112 | 113 | frame表示当前处理的time step. tokenid:表示当前帧识别到的Token ID. proposed:表示基于当前假设(current hypotheses) 的扩展(proposed extensions). 建议参考图示[Sequence Modeling With CTC](https://distill.pub/2017/ctc/) 理解。prob表示该token的分类概率。 114 | 115 | ``` 116 | Kws Model Info: 117 | cache_dim: 128 118 | cache_len: 11 119 | stepT= 0 tokenid= 0 proposed i=0 prob=0.952 120 | stepT= 3 tokenid= 0 proposed i=0 prob=0.943 121 | stepT= 6 tokenid= 0 proposed i=0 prob=0.946 122 | stepT= 9 tokenid= 0 proposed i=0 prob=0.965 123 | stepT= 12 tokenid= 0 proposed i=0 prob=0.801 124 | 125 | ... 126 | 127 | stepT=129 tokenid= 0 proposed i=0 prob=1 128 | stepT=132 tokenid=2494 proposed i=0 prob=0.954 129 | hitword=你好问问 130 | hitscore=0.954 131 | start frame=69 end frame=132 132 | stepT=135 tokenid=2494 proposed i=0 prob=1 133 | stepT=138 tokenid= 0 proposed i=0 prob=1 134 | stepT=141 tokenid= 0 proposed i=0 prob=1 135 | stepT=144 tokenid= 0 proposed i=0 prob=1 136 | stepT=147 tokenid= 0 proposed i=0 prob=1 137 | stepT=150 tokenid= 0 proposed i=0 prob=1 138 | stepT=153 tokenid= 0 proposed i=0 prob=1 139 | stepT=156 tokenid= 0 proposed i=0 prob=1 140 | stepT=159 tokenid= 0 proposed i=0 prob=1 141 | stepT=162 tokenid= 0 proposed i=0 prob=1 142 | stepT=165 tokenid= 0 proposed i=0 prob=1 143 | 144 | Process finished with exit code 0 145 | ``` 146 | 147 | 148 | 149 | - 流式模式 150 | 151 | ``` 152 | #测试流式模式,请先编译流式模块。 153 | cd build/ 154 | cmake --build . --target stream_kws_main 155 | 156 | cd build/bin 157 | ./stream_kws_main [solution_type, int] [num_bins, int] [batch_size, int] [model_path, str] 158 | 159 | #eg 160 | ./stream_kws_main 1 80 80 models/keyword-spot-fsmn-ctc-wenwen/onnx/keyword_spot_fsmn_ctc_wenwen.ort 161 | ``` 162 | 163 | PS: 164 | 165 | - solution_type:{0:表示max-pooling方案, 1:表示ctc方案} 166 | - key_word: {你好问问,嗨小问} 167 | - 需要提前接入麦克风进行音频输入。 168 | 169 | 如需要其他端测的推理测试,可参考wekws提供的[Android, RaspberryPI示例](https://github.com/wenet-e2e/wekws/tree/main/runtime)。T 170 | 171 | 172 | 173 | - 流式模式测试文件目录。 174 | 175 | 对一个目录中的wav文件进行批量测试,音频输入形式模拟流式处理。 176 | 177 | ``` 178 | #测试流式模式,请先编译流式模块。 179 | cd build/ 180 | cmake --build . --target stream_kws_testing 181 | 182 | cd build/bin 183 | ./stream_kws_testing [solution_type, int] [num_bins, int] [model_path, str] [test_dir, str] [interval, int] 184 | test_dir: 测试目录。 185 | interval: 音频输入的间隔,ms为单位。 186 | 187 | #eg 188 | ./stream_kws_testing 1 80 models/keyword-spot-fsmn-ctc-wenwen/onnx/keyword_spot_fsmn_ctc_wenwen.ort audio/ 200 189 | ``` 190 | 191 | -------------------------------------------------------------------------------- /onnxruntime/bin/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | find_package(Boost 1.55.0 REQUIRED COMPONENTS system filesystem) 3 | include_directories(untitled ${Boost_INCLUDE_DIRS}) 4 | 5 | add_executable(kws_main kws_main.cc) 6 | target_link_libraries(kws_main PUBLIC onnxruntime frontend kws ${Boost_LIBRARIES}) 7 | 8 | add_executable(stream_kws_testing stream_kws_testing.cc) 9 | target_link_libraries(stream_kws_testing PUBLIC onnxruntime frontend kws ${Boost_LIBRARIES}) 10 | 11 | add_executable(device_test device_test.cc) 12 | target_link_libraries(device_test PUBLIC portaudio_static) 13 | 14 | add_executable(stream_kws_main stream_kws_main.cc) 15 | target_link_libraries(stream_kws_main PUBLIC onnxruntime frontend kws portaudio_static) -------------------------------------------------------------------------------- /onnxruntime/bin/device_test.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * 检测输入的麦克风设备是否正常接入。 3 | * */ 4 | 5 | #include 6 | #include 7 | 8 | int main() 9 | { 10 | Pa_Initialize(); 11 | 12 | int devices = Pa_GetDeviceCount(); 13 | if (devices==0){ 14 | std::cout << "Not find audio input device." << std::endl; 15 | } 16 | 17 | for (int i = 0; i != devices; ++i) 18 | { 19 | auto * info = Pa_GetDeviceInfo(i); 20 | std::cout << info->name << std::endl; 21 | } 22 | 23 | Pa_Terminate(); 24 | } -------------------------------------------------------------------------------- /onnxruntime/bin/kws_main.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Binbin Zhang (binbzha@qq.com) 2 | // Copyright (c) 2024 Yang Chen (cyang8050@163.com) 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | 16 | 17 | #include "frontend/feature_pipeline.h" 18 | #include "frontend/wav.h" 19 | #include "kws/keyword_spotting.h" 20 | #include "kws/utils.h" 21 | #include "utils/log.h" 22 | 23 | using namespace wekws; 24 | 25 | int main(int argc, char *argv[]) { 26 | 27 | std::string token_path, key_word; 28 | wenet::MODEL_TYPE mode_type; 29 | 30 | if (argc > 2){ 31 | mode_type = (wenet::MODEL_TYPE)std::stoi(argv[1]); 32 | if(mode_type==wenet::CTC_TYPE_MODEL){ 33 | if (argc != 7) { 34 | LOG(FATAL) << "Usage: kws_main\n ./kws_main [solution_type, int] [num_bins, int] " 35 | << "[batch_size, int] [model_path, str] [wave_path,str] [key_word,str]" ; 36 | } 37 | // Input Arguments. 38 | key_word = argv[6]; 39 | token_path = "../../kws/tokens.txt"; 40 | } else if (mode_type == wenet::MAXPOOLING_TYPE_MODEL){ 41 | if (argc != 6) { 42 | LOG(FATAL) << "Usage: kws_main\n [solution_type, int] [num_bins, int] " 43 | <<"[batch_size, int] [model_path, str] [wave_path,str]" ; 44 | } 45 | token_path = "../../kws/maxpooling_keyword.txt"; 46 | } 47 | }else{ 48 | LOG(FATAL) << "Usage: kws_main\n [solution_type, int] [num_bins, int] " 49 | <<"[batch_size, int] [model_path, str] [wave_path,str]" ; 50 | } 51 | 52 | // Input Arguments. 53 | const int num_bins = std::stoi(argv[2]); // num_mel_bins in config.yaml. means dim of Fbank feature. 54 | const int batch_size = std::stoi(argv[3]); 55 | if(batch_size < 1){ 56 | LOG(FATAL) << "batch_size should greater than 0, it's equal to " << batch_size << "now"; 57 | } 58 | const std::string model_path = argv[4]; 59 | const std::string wav_path = argv[5]; 60 | 61 | boost::filesystem::path wavpath(wav_path); 62 | std::vector wav; 63 | if (wavpath.extension() == ".wav"){ 64 | // audio reader 65 | wenet::WavReader wav_reader(wav_path); 66 | int num_samples = wav_reader.num_samples(); 67 | wav.assign(wav_reader.data(), wav_reader.data() + num_samples); 68 | }else if (wavpath.extension() == ".pcm"){ 69 | read_pcm(wav_path, wav); 70 | }else{ 71 | LOG(FATAL) << "Not support format = " << wavpath.extension(); 72 | } 73 | 74 | // Setting config for handling waveform of audio, convert it to mel spectrogram of audio. 75 | // Only support CTC_TYPE_MODEL. 76 | wenet::FeaturePipelineConfig feature_config(num_bins, 16000, mode_type); 77 | wenet::FeaturePipeline feature_pipeline(feature_config); 78 | feature_pipeline.AcceptWaveform(wav); 79 | feature_pipeline.set_input_finished(); 80 | 81 | wekws::KeywordSpotting spotter(model_path, wekws::DECODE_PREFIX_BEAM_SEARCH, mode_type); 82 | spotter.readToken(token_path); 83 | if(mode_type==1){ 84 | // set keyword 85 | spotter.setKeyWord(key_word); 86 | } 87 | 88 | // Simulate streaming, detect batch by batch 89 | int offset = 0; 90 | while (true) { 91 | std::vector> feats; 92 | bool ok = feature_pipeline.Read(batch_size, &feats); 93 | std::vector> probs; // 94 | spotter.Forward(feats, &probs); 95 | 96 | if(mode_type==1){ 97 | // Reach the end of feature pipeline 98 | spotter.decode_keywords(probs, 0.2); 99 | // 每次唤醒检测结果,保存在全局变量 spotter.kwsInfo中。 100 | 101 | }else{ 102 | int flag = 0; 103 | float threshold = 0.8; // > threshold means keyword activated. < threshold means not. 104 | for (int i = 0; i < probs.size(); i++) { 105 | std::cout << "frame " << offset + i << " prob"; 106 | for (int j = 0; j < probs[i].size(); j++) { // size()=number of keywords. 107 | 108 | std::cout << " " << probs[i][j]; 109 | if (probs[i][j] > threshold){ 110 | std::cout << " activated keyword: " << spotter.mmaxpooling_keywords[j] << " "; 111 | } 112 | } 113 | std::cout << std::endl; 114 | } 115 | } 116 | 117 | if (!ok) break; 118 | offset += probs.size(); 119 | } 120 | return 0; 121 | } 122 | -------------------------------------------------------------------------------- /onnxruntime/bin/stream_kws_main.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn) 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | #include "portaudio.h" // NOLINT 21 | 22 | #include "frontend/feature_pipeline.h" 23 | #include "kws/keyword_spotting.h" 24 | #include "utils/log.h" 25 | #include 26 | 27 | int g_exiting = 0; 28 | std::shared_ptr g_feature_pipeline; 29 | 30 | void SigRoutine(int dunno) { 31 | if (dunno == SIGINT) { 32 | g_exiting = 1; 33 | } 34 | } 35 | 36 | static int RecordCallback(const void *input, void *output, 37 | unsigned long frames_count, // NOLINT 38 | const PaStreamCallbackTimeInfo *time_info, 39 | PaStreamCallbackFlags status_flags, void *user_data) { 40 | const auto *pcm_data = static_cast(input); 41 | std::vector v(pcm_data, pcm_data + frames_count); 42 | g_feature_pipeline->AcceptWaveform(v); 43 | 44 | if (g_exiting) { 45 | LOG(INFO) << "Exiting loop."; 46 | g_feature_pipeline->set_input_finished(); 47 | return paComplete; 48 | } else { 49 | return paContinue; 50 | } 51 | } 52 | 53 | int main(int argc, char *argv[]) { 54 | std::string token_path; 55 | std::string key_word; 56 | wenet::MODEL_TYPE mode_type; 57 | if (argc > 2) { 58 | mode_type = (wenet::MODEL_TYPE) std::stoi(argv[1]); 59 | if (mode_type == wenet::CTC_TYPE_MODEL) { 60 | if (argc != 6) { 61 | LOG(FATAL) << "Usage: ./stream_kws_main\n [solution_type, int] [num_bins, int] [batch_size, int]" 62 | << "[model_path, str] [key_word,str]"; 63 | } 64 | key_word = argv[5]; 65 | token_path = "../../kws/tokens.txt"; 66 | } else if (mode_type == wenet::MAXPOOLING_TYPE_MODEL) { 67 | if (argc != 5) { 68 | LOG(FATAL) << "Usage: ./stream_kws_main\n [solution_type, int] [num_bins, int] [batch_size, int]" 69 | << "[model_path, str]"; 70 | } 71 | token_path = "../../kws/maxpooling_keyword.txt"; 72 | } 73 | } else { 74 | LOG(FATAL) 75 | << "Usage: ./stream_kws_main\n [solution_type, int] [num_bins, int] [batch_size, int] [model_path, str]"; 76 | } 77 | 78 | // Input Arguments. 79 | const int num_bins = std::stoi(argv[2]); // num_mel_bins in config.yaml. means dim of Fbank feature. 80 | const int batch_size = std::stoi(argv[3]); 81 | if(batch_size < 4){ 82 | LOG(FATAL) << "batch_size should greater than 3, it's equal to " << batch_size << "now"; 83 | } 84 | const std::string model_path = argv[4]; 85 | 86 | wenet::FeaturePipelineConfig feature_config(num_bins, 16000, mode_type); 87 | g_feature_pipeline = std::make_shared(feature_config); 88 | wekws::KeywordSpotting spotter(model_path, wekws::DECODE_PREFIX_BEAM_SEARCH, mode_type); 89 | spotter.readToken(token_path); 90 | if (mode_type == 1) { 91 | // set keyword 92 | spotter.setKeyWord(key_word); 93 | } 94 | 95 | signal(SIGINT, SigRoutine); 96 | PaError err = Pa_Initialize(); 97 | PaStreamParameters params; 98 | std::cout << err << " " << Pa_GetDeviceCount() << std::endl; 99 | params.device = Pa_GetDefaultInputDevice(); 100 | if (params.device == paNoDevice) { 101 | LOG(FATAL) << "Error: No default input device."; 102 | } 103 | params.channelCount = 1; 104 | params.sampleFormat = paInt16; 105 | params.suggestedLatency = 106 | Pa_GetDeviceInfo(params.device)->defaultLowInputLatency; 107 | params.hostApiSpecificStreamInfo = NULL; 108 | PaStream *stream; 109 | // Callback and spot pcm date each `interval` ms. 110 | int interval = 500; 111 | int frames_per_buffer = 16000 / 1000 * interval; 112 | Pa_OpenStream(&stream, ¶ms, NULL, 16000, frames_per_buffer, paClipOff, 113 | RecordCallback, NULL); 114 | Pa_StartStream(stream); 115 | LOG(INFO) << "=== Now recording!! Please speak into the microphone. ==="; 116 | 117 | std::cout << std::setiosflags(std::ios::fixed) << std::setprecision(2); 118 | auto start = std::chrono::high_resolution_clock::now(); 119 | 120 | while (Pa_IsStreamActive(stream)) { 121 | Pa_Sleep(interval); 122 | std::vector> feats; 123 | g_feature_pipeline->Read(batch_size, &feats); 124 | std::vector> probs; 125 | spotter.Forward(feats, &probs); 126 | 127 | // detection key-words 128 | if (mode_type == 1) { 129 | float hitScoreThr = 0.1; // threshold of hit score. 130 | spotter.decode_keywords(probs, hitScoreThr); 131 | 132 | } else { 133 | for (int t = 0; t < probs.size(); t++) { 134 | std::cout << "keywords prob:"; 135 | for (int i = 0; i < probs[t].size(); i++) { 136 | if (probs[t][i] > 0.8) { 137 | std::cout << " kw[" << i << "] " << probs[t][i]; 138 | } 139 | //std::cout << " kw[" << i << "] " << probs[t][i]; 140 | } 141 | std::cout << std::endl; 142 | } 143 | } 144 | auto end = std::chrono::high_resolution_clock::now(); 145 | std::chrono::duration diff = end - start; 146 | std::cout << "Running time: " << diff.count() << " s\n"; 147 | } 148 | Pa_CloseStream(stream); 149 | Pa_Terminate(); 150 | 151 | return 0; 152 | } 153 | -------------------------------------------------------------------------------- /onnxruntime/bin/stream_kws_testing.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2024 Yang Chen (cyang8050@163.com) 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License 14 | // 15 | //. 16 | 17 | 18 | #include "frontend/feature_pipeline.h" 19 | #include "frontend/wav.h" 20 | #include "kws/keyword_spotting.h" 21 | #include "kws/utils.h" 22 | 23 | using namespace wekws; 24 | 25 | int main(int argc, char *argv[]) { 26 | 27 | // Input Arguments. 28 | const int num_bins = std::stoi(argv[1]); // num_mel_bins in config.yaml. means dim of Fbank feature. 29 | const int batch_size = 2; //固定batch_size测试 30 | if (batch_size < 1) { 31 | LOG(FATAL) << "batch_size should greater than 0, it's equal to " << batch_size << "now"; 32 | } 33 | const std::string model_path = argv[2]; 34 | const std::string key_word = argv[3]; 35 | const std::string token_path = "../../kws/tokens.txt"; 36 | const std::string test_dir = argv[4]; 37 | const int interval = std::stoi(argv[5]); // 每次输入多少ms的音频数据 38 | 39 | // Setting config for handling waveform of audio, convert it to mel spectrogram of audio. 40 | // Only support CTC_TYPE_MODEL. 41 | wekws::KeywordSpotting spotter(model_path, wekws::DECODE_PREFIX_BEAM_SEARCH, 1); 42 | spotter.readToken(token_path); 43 | spotter.setKeyWord(key_word); 44 | 45 | std::vector wavepath; 46 | // walk path, collection all wave file. 47 | boost::filesystem::path directory(test_dir); 48 | process_directory(directory, wavepath); 49 | 50 | wenet::FeaturePipelineConfig feature_config(num_bins, 16000, wenet::CTC_TYPE_MODEL); 51 | wenet::FeaturePipeline feature_pipeline(feature_config); 52 | feature_pipeline.set_input_finished(); 53 | 54 | int TP = 0, FN = 0; 55 | std::vector errorCases; 56 | for (const std::string wav_path: wavepath) { 57 | 58 | // audio reader 59 | wenet::WavReader wav_reader(wav_path); 60 | int num_samples = wav_reader.num_samples(); 61 | std::vector wav1(wav_reader.data(), wav_reader.data() + num_samples); 62 | 63 | int count = 0; 64 | //分段传入, 每100ms传入一次数据,100ms=100*(16000/1000)=1600nums, 100ms=1600*2=3200 bytes。 65 | std::vector wav; 66 | spotter.reset_value(); 67 | spotter.stepClear(); 68 | int numBytes = 0; 69 | bool flag = false; 70 | // Simulate streaming, detect batch by batch 71 | while (!wav1.empty()){ 72 | wav.push_back(wav1.front()); 73 | wav1.erase(wav1.begin()); 74 | count +=1; 75 | if (count < interval*16 && !wav1.empty() ) { 76 | continue; 77 | }else{ 78 | count = 0; 79 | } 80 | numBytes += wav.size(); 81 | feature_pipeline.AcceptWaveform(wav); 82 | wav.clear(); 83 | 84 | while (true) { 85 | std::vector> feats; 86 | 87 | bool ok = feature_pipeline.Read(batch_size, &feats); 88 | std::vector> probs; // 89 | 90 | spotter.Forward(feats, &probs); 91 | std::cout << "feats.size= " << feats.size() << " probs.size=" << probs.size() << std::endl; 92 | // Reach the end of feature pipeline 93 | spotter.decode_keywords(probs); // feature_config.downsampling 94 | if (spotter.kwsInfo.state) flag = true; 95 | if (!ok) break; 96 | } 97 | } 98 | if (flag) { 99 | TP += 1; 100 | // find keyword in predicted sequence. 101 | std::cout << "YES :" << wav_path << "\t" << numBytes << std::endl; 102 | } else { 103 | FN += 1; 104 | std::cout << "NO :" << wav_path << "\t" << numBytes << std::endl; 105 | errorCases.push_back(wav_path); 106 | } 107 | } 108 | 109 | // compute metric 110 | std::cout << "TP: " << TP << " FN: " << FN << std::endl; 111 | // save bad case 112 | std::string badcasePath = "test.txt"; 113 | writeVectorToFile(errorCases, badcasePath); 114 | 115 | return 0; 116 | }; -------------------------------------------------------------------------------- /onnxruntime/cmake/onnxruntime.cmake: -------------------------------------------------------------------------------- 1 | set(ONNX_VERSION "1.12.0") 2 | if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") 3 | set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-linux-aarch64-${ONNX_VERSION}.tgz") 4 | set(URL_HASH "SHA256=5820d9f343df73c63b6b2b174a1ff62575032e171c9564bcf92060f46827d0ac") 5 | else() 6 | set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-linux-x64-${ONNX_VERSION}.tgz") 7 | set(URL_HASH "SHA256=5d503ce8540358b59be26c675e42081be14a3e833a5301926f555451046929c5") 8 | endif() 9 | 10 | FetchContent_Declare(onnxruntime 11 | URL ${ONNX_URL} 12 | URL_HASH ${URL_HASH} 13 | ) 14 | FetchContent_MakeAvailable(onnxruntime) 15 | include_directories(${onnxruntime_SOURCE_DIR}/include) 16 | link_directories(${onnxruntime_SOURCE_DIR}/lib) 17 | -------------------------------------------------------------------------------- /onnxruntime/cmake/portaudio.cmake: -------------------------------------------------------------------------------- 1 | 2 | FetchContent_Declare(portaudio 3 | URL https://github.com/PortAudio/portaudio/archive/refs/tags/v19.7.0.tar.gz 4 | URL_HASH SHA256=5af29ba58bbdbb7bbcefaaecc77ec8fc413f0db6f4c4e286c40c3e1b83174fa0 5 | ) 6 | FetchContent_MakeAvailable(portaudio) 7 | include_directories(${portaudio_SOURCE_DIR}/include) 8 | -------------------------------------------------------------------------------- /onnxruntime/frontend/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(frontend STATIC 2 | feature_pipeline.cc 3 | fft.cc 4 | ) 5 | -------------------------------------------------------------------------------- /onnxruntime/frontend/fbank.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2017 Personal (Binbin Zhang) 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef FRONTEND_FBANK_H_ 16 | #define FRONTEND_FBANK_H_ 17 | 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | #include "frontend/fft.h" 25 | #include "utils/log.h" 26 | 27 | namespace wenet { 28 | 29 | // This code is based on kaldi Fbank implentation, please see 30 | // https://github.com/kaldi-asr/kaldi/blob/master/src/feat/feature-fbank.cc 31 | class Fbank { 32 | public: 33 | Fbank(int num_bins, int sample_rate, int frame_length, int frame_shift) 34 | : num_bins_(num_bins), 35 | sample_rate_(sample_rate), 36 | frame_length_(frame_length), 37 | frame_shift_(frame_shift), 38 | use_log_(true), 39 | remove_dc_offset_(true), 40 | generator_(0), 41 | distribution_(0, 1.0), 42 | dither_(0.0) { 43 | fft_points_ = UpperPowerOfTwo(frame_length_); 44 | // generate bit reversal table and trigonometric function table 45 | const int fft_points_4 = fft_points_ / 4; 46 | bitrev_.resize(fft_points_); 47 | sintbl_.resize(fft_points_ + fft_points_4); 48 | make_sintbl(fft_points_, sintbl_.data()); 49 | make_bitrev(fft_points_, bitrev_.data()); 50 | 51 | int num_fft_bins = fft_points_ / 2; 52 | float fft_bin_width = static_cast(sample_rate_) / fft_points_; 53 | int low_freq = 20, high_freq = sample_rate_ / 2; 54 | float mel_low_freq = MelScale(low_freq); 55 | float mel_high_freq = MelScale(high_freq); 56 | float mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1); 57 | bins_.resize(num_bins_); 58 | center_freqs_.resize(num_bins_); 59 | //计算left_mel 是当前滤波器左边界的梅尔频率。center_mel 是当前滤波器中心的梅尔频率。 right_mel 是当前滤波器右边界的梅尔频率。 60 | for (int bin = 0; bin < num_bins; ++bin) { 61 | float left_mel = mel_low_freq + bin * mel_freq_delta, 62 | center_mel = mel_low_freq + (bin + 1) * mel_freq_delta, 63 | right_mel = mel_low_freq + (bin + 2) * mel_freq_delta; 64 | center_freqs_[bin] = InverseMelScale(center_mel); 65 | std::vector this_bin(num_fft_bins); 66 | int first_index = -1, last_index = -1; 67 | for (int i = 0; i < num_fft_bins; ++i) { 68 | float freq = (fft_bin_width * i); // Center frequency of this fft 69 | // bin. 70 | float mel = MelScale(freq); 71 | if (mel > left_mel && mel < right_mel) { 72 | float weight; 73 | if (mel <= center_mel) 74 | weight = (mel - left_mel) / (center_mel - left_mel); 75 | else 76 | weight = (right_mel - mel) / (right_mel - center_mel); 77 | this_bin[i] = weight; 78 | if (first_index == -1) first_index = i; 79 | last_index = i; 80 | } 81 | } 82 | CHECK(first_index != -1 && last_index >= first_index); 83 | bins_[bin].first = first_index; 84 | int size = last_index + 1 - first_index; 85 | bins_[bin].second.resize(size); 86 | for (int i = 0; i < size; ++i) { 87 | bins_[bin].second[i] = this_bin[first_index + i]; 88 | } 89 | } 90 | 91 | // NOTE(cdliang): add hamming window 92 | hamming_window_.resize(frame_length_); 93 | double a = M_2PI / (frame_length - 1); 94 | for (int i = 0; i < frame_length; i++) { 95 | double i_fl = static_cast(i); 96 | hamming_window_[i] = 0.54 - 0.46 * cos(a * i_fl); 97 | } 98 | } 99 | 100 | void set_use_log(bool use_log) { use_log_ = use_log; } 101 | 102 | void set_remove_dc_offset(bool remove_dc_offset) { 103 | remove_dc_offset_ = remove_dc_offset; 104 | } 105 | 106 | void set_dither(float dither) { dither_ = dither; } 107 | 108 | int num_bins() const { return num_bins_; } 109 | 110 | static inline float InverseMelScale(float mel_freq) { 111 | return 700.0f * (expf(mel_freq / 1127.0f) - 1.0f); 112 | } 113 | 114 | static inline float MelScale(float freq) { 115 | return 1127.0f * logf(1.0f + freq / 700.0f); 116 | } 117 | 118 | static int UpperPowerOfTwo(int n) { 119 | return static_cast(pow(2, ceil(log(n) / log(2)))); 120 | } 121 | 122 | // preemphasis 123 | void PreEmphasis(float coeff, std::vector* data) const { 124 | if (coeff == 0.0) return; 125 | for (int i = data->size() - 1; i > 0; i--) 126 | (*data)[i] -= coeff * (*data)[i - 1]; 127 | (*data)[0] -= coeff * (*data)[0]; 128 | } 129 | 130 | // add hamming window 131 | void Hamming(std::vector* data) const { 132 | CHECK(data->size() >= hamming_window_.size()); 133 | for (size_t i = 0; i < hamming_window_.size(); ++i) { 134 | (*data)[i] *= hamming_window_[i]; 135 | } 136 | } 137 | 138 | // Compute fbank feat, return num frames 139 | int Compute(const std::vector& wave, 140 | std::vector>* feat) { 141 | int num_samples = wave.size(); 142 | if (num_samples < frame_length_) return 0; 143 | int num_frames = 1 + ((num_samples - frame_length_) / frame_shift_); 144 | feat->resize(num_frames); 145 | std::vector fft_real(fft_points_, 0), fft_img(fft_points_, 0); 146 | std::vector power(fft_points_ / 2); 147 | for (int i = 0; i < num_frames; ++i) { 148 | std::vector data(wave.data() + i * frame_shift_, 149 | wave.data() + i * frame_shift_ + frame_length_); 150 | // optional add noise 151 | if (dither_ != 0.0) { 152 | for (size_t j = 0; j < data.size(); ++j) 153 | data[j] += dither_ * distribution_(generator_); 154 | } 155 | // optinal remove dc offset 156 | if (remove_dc_offset_) { 157 | float mean = 0.0; 158 | for (size_t j = 0; j < data.size(); ++j) mean += data[j]; 159 | mean /= data.size(); 160 | for (size_t j = 0; j < data.size(); ++j) data[j] -= mean; 161 | } 162 | 163 | PreEmphasis(0.97, &data); 164 | // Povey(&data); 165 | Hamming(&data); 166 | // copy data to fft_real 167 | memset(fft_img.data(), 0, sizeof(float) * fft_points_); 168 | memset(fft_real.data() + frame_length_, 0, 169 | sizeof(float) * (fft_points_ - frame_length_)); 170 | memcpy(fft_real.data(), data.data(), sizeof(float) * frame_length_); 171 | fft(bitrev_.data(), sintbl_.data(), fft_real.data(), fft_img.data(), 172 | fft_points_); 173 | // power 174 | for (int j = 0; j < fft_points_ / 2; ++j) { 175 | power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j]; 176 | } 177 | 178 | (*feat)[i].resize(num_bins_); 179 | // cepstral coefficients, triangle filter array 180 | for (int j = 0; j < num_bins_; ++j) { 181 | float mel_energy = 0.0; 182 | int s = bins_[j].first; 183 | for (size_t k = 0; k < bins_[j].second.size(); ++k) { 184 | mel_energy += bins_[j].second[k] * power[s + k]; 185 | } 186 | // optional use log 187 | if (use_log_) { 188 | if (mel_energy < std::numeric_limits::epsilon()) 189 | mel_energy = std::numeric_limits::epsilon(); 190 | mel_energy = logf(mel_energy); 191 | } 192 | 193 | (*feat)[i][j] = mel_energy; 194 | // printf("%f ", mel_energy); 195 | } 196 | // printf("\n"); 197 | } 198 | return num_frames; 199 | } 200 | 201 | private: 202 | int num_bins_; 203 | int sample_rate_; 204 | int frame_length_, frame_shift_; 205 | int fft_points_; 206 | bool use_log_; 207 | bool remove_dc_offset_; 208 | std::vector center_freqs_; 209 | std::vector>> bins_; 210 | std::vector hamming_window_; 211 | std::default_random_engine generator_; 212 | std::normal_distribution distribution_; 213 | float dither_; 214 | 215 | // bit reversal table 216 | std::vector bitrev_; 217 | // trigonometric function table 218 | std::vector sintbl_; 219 | }; 220 | 221 | } // namespace wenet 222 | 223 | #endif // FRONTEND_FBANK_H_ 224 | -------------------------------------------------------------------------------- /onnxruntime/frontend/feature_pipeline.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2017 Personal (Binbin Zhang) 2 | // Copyright (c) 2024 Yang Chen (cyang8050@163.com) 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | 16 | #include "frontend/feature_pipeline.h" 17 | 18 | #include 19 | #include 20 | 21 | namespace wenet { 22 | 23 | FeaturePipeline::FeaturePipeline(const FeaturePipelineConfig &config) 24 | : config_(config), 25 | feature_dim_(config.num_bins), 26 | fbank_(config.num_bins, config.sample_rate, config.frame_length, 27 | config.frame_shift), 28 | num_frames_(0), 29 | input_finished_(false) {} 30 | 31 | void FeaturePipeline::AcceptWaveform(const std::vector &wav) { 32 | std::vector> feats; 33 | std::vector waves; 34 | waves.insert(waves.end(), remained_wav_.begin(), remained_wav_.end()); 35 | waves.insert(waves.end(), wav.begin(), wav.end()); 36 | int num_frames = fbank_.Compute(waves, &feats); // feats.shape=(frames, mel_num_bins) 37 | 38 | if (config_.model_type==CTC_TYPE_MODEL){ 39 | int left_context = config_.left_context, right_context = config_.right_context; 40 | 41 | // 处理CTC Loss的模型的特征输入,参考https://modelscope.cn/studios/thuduj12/KWS_Nihao_Xiaojing/file/view/master/stream_kws_ctc.py 42 | // 将mel_num_bins=80的音频数据处理成dim=400的数据 43 | std::vector> feats_pad; 44 | if(!feature_remained.empty()){ 45 | feats_pad.insert(feats_pad.end(), feature_remained.begin(), feature_remained.end()); 46 | feats_pad.insert(feats_pad.end(), feats.begin(), feats.end()); 47 | feature_remained.clear(); // clear for updating later. 48 | }else{ 49 | feats_pad = std::move(padFeatures(feats, left_context)); 50 | } 51 | std::vector> feats_ctx = extractContext(feats_pad, left_context, 52 | right_context); 53 | 54 | // update feature remained, and feats 55 | int feature_remained_size = left_context + right_context; 56 | if (feature_remained_size <= feats.size()){ 57 | int start_index = feats.size() - feature_remained_size; 58 | feature_remained.assign(feats.begin() + start_index, feats.end()); 59 | } 60 | 61 | //对序列进行skip采样,降低重复计算。 62 | // int last_remainder = 0; 63 | // int remainder = (feats.size() + last_remainder) % this->config_.downsampling; 64 | // 对feats_ctx特征进行切片,按照step进行。 65 | std::vector> feats_down = slice(feats_ctx, 0, config_.downsampling); 66 | 67 | 68 | for (size_t i = 0; i < feats_down.size(); ++i) { 69 | feature_queue_.Push(std::move(feats_down[i])); 70 | } 71 | }else{ 72 | for (size_t i = 0; i < feats.size(); ++i) { 73 | feature_queue_.Push(std::move(feats[i])); 74 | } 75 | } 76 | 77 | num_frames_ += num_frames; 78 | 79 | int left_samples = waves.size() - config_.frame_shift * num_frames; 80 | remained_wav_.resize(left_samples); 81 | std::copy(waves.begin() + config_.frame_shift * num_frames, waves.end(), 82 | remained_wav_.begin()); 83 | // We are still adding wave, notify input is not finished 84 | finish_condition_.notify_one(); 85 | } 86 | 87 | void FeaturePipeline::AcceptWaveform(const std::vector &wav) { 88 | std::vector float_wav(wav.size()); 89 | for (size_t i = 0; i < wav.size(); i++) { 90 | float_wav[i] = static_cast(wav[i]); 91 | } 92 | this->AcceptWaveform(float_wav); 93 | } 94 | 95 | void FeaturePipeline::set_input_finished() { 96 | CHECK(!input_finished_); 97 | { 98 | std::lock_guard lock(mutex_); 99 | input_finished_ = true; 100 | feature_remained.clear(); 101 | } 102 | finish_condition_.notify_one(); 103 | } 104 | 105 | bool FeaturePipeline::ReadOne(std::vector *feat) { 106 | if (!feature_queue_.Empty()) { 107 | *feat = std::move(feature_queue_.Pop()); 108 | return true; 109 | } else { 110 | std::unique_lock lock(mutex_); 111 | while (!input_finished_) { 112 | // This will release the lock and wait for notify_one() 113 | // from AcceptWaveform() or set_input_finished() 114 | finish_condition_.wait(lock); 115 | if (!feature_queue_.Empty()) { 116 | *feat = std::move(feature_queue_.Pop()); 117 | return true; 118 | } 119 | } 120 | CHECK(input_finished_); 121 | // Double check queue.empty, see issue#893 for detailed discussions. 122 | if (!feature_queue_.Empty()) { 123 | *feat = std::move(feature_queue_.Pop()); 124 | return true; 125 | } else { 126 | return false; 127 | } 128 | } 129 | } 130 | 131 | bool FeaturePipeline::Read(int num_frames, 132 | std::vector> *feats) { 133 | feats->clear(); 134 | std::vector feat; 135 | while (feats->size() < num_frames) { 136 | if (ReadOne(&feat)) { 137 | feats->push_back(std::move(feat)); 138 | } else { 139 | return false; 140 | } 141 | } 142 | return true; 143 | } 144 | 145 | void FeaturePipeline::Reset() { 146 | input_finished_ = false; 147 | num_frames_ = 0; 148 | remained_wav_.clear(); 149 | feature_queue_.Clear(); 150 | } 151 | 152 | std::vector> 153 | FeaturePipeline::padFeatures(const std::vector> &feats, int leftContext) { 154 | // 获取特征矩阵的行数和列数 155 | size_t numRows = feats.size(); 156 | size_t numCols = feats[0].size(); 157 | 158 | // 计算填充后的特征矩阵的列数 159 | size_t paddedRows = numRows + leftContext; 160 | 161 | // 创建填充后的特征矩阵 162 | std::vector> paddedFeats(paddedRows, std::vector(numCols)); 163 | 164 | // 复制原始特征到填充后的特征矩阵中 165 | for (size_t i = 0; i < numRows; ++i) { 166 | // 复制原始特征到填充后的特征矩阵的右侧 167 | std::copy(feats[i].begin(), feats[i].end(), paddedFeats[i].begin() + leftContext); 168 | 169 | // 使用最边缘的元素复制填充特征到填充后的特征矩阵的左侧 170 | for (int j = 0; j < leftContext; ++j) { 171 | paddedFeats[i][j] = feats[i][0]; // 使用第一列的元素进行复制 172 | } 173 | } 174 | 175 | return paddedFeats; 176 | } 177 | 178 | 179 | std::vector> FeaturePipeline::extractContext(const std::vector> & 180 | featsPad, int leftContext, int rightContext) { 181 | int ctxFrm = featsPad.size() - (leftContext + rightContext); 182 | int ctxWin = leftContext + rightContext + 1; 183 | int ctxDim = featsPad[0].size() * ctxWin; 184 | 185 | std::vector> featsCtx(ctxFrm, std::vector(ctxDim, 0.0)); 186 | 187 | for (int i = 0; i < ctxFrm; ++i) { 188 | int start = i; 189 | int end = i + ctxWin; 190 | int index = 0; 191 | 192 | for (int j = start; j < end; ++j) { 193 | for (int k = 0; k < featsPad[j].size(); ++k) { 194 | featsCtx[i][index] = featsPad[j][k]; 195 | ++index; 196 | } 197 | } 198 | } 199 | 200 | return featsCtx; 201 | } 202 | 203 | std::vector> FeaturePipeline::slice(const std::vector>& data, int start, int step) { 204 | std::vector> result; 205 | for (int i = start; i < data.size(); i += step) { 206 | result.push_back(data[i]); 207 | } 208 | return result; 209 | } 210 | } // namespace wenet 211 | -------------------------------------------------------------------------------- /onnxruntime/frontend/feature_pipeline.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2017 Personal (Binbin Zhang) 2 | // Copyright (c) 2024 Yang Chen (cyang8050@163.com) 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | 16 | #ifndef FRONTEND_FEATURE_PIPELINE_H_ 17 | #define FRONTEND_FEATURE_PIPELINE_H_ 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | #include "frontend/fbank.h" 25 | #include "utils/log.h" 26 | #include "utils/blocking_queue.h" 27 | 28 | namespace wenet { 29 | 30 | typedef enum { 31 | MAXPOOLING_TYPE_MODEL=0, 32 | CTC_TYPE_MODEL=1 33 | }MODEL_TYPE; 34 | 35 | struct FeaturePipelineConfig { 36 | int num_bins; 37 | int sample_rate; 38 | int frame_length; 39 | int frame_shift; 40 | int left_context; 41 | int right_context; 42 | int downsampling; 43 | MODEL_TYPE model_type; // 1:ctc 0: max-pooling 44 | 45 | FeaturePipelineConfig(int num_bins, int sample_rate, MODEL_TYPE model_type) 46 | : num_bins(num_bins), // 80 dim fbank. feature dim of mel-spectrogram. 47 | sample_rate(sample_rate), // 16k sample rate of audio 48 | model_type(model_type) { 49 | frame_length = sample_rate / 1000 * 25; // frame length 25ms, window_size 50 | frame_shift = sample_rate / 1000 * 10; // frame shift 10ms, window_shift 51 | left_context = 2; // context_expansion_conf in config.yaml. 52 | right_context = 2; 53 | downsampling = 3; 54 | } 55 | 56 | void Info() const { 57 | LOG(INFO) << "feature pipeline config" 58 | << " num_bins " << num_bins << " frame_length " << frame_length 59 | << " frame_shift " << frame_shift; 60 | } 61 | }; 62 | 63 | // Typically, FeaturePipeline is used in two threads: one thread A calls 64 | // AcceptWaveform() to add raw wav data and set_input_finished() to notice 65 | // the end of input wav, another thread B (decoder thread) calls Read() to 66 | // consume features.So a BlockingQueue is used to make this class thread safe. 67 | 68 | // The Read() is designed as a blocking method when there is no feature 69 | // in feature_queue_ and the input is not finished. 70 | 71 | class FeaturePipeline { 72 | public: 73 | explicit FeaturePipeline(const FeaturePipelineConfig &config); 74 | 75 | // The feature extraction is done in AcceptWaveform(). 76 | void AcceptWaveform(const std::vector &wav); 77 | 78 | void AcceptWaveform(const std::vector &wav); 79 | 80 | // Current extracted frames number. 81 | int num_frames() const { return num_frames_; } 82 | 83 | int feature_dim() const { return feature_dim_; } 84 | 85 | const FeaturePipelineConfig &config() const { return config_; } 86 | 87 | // The caller should call this method when speech input is end. 88 | // Never call AcceptWaveform() after calling set_input_finished() ! 89 | void set_input_finished(); 90 | 91 | bool input_finished() const { return input_finished_; } 92 | 93 | // Return False if input is finished and no feature could be read. 94 | // Return True if a feature is read. 95 | // This function is a blocking method. It will block the thread when 96 | // there is no feature in feature_queue_ and the input is not finished. 97 | bool ReadOne(std::vector *feat); 98 | 99 | // Read #num_frames frame features. 100 | // Return False if less then #num_frames features are read and the 101 | // input is finished. 102 | // Return True if #num_frames features are read. 103 | // This function is a blocking method when there is no feature 104 | // in feature_queue_ and the input is not finished. 105 | bool Read(int num_frames, std::vector> *feats); 106 | 107 | void Reset(); 108 | 109 | bool IsLastFrame(int frame) const { 110 | return input_finished_ && (frame == num_frames_ - 1); 111 | } 112 | 113 | int NumQueuedFrames() const { return feature_queue_.Size(); } 114 | 115 | std::vector> padFeatures(const std::vector> &feats, int leftContext); 116 | 117 | std::vector> extractContext(const std::vector> & 118 | featsPad, int leftContext, int rightContext); 119 | 120 | std::vector> slice(const std::vector> &data, int start, int step); 121 | 122 | private: 123 | const FeaturePipelineConfig &config_; 124 | int feature_dim_; 125 | Fbank fbank_; 126 | 127 | BlockingQueue> feature_queue_; 128 | int num_frames_; 129 | bool input_finished_; 130 | 131 | // The feature extraction is done in AcceptWaveform(). 132 | // This wavefrom sample points are consumed by frame size. 133 | // The residual wavefrom sample points after framing are 134 | // kept to be used in next AcceptWaveform() calling. 135 | std::vector remained_wav_; 136 | 137 | // Used to block the Read when there is no feature in feature_queue_ 138 | // and the input is not finished. 139 | mutable std::mutex mutex_; 140 | std::condition_variable finish_condition_; 141 | 142 | std::vector> feature_remained; 143 | 144 | 145 | }; 146 | 147 | } // namespace wenet 148 | 149 | #endif // FRONTEND_FEATURE_PIPELINE_H_ 150 | -------------------------------------------------------------------------------- /onnxruntime/frontend/fft.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2016 HR 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "frontend/fft.h" 8 | 9 | namespace wenet { 10 | 11 | void make_sintbl(int n, float* sintbl) { 12 | int i, n2, n4, n8; 13 | float c, s, dc, ds, t; 14 | 15 | n2 = n / 2; 16 | n4 = n / 4; 17 | n8 = n / 8; 18 | t = sin(M_PI / n); 19 | dc = 2 * t * t; 20 | ds = sqrt(dc * (2 - dc)); 21 | t = 2 * dc; 22 | c = sintbl[n4] = 1; 23 | s = sintbl[0] = 0; 24 | for (i = 1; i < n8; ++i) { 25 | c -= dc; 26 | dc += t * c; 27 | s += ds; 28 | ds -= t * s; 29 | sintbl[i] = s; 30 | sintbl[n4 - i] = c; 31 | } 32 | if (n8 != 0) sintbl[n8] = sqrt(0.5); 33 | for (i = 0; i < n4; ++i) sintbl[n2 - i] = sintbl[i]; 34 | for (i = 0; i < n2 + n4; ++i) sintbl[i + n2] = -sintbl[i]; 35 | } 36 | 37 | void make_bitrev(int n, int* bitrev) { 38 | int i, j, k, n2; 39 | 40 | n2 = n / 2; 41 | i = j = 0; 42 | for (;;) { 43 | bitrev[i] = j; 44 | if (++i >= n) break; 45 | k = n2; 46 | while (k <= j) { 47 | j -= k; 48 | k /= 2; 49 | } 50 | j += k; 51 | } 52 | } 53 | 54 | // bitrev: bit reversal table 55 | // sintbl: trigonometric function table 56 | // x:real part 57 | // y:image part 58 | // n: fft length 59 | int fft(const int* bitrev, const float* sintbl, float* x, float* y, int n) { 60 | int i, j, k, ik, h, d, k2, n4, inverse; 61 | float t, s, c, dx, dy; 62 | 63 | /* preparation */ 64 | if (n < 0) { 65 | n = -n; 66 | inverse = 1; /* inverse transform */ 67 | } else { 68 | inverse = 0; 69 | } 70 | n4 = n / 4; 71 | if (n == 0) { 72 | return 0; 73 | } 74 | 75 | /* bit reversal */ 76 | for (i = 0; i < n; ++i) { 77 | j = bitrev[i]; 78 | if (i < j) { 79 | t = x[i]; 80 | x[i] = x[j]; 81 | x[j] = t; 82 | t = y[i]; 83 | y[i] = y[j]; 84 | y[j] = t; 85 | } 86 | } 87 | 88 | /* transformation */ 89 | for (k = 1; k < n; k = k2) { 90 | h = 0; 91 | k2 = k + k; 92 | d = n / k2; 93 | for (j = 0; j < k; ++j) { 94 | c = sintbl[h + n4]; 95 | if (inverse) 96 | s = -sintbl[h]; 97 | else 98 | s = sintbl[h]; 99 | for (i = j; i < n; i += k2) { 100 | ik = i + k; 101 | dx = s * y[ik] + c * x[ik]; 102 | dy = c * y[ik] - s * x[ik]; 103 | x[ik] = x[i] - dx; 104 | x[i] += dx; 105 | y[ik] = y[i] - dy; 106 | y[i] += dy; 107 | } 108 | h += d; 109 | } 110 | } 111 | if (inverse) { 112 | /* divide by n in case of the inverse transformation */ 113 | for (i = 0; i < n; ++i) { 114 | x[i] /= n; 115 | y[i] /= n; 116 | } 117 | } 118 | return 0; /* finished successfully */ 119 | } 120 | 121 | } // namespace wenet 122 | -------------------------------------------------------------------------------- /onnxruntime/frontend/fft.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2016 HR 2 | 3 | #ifndef FRONTEND_FFT_H_ 4 | #define FRONTEND_FFT_H_ 5 | 6 | #ifndef M_PI 7 | #define M_PI 3.1415926535897932384626433832795 8 | #endif 9 | #ifndef M_2PI 10 | #define M_2PI 6.283185307179586476925286766559005 11 | #endif 12 | 13 | namespace wenet { 14 | 15 | // Fast Fourier Transform 16 | 17 | void make_sintbl(int n, float* sintbl); 18 | 19 | void make_bitrev(int n, int* bitrev); 20 | 21 | int fft(const int* bitrev, const float* sintbl, float* x, float* y, int n); 22 | 23 | } // namespace wenet 24 | 25 | #endif // FRONTEND_FFT_H_ 26 | -------------------------------------------------------------------------------- /onnxruntime/frontend/wav.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2016 Personal (Binbin Zhang) 2 | // Created on 2016-08-15 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | 16 | #ifndef FRONTEND_WAV_H_ 17 | #define FRONTEND_WAV_H_ 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | #include 26 | 27 | #include "utils/log.h" 28 | 29 | namespace wenet { 30 | 31 | struct WavHeader { 32 | char riff[4]; // "riff" 33 | unsigned int size; 34 | char wav[4]; // "WAVE" 35 | char fmt[4]; // "fmt " 36 | unsigned int fmt_size; 37 | uint16_t format; 38 | uint16_t channels; 39 | unsigned int sample_rate; 40 | unsigned int bytes_per_second; 41 | uint16_t block_size; 42 | uint16_t bit; 43 | char data[4]; // "data" 44 | unsigned int data_size; 45 | }; 46 | 47 | class WavReader { 48 | public: 49 | WavReader() : data_(nullptr) {} 50 | explicit WavReader(const std::string& filename) { Open(filename); } 51 | 52 | bool Open(const std::string& filename) { 53 | FILE* fp = fopen(filename.c_str(), "rb"); 54 | if (NULL == fp) { 55 | LOG(WARNING) << "Error in read " << filename; 56 | return false; 57 | } 58 | 59 | WavHeader header; 60 | fread(&header, 1, sizeof(header), fp); 61 | if (header.fmt_size < 16) { 62 | fprintf(stderr, 63 | "WaveData: expect PCM format data " 64 | "to have fmt chunk of at least size 16.\n"); 65 | return false; 66 | } else if (header.fmt_size > 16) { 67 | int offset = 44 - 8 + header.fmt_size - 16; 68 | fseek(fp, offset, SEEK_SET); 69 | fread(header.data, 8, sizeof(char), fp); 70 | } 71 | // check "riff" "WAVE" "fmt " "data" 72 | 73 | // Skip any subchunks between "fmt" and "data". Usually there will 74 | // be a single "fact" subchunk, but on Windows there can also be a 75 | // "list" subchunk. 76 | while (0 != strncmp(header.data, "data", 4)) { 77 | // We will just ignore the data in these chunks. 78 | fseek(fp, header.data_size, SEEK_CUR); 79 | // read next subchunk 80 | fread(header.data, 8, sizeof(char), fp); 81 | } 82 | 83 | num_channel_ = header.channels; 84 | sample_rate_ = header.sample_rate; 85 | bits_per_sample_ = header.bit; 86 | int num_data = header.data_size / (bits_per_sample_ / 8); 87 | data_ = new float[num_data]; 88 | num_samples_ = num_data / num_channel_; 89 | 90 | for (int i = 0; i < num_data; ++i) { 91 | switch (bits_per_sample_) { 92 | case 8: { 93 | char sample; 94 | fread(&sample, 1, sizeof(char), fp); 95 | data_[i] = static_cast(sample); 96 | break; 97 | } 98 | case 16: { 99 | int16_t sample; 100 | fread(&sample, 1, sizeof(int16_t), fp); 101 | data_[i] = static_cast(sample); 102 | break; 103 | } 104 | case 32: { 105 | int sample; 106 | fread(&sample, 1, sizeof(int), fp); 107 | data_[i] = static_cast(sample); 108 | break; 109 | } 110 | default: 111 | fprintf(stderr, "unsupported quantization bits"); 112 | exit(1); 113 | } 114 | } 115 | fclose(fp); 116 | return true; 117 | } 118 | 119 | int num_channel() const { return num_channel_; } 120 | int sample_rate() const { return sample_rate_; } 121 | int bits_per_sample() const { return bits_per_sample_; } 122 | int num_samples() const { return num_samples_; } 123 | 124 | ~WavReader() { 125 | if (data_ != NULL) delete[] data_; 126 | } 127 | 128 | const float* data() const { return data_; } 129 | 130 | 131 | 132 | 133 | private: 134 | int num_channel_; 135 | int sample_rate_; 136 | int bits_per_sample_; 137 | int num_samples_; // sample points per channel 138 | float* data_; 139 | }; 140 | 141 | class WavWriter { 142 | public: 143 | WavWriter(const float* data, int num_samples, int num_channel, 144 | int sample_rate, int bits_per_sample) 145 | : data_(data), 146 | num_samples_(num_samples), 147 | num_channel_(num_channel), 148 | sample_rate_(sample_rate), 149 | bits_per_sample_(bits_per_sample) {} 150 | 151 | void Write(const std::string& filename) { 152 | FILE* fp = fopen(filename.c_str(), "w"); 153 | // init char 'riff' 'WAVE' 'fmt ' 'data' 154 | WavHeader header; 155 | char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57, 156 | 0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00, 157 | 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 158 | 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 159 | 0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00}; 160 | memcpy(&header, wav_header, sizeof(header)); 161 | header.channels = num_channel_; 162 | header.bit = bits_per_sample_; 163 | header.sample_rate = sample_rate_; 164 | header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8); 165 | header.size = sizeof(header) - 8 + header.data_size; 166 | header.bytes_per_second = 167 | sample_rate_ * num_channel_ * (bits_per_sample_ / 8); 168 | header.block_size = num_channel_ * (bits_per_sample_ / 8); 169 | 170 | fwrite(&header, 1, sizeof(header), fp); 171 | 172 | for (int i = 0; i < num_samples_; ++i) { 173 | for (int j = 0; j < num_channel_; ++j) { 174 | switch (bits_per_sample_) { 175 | case 8: { 176 | char sample = static_cast(data_[i * num_channel_ + j]); 177 | fwrite(&sample, 1, sizeof(sample), fp); 178 | break; 179 | } 180 | case 16: { 181 | int16_t sample = static_cast(data_[i * num_channel_ + j]); 182 | fwrite(&sample, 1, sizeof(sample), fp); 183 | break; 184 | } 185 | case 32: { 186 | int sample = static_cast(data_[i * num_channel_ + j]); 187 | fwrite(&sample, 1, sizeof(sample), fp); 188 | break; 189 | } 190 | } 191 | } 192 | } 193 | fclose(fp); 194 | } 195 | 196 | private: 197 | const float* data_; 198 | int num_samples_; // total float points in data_ 199 | int num_channel_; 200 | int sample_rate_; 201 | int bits_per_sample_; 202 | }; 203 | 204 | } // namespace wenet 205 | 206 | #endif // FRONTEND_WAV_H_ 207 | -------------------------------------------------------------------------------- /onnxruntime/kws/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(kws STATIC keyword_spotting.cc utils.cpp) 2 | -------------------------------------------------------------------------------- /onnxruntime/kws/keyword_spotting.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Binbin Zhang (binbzha@qq.com) 2 | // Copyright (c) 2024 Yang Chen (cyang8050@163.com) 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | 16 | 17 | #include "kws/keyword_spotting.h" 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | 28 | namespace wekws { 29 | 30 | Ort::Env KeywordSpotting::env_ = Ort::Env(ORT_LOGGING_LEVEL_WARNING, ""); 31 | Ort::SessionOptions KeywordSpotting::session_options_ = Ort::SessionOptions(); 32 | 33 | static void print_vector(const std::vector &arr) { 34 | if (!arr.empty()) { 35 | std::cout << "prefix: "; 36 | for (auto it: arr) { 37 | std::cout << it << ","; 38 | } 39 | std::cout << std::endl; 40 | } 41 | } 42 | 43 | static bool PrefixScoreCompare( 44 | const std::pair, PrefixScore> &a, 45 | const std::pair, PrefixScore> &b) { 46 | return a.second.total_score() > b.second.total_score(); 47 | } 48 | 49 | KeywordSpotting::KeywordSpotting(const std::string &model_path, DECODE_TYPE decode_type, int model_type) { 50 | // 0. set decode type from {DECODE_GREEDY_SEARCH, DECODE_PREFIX_BEAM_SEARCH} 51 | mdecode_type = decode_type; 52 | mmodel_type = model_type; 53 | 54 | // 1. Load onnx runtime sessions 55 | session_ = std::make_shared(env_, model_path.c_str(), 56 | session_options_); 57 | // 2. Model info. Information can be view from netron. 58 | // pip install netron. netron [model_path] 59 | in_names_ = {"input", "cache"}; 60 | out_names_ = {"output", "r_cache"}; 61 | auto metadata = session_->GetModelMetadata(); 62 | Ort::AllocatorWithDefaultOptions allocator; 63 | cache_dim_ = std::stoi(metadata.LookupCustomMetadataMap("cache_dim", 64 | allocator)); 65 | cache_len_ = std::stoi(metadata.LookupCustomMetadataMap("cache_len", 66 | allocator)); 67 | std::cout << "Kws Model Info:" << std::endl 68 | << "\tcache_dim: " << cache_dim_ << std::endl 69 | << "\tcache_len: " << cache_len_ << std::endl; 70 | 71 | Reset(); 72 | 73 | 74 | } 75 | 76 | void KeywordSpotting::Reset() { 77 | Ort::MemoryInfo memory_info = 78 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); 79 | if(mmodel_type == 1){ // ctc model 80 | cache_.resize(cache_dim_ * cache_len_ * cache_4_, 0.0); 81 | const int64_t cache_shape[] = {1, cache_dim_, cache_len_, cache_4_}; 82 | cache_ort_ = Ort::Value::CreateTensor( 83 | memory_info, cache_.data(), cache_.size(), cache_shape, 4); 84 | reset_value(); 85 | }else{ // max pooling model 86 | cache_.resize(cache_dim_ * cache_len_ , 0.0); 87 | const int64_t cache_shape[] = {1, cache_dim_, cache_len_}; 88 | cache_ort_ = Ort::Value::CreateTensor( 89 | memory_info, cache_.data(), cache_.size(), cache_shape, 3); 90 | } 91 | } 92 | 93 | void KeywordSpotting::reset_value() { 94 | if (mdecode_type == DECODE_PREFIX_BEAM_SEARCH) { 95 | cur_hyps_.clear(); 96 | PrefixScore prefix_score; 97 | prefix_score.s = 1.0; 98 | prefix_score.ns = 0.0; 99 | std::vector empty; 100 | cur_hyps_[empty] = prefix_score; 101 | 102 | activated = false; // none 103 | 104 | } else if (mdecode_type == DECODE_GREEDY_SEARCH) { 105 | gd_cur_hyps.clear(); 106 | } 107 | } 108 | 109 | void KeywordSpotting::Forward( 110 | const std::vector> &feats, 111 | std::vector> *prob) { 112 | prob->clear(); 113 | if (feats.size() == 0) return; 114 | Ort::MemoryInfo memory_info = 115 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); 116 | // 1. Prepare input 117 | int num_frames = feats.size(); 118 | int feature_dim = feats[0].size(); 119 | std::vector slice_feats; 120 | for (int i = 0; i < feats.size(); i++) { 121 | slice_feats.insert(slice_feats.end(), feats[i].begin(), feats[i].end()); 122 | } 123 | const int64_t feats_shape[3] = {1, num_frames, feature_dim}; 124 | Ort::Value feats_ort = Ort::Value::CreateTensor( 125 | memory_info, slice_feats.data(), slice_feats.size(), feats_shape, 3); 126 | // 2. Ort forward 127 | std::vector inputs; 128 | inputs.emplace_back(std::move(feats_ort)); 129 | inputs.emplace_back(std::move(cache_ort_)); 130 | // ort_outputs.size() == 2 131 | std::vector ort_outputs = session_->Run( 132 | Ort::RunOptions{nullptr}, in_names_.data(), inputs.data(), 133 | inputs.size(), out_names_.data(), out_names_.size()); 134 | 135 | // 3. Update cache 136 | cache_ort_ = std::move(ort_outputs[1]); 137 | 138 | // 4. Get keyword prob 139 | float *data = ort_outputs[0].GetTensorMutableData(); 140 | auto type_info = ort_outputs[0].GetTensorTypeAndShapeInfo(); 141 | int num_outputs = type_info.GetShape()[1]; 142 | int output_dim = type_info.GetShape()[2]; 143 | prob->resize(num_outputs); 144 | for (int i = 0; i < num_outputs; i++) { 145 | (*prob)[i].resize(output_dim); 146 | memcpy((*prob)[i].data(), data + i * output_dim, 147 | sizeof(float) * output_dim); 148 | } 149 | } 150 | 151 | void KeywordSpotting::readToken(const std::string &tokenFile) { 152 | std::ifstream fin(tokenFile); 153 | 154 | if (fin.is_open()) { 155 | std::string line; 156 | while (std::getline(fin, line)) { 157 | if(mmodel_type==1){ 158 | std::string token; 159 | int value; 160 | std::istringstream iss(line); 161 | if (iss >> token >> value) { 162 | mvocab[token] = value - 1; 163 | } 164 | }else{ 165 | mmaxpooling_keywords.push_back(line); 166 | } 167 | } 168 | fin.close(); 169 | } else { 170 | std::cerr << "Error: Unable to open the token file." << std::endl; 171 | } 172 | 173 | } 174 | 175 | void KeywordSpotting::setKeyWord(const std::string &keyWord) { 176 | /*keyWord : key word to wakeup. 177 | * */ 178 | mkey_word = keyWord; 179 | mkeyword_set.insert(0); // insert 0 for blank token of ctc. 180 | for (int idx = 0; idx < keyWord.size(); idx += 3) { // 3byte for chinese char with utf8. 181 | std::string token = keyWord.substr(idx, 3); 182 | int toekn_idx = mvocab.at(token); 183 | if (mvocab.count(token) > 0) { 184 | if (mkeyword_set.count(toekn_idx) == 0) { 185 | mkeyword_set.insert(toekn_idx); 186 | } 187 | mkeyword_token.push_back(toekn_idx); 188 | } else { 189 | std::cerr << "Can not find" << " " << keyWord << " " << "in vocab. Please check."; 190 | } 191 | } 192 | } 193 | 194 | bool KeywordSpotting::isKeyword(int index) { 195 | return mkeyword_set.count(index) > 0; 196 | } 197 | 198 | void KeywordSpotting::UpdateHypotheses(const std::vector, PrefixScore>> &hpys) { 199 | cur_hyps_.clear(); 200 | for (auto &item: hpys) { 201 | // std::vector prefix = item.first; 202 | if (item.first.empty()) { 203 | PrefixScore prefix_score; 204 | prefix_score.s = 1.0; 205 | prefix_score.ns = 0.0; 206 | std::vector empty; 207 | cur_hyps_[empty] = prefix_score; 208 | } else { 209 | // filter illegal prefix case. 210 | if(item.first.size() > mkeyword_token.size()) { 211 | continue; 212 | } 213 | cur_hyps_[item.first] = item.second; 214 | } 215 | } 216 | // assert cur_hyps_ is not empty() 217 | if (!cur_hyps_.empty()){ 218 | cur_hyps_[std::vector()] = PrefixScore{1.0, 0.0}; 219 | } 220 | 221 | } 222 | 223 | void KeywordSpotting::decode_keywords(std::vector> &probs, float hitScoreThr) { 224 | /*decode keyword. 225 | */ 226 | if (mdecode_type == DECODE_GREEDY_SEARCH) { 227 | //std::cout << "DECODE_GREEDY_SEARCH" << std::endl; 228 | decode_with_greedy_search(mGTimeStep, probs); 229 | mGTimeStep += 1; 230 | 231 | } else if (mdecode_type == DECODE_PREFIX_BEAM_SEARCH) { 232 | // std::cout << "DECODE_PREFIX_BEAM_SEARCH" << std::endl; 233 | for (const auto &prob: probs) { 234 | 235 | decode_ctc_prefix_beam_search(mGTimeStep, prob); 236 | mGTimeStep += 1; 237 | execute_detection(hitScoreThr); 238 | if (activated) { 239 | reset_value(); 240 | break; 241 | } 242 | } 243 | } else { 244 | std::cerr << "Not implement yet now."; 245 | } 246 | } 247 | 248 | void KeywordSpotting::decode_with_greedy_search(int offset, std::vector> &probs) { 249 | 250 | // find index with max-prob in each time step. 251 | for (int i = 0; i < probs.size(); i++) { 252 | std::cout << "frame " << std::setw(3) << offset + i; 253 | auto maxElement = std::max_element(probs[i].begin(), probs[i].end()); 254 | int maxIndex = std::distance(probs[i].begin(), maxElement); 255 | std::cout << " maxIndex " << std::setw(4) << maxIndex << " prob " << probs[i][maxIndex]; 256 | Token token = {offset + i, maxIndex, probs[i][maxIndex]}; 257 | alignments.emplace_back(token); 258 | std::cout << std::endl; 259 | } 260 | 261 | // find hypotheses with token index. just consider one path. 262 | // it's not update prob when meeting same token, now. 263 | std::unordered_set seenIds; 264 | for (const auto &token: alignments) { 265 | if (token.id != 0 && isKeyword(token.id)) { 266 | if (seenIds.count(token.id) == 0) { 267 | // not see token in current hyp. 268 | gd_cur_hyps.push_back(token); 269 | seenIds.insert(token.id); 270 | } //update prob 271 | } else { 272 | seenIds.clear(); 273 | } 274 | } 275 | alignments.clear(); 276 | } 277 | 278 | void KeywordSpotting::decode_ctc_prefix_beam_search(int stepT, const std::vector &probv) { 279 | /* Decoding ctc sequence with prefix beam search. 280 | * ref: https://distill.pub/2017/ctc/ 281 | * python implement 282 | * https://modelscope.cn/studios/thuduj12/KWS_Nihao_Xiaojing/file/view/master/stream_kws_ctc.py 283 | * https://robin1001.github.io/2020/12/11/ctc-search 284 | * */ 285 | 286 | // std::cout << "stepT=" << std::setw(3) << stepT << std::endl; 287 | if (probv.size() == 0) return; 288 | std::unordered_map, PrefixScore, PrefixHash> next_hyps; 289 | 290 | // 1. First beam prune, only select topk candidates 291 | std::vector topk_probs; 292 | std::vector topk_index; 293 | TopK(probv, opts_.first_beam_size, &topk_probs, &topk_index); 294 | 295 | // filter prob score that is too small. 296 | std::vector filter_probs; 297 | std::vector filter_index; 298 | for (int i = 0; i < opts_.first_beam_size; i++) { 299 | int idx = topk_index[i]; 300 | float prob = probv[idx]; 301 | 302 | if (!mkeyword_set.empty()) { 303 | if (prob > 0.05 && isKeyword(idx)) { 304 | filter_probs.push_back(prob); 305 | filter_index.push_back(idx); 306 | } 307 | } else { 308 | if (prob > 0.05) { 309 | filter_probs.push_back(prob); 310 | filter_index.push_back(idx); 311 | } 312 | } 313 | } 314 | 315 | // handle prefix beam search 316 | if (!filter_index.empty()) { 317 | for (int i = 0; i < filter_index.size(); i++) { 318 | int tokenId = filter_index[i]; // token index of vocab 319 | float ps = probv[tokenId]; // prob of token 320 | std::cout << "stepT=" << std::setw(3) << stepT << " tokenid=" << std::setw(4) << tokenId \ 321 | << " proposed i=" << i << " prob=" << std::setprecision(3) << ps << std::endl; 322 | for (const auto &it: cur_hyps_) {//Fixing bug that can't be wakeup in stream-mode" 323 | const std::vector &prefix = it.first; 324 | const PrefixScore &prefix_score = it.second; 325 | print_vector(prefix); 326 | if (tokenId == opts_.blank) { 327 | // handle ending with blank token. eg 你好问 + ε ->你好问 328 | PrefixScore &next_score = next_hyps[prefix]; 329 | next_score.s = next_score.s + prefix_score.s * ps + prefix_score.ns * ps; 330 | next_score.nodes = prefix_score.nodes; // keep the nodes 331 | 332 | } else if (!prefix.empty() && tokenId == prefix.back()) { 333 | if (!(std::abs(prefix_score.ns - 0.0) <= 1e-6)) { 334 | // 处理: 你好-好->你好 . 消除alignment中两个blank之间的重复token. 335 | PrefixScore &next_score1 = next_hyps[prefix]; 336 | // update prob of same token. 337 | std::vector next_nodes(prefix_score.nodes); // copy current nodes 338 | if (!next_nodes.empty() && ps > next_nodes.back().prob) { 339 | next_nodes.back().prob = ps; 340 | next_nodes.back().timeStep = stepT; 341 | } 342 | next_score1.ns = next_score1.ns + prefix_score.ns * ps; 343 | next_score1.nodes = next_nodes; 344 | } 345 | if (!(std::abs(prefix_score.s - 0.0) <= 1e-6)) { 346 | // 处理: 你好-好->你好好 . 保留输出序列中的重复字符. 347 | std::vector next_prefix(prefix); 348 | next_prefix.push_back(tokenId); 349 | PrefixScore &next_score2 = next_hyps[next_prefix]; 350 | next_score2.ns = next_score2.ns + prefix_score.s * ps; 351 | Token curToken = {stepT, tokenId, ps}; 352 | // update nodes from current nodes 353 | std::vector next_nodes(prefix_score.nodes); // copy current nodes 354 | next_nodes.push_back(curToken); 355 | next_score2.nodes = next_nodes; 356 | } 357 | 358 | } else { 359 | //std::cout << "##Not see Token" << std::endl; 360 | 361 | std::vector next_prefix(prefix); 362 | next_prefix.push_back(tokenId); 363 | PrefixScore &next_score3 = next_hyps[next_prefix]; 364 | 365 | if (!next_score3.nodes.empty()) { 366 | // update prob of same token 367 | if (ps > next_score3.nodes.back().prob) { 368 | next_score3.nodes.pop_back(); 369 | Token curToken = {stepT, tokenId, ps}; 370 | next_score3.nodes.push_back(curToken); 371 | next_score3.ns = prefix_score.ns; 372 | next_score3.s = prefix_score.s; 373 | } 374 | } else { 375 | std::vector next_nodes(prefix_score.nodes); // copy current nodes 376 | Token curToken = {stepT, tokenId, ps}; 377 | next_nodes.push_back(curToken); 378 | next_score3.nodes = next_nodes; 379 | next_score3.ns = next_score3.ns + prefix_score.s * ps + prefix_score.ns * ps; 380 | } 381 | } 382 | } 383 | } 384 | 385 | // 3 second beam prune. keep topK 386 | std::vector, PrefixScore>> arr(next_hyps.begin(), 387 | next_hyps.end()); 388 | int second_beam_size = 389 | std::min(static_cast(arr.size()), opts_.second_beam_size); 390 | 391 | std::nth_element(arr.begin(), arr.begin() + second_beam_size, arr.end(), 392 | PrefixScoreCompare); 393 | arr.resize(second_beam_size); 394 | std::sort(arr.begin(), arr.end(), PrefixScoreCompare); 395 | 396 | // update 397 | UpdateHypotheses(arr); 398 | } 399 | 400 | } 401 | 402 | void KeywordSpotting::execute_detection(float hitScoreThr) { 403 | /* 对当前输出的prfix串和关键词进行对比,判断是否唤醒. 404 | * */ 405 | 406 | if (mdecode_type == wekws::DECODE_PREFIX_BEAM_SEARCH) { 407 | for (const auto &it: cur_hyps_) { 408 | const std::vector &prefix = it.first; 409 | const std::vector &nodes = it.second.nodes; 410 | if (!prefix.empty() && prefix.size() == mkeyword_token.size()) { 411 | int num = 0; 412 | for (auto i = 0; i < prefix.size(); i++) { 413 | num += (prefix[i] != mkeyword_token[i]) ? 0 : 1; 414 | kwsInfo.hit_score *= nodes[i].prob; 415 | if (i == 0) kwsInfo.start_frame = nodes[i].timeStep; 416 | if (i == nodes.size() - 1) kwsInfo.end_frame= nodes[i].timeStep; 417 | } 418 | activated = (num==mkeyword_token.size()) ? true : false; 419 | } 420 | kwsInfo.hit_score = std::sqrt(kwsInfo.hit_score); 421 | activated = (kwsInfo.hit_score > hitScoreThr) ? activated : false; 422 | kwsInfo.state = activated; 423 | if (activated == true) { 424 | 425 | std::cout << "keyword=" << mkey_word 426 | << " hitscore=" << kwsInfo.hit_score << " hitScoreThr=" << hitScoreThr 427 | << " start T=" << kwsInfo.start_frame 428 | << " end T=" << kwsInfo.end_frame << std::endl; 429 | //print_vector(prefix); 430 | break; 431 | } 432 | } 433 | } else { 434 | // std::cout << "cur_hyps size: " << cur_hyps.size() << " kws size: " << this->mkws_ids.size() < 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | #include "onnxruntime_cxx_api.h" // NOLINT 27 | #include "kws/utils.h" 28 | 29 | namespace wekws { 30 | 31 | struct Token { 32 | int timeStep; // token time step 33 | int id; // token id of vocab 34 | float prob; // token prob 35 | }; 36 | 37 | struct CtcPrefixBeamSearchOptions { 38 | int blank = 0; // blank id of vocab list. 39 | int first_beam_size = 3; 40 | int second_beam_size = 3; 41 | }; 42 | 43 | struct PrefixHash { 44 | size_t operator()(const std::vector& prefix) const { 45 | size_t hash_code = 0; 46 | // here we use KB&DR hash code 47 | for (int id : prefix) { 48 | hash_code = id + 31 * hash_code; 49 | } 50 | return hash_code; 51 | } 52 | }; 53 | 54 | struct PrefixScore { // for one prefix. 55 | float s = 0.0; // blank ending score 56 | float ns = 0.0; // none blank ending score 57 | std::vector nodes; 58 | float total_score() const { return ns + s; } 59 | }; 60 | 61 | struct KeyWord { // for keyword. 62 | float hit_score = 1.0; 63 | int start_frame = 0; 64 | int end_frame = 0; 65 | bool state = false; // is activated or not. 66 | }; 67 | 68 | // Define decoding type. 69 | typedef enum { 70 | DECODE_GREEDY_SEARCH=0, 71 | DECODE_PREFIX_BEAM_SEARCH=1, 72 | }DECODE_TYPE; 73 | 74 | class KeywordSpotting { 75 | public: 76 | explicit KeywordSpotting(const std::string &model_path, DECODE_TYPE decode_type, int model_type); 77 | 78 | void Reset(); 79 | 80 | void reset_value(); 81 | 82 | static void InitEngineThreads(int num_threads) { 83 | session_options_.SetIntraOpNumThreads(num_threads); 84 | session_options_.SetInterOpNumThreads(num_threads); 85 | } 86 | 87 | void Forward(const std::vector> &feats, 88 | std::vector> *prob); 89 | 90 | // function to load vocab from token.txt 91 | void readToken(const std::string& tokenFile) ; 92 | 93 | // set keyword 94 | void setKeyWord(const std::string& keyWord); 95 | 96 | void decode_keywords(std::vector>& probs, float hitScoreThr=0.0); 97 | 98 | // decoding alignments to predict sequence using greedy search. 99 | void decode_with_greedy_search(int offset, std::vector>& probs); 100 | 101 | // decoding alignments to predict sequence using prefix beam search. 102 | void decode_ctc_prefix_beam_search(int offset, const std::vector &prob); 103 | 104 | // find keyword 105 | void execute_detection(float hitScoreThr=0.1); 106 | 107 | // Token is keyword or not. 108 | bool isKeyword(int index); 109 | 110 | //update current hypotheses from proposed extensions. 111 | void UpdateHypotheses(const std::vector, PrefixScore>>& hpys); 112 | 113 | // maxpooling keywords 114 | std::vector mmaxpooling_keywords; 115 | 116 | //time stemp reset 117 | void stepClear(); 118 | 119 | //keyword info; 120 | KeyWord kwsInfo; 121 | 122 | 123 | private: 124 | // onnx runtime session 125 | static Ort::Env env_; 126 | static Ort::SessionOptions session_options_; 127 | std::shared_ptr session_ = nullptr; 128 | 129 | // model node names 130 | std::vector in_names_; 131 | std::vector out_names_; 132 | 133 | // meta info 134 | int cache_dim_ = 0; 135 | int cache_len_ = 0; 136 | int cache_4_ = 4; 137 | 138 | // cache info 139 | Ort::Value cache_ort_{nullptr}; 140 | std::vector cache_; 141 | 142 | // set mdoel type. 143 | int mmodel_type; 144 | 145 | //set decoder type. 146 | DECODE_TYPE mdecode_type; 147 | 148 | // vocab {token:index} 149 | std::unordered_map mvocab; 150 | // keyword string 151 | std::string mkey_word ; 152 | // keyword index set 153 | std::unordered_set mkeyword_set; 154 | // keyword index list. handle same token in keyword. 155 | std::vector mkeyword_token; 156 | 157 | // CTC alignments. 158 | std::vector alignments; 159 | // set of hypotheses with greed search. 160 | std::vector gd_cur_hyps; 161 | // set of hypotheses with prefix beam search. 162 | std::unordered_map, PrefixScore, PrefixHash> cur_hyps_; 163 | int total_frames=0;// frame offset, for absolute time 164 | 165 | //ctc prefix beam search 166 | const CtcPrefixBeamSearchOptions opts_={0, 3, 10}; 167 | 168 | // silence time, 1s audio = 99frames melFbank. with default frame_shift(10ms) and frame_length(25ms). 169 | // Now we set silenceFrames = 3s * 99 = 297. 170 | //int mSilenceFrames = 297; 171 | 172 | //global Time step. 173 | int mGTimeStep = 0; 174 | 175 | bool activated = false; 176 | }; 177 | 178 | 179 | } // namespace wekws 180 | 181 | #endif // KWS_KEYWORD_SPOTTING_H_ 182 | -------------------------------------------------------------------------------- /onnxruntime/kws/maxpooling_keyword.txt: -------------------------------------------------------------------------------- 1 | 嗨小问 2 | 你好问问 -------------------------------------------------------------------------------- /onnxruntime/kws/utils.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2024 Yang Chen (cyang8050@163.com) 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // reference code from: 16 | // https://github.com/wenet-e2e/wenet/blob/main/runtime/core/utils/utils.cc 17 | // 18 | 19 | #include "kws/utils.h" 20 | 21 | 22 | namespace wekws { 23 | 24 | template 25 | struct ValueComp { 26 | bool operator()(const std::pair& lhs, 27 | const std::pair& rhs) const { 28 | return lhs.first > rhs.first || 29 | (lhs.first == rhs.first && lhs.second < rhs.second); 30 | } 31 | }; 32 | 33 | // We refer the pytorch topk implementation 34 | // https://github.com/pytorch/pytorch/blob/master/caffe2/operators/top_k.cc 35 | template 36 | void TopK(const std::vector& data, int32_t k, std::vector* values, 37 | std::vector* indices) { 38 | std::vector> heap_data; 39 | int n = data.size(); 40 | for (int32_t i = 0; i < k && i < n; ++i) { 41 | heap_data.emplace_back(data[i], i); 42 | } 43 | std::priority_queue, std::vector>, 44 | ValueComp> 45 | pq(ValueComp(), std::move(heap_data)); 46 | for (int32_t i = k; i < n; ++i) { 47 | if (pq.top().first < data[i]) { 48 | pq.pop(); 49 | pq.emplace(data[i], i); 50 | } 51 | } 52 | 53 | values->resize(std::min(k, n)); 54 | indices->resize(std::min(k, n)); 55 | int32_t cur = values->size() - 1; 56 | while (!pq.empty()) { 57 | const auto& item = pq.top(); 58 | (*values)[cur] = item.first; 59 | (*indices)[cur] = item.second; 60 | pq.pop(); 61 | cur -= 1; 62 | } 63 | } 64 | 65 | template void TopK(const std::vector& data, int32_t k, 66 | std::vector* values, 67 | std::vector* indices); 68 | 69 | 70 | //读取PCM音频文件为vector 71 | void read_pcm(const std::string& file_path, std::vector& pcm_float){ 72 | std::ifstream pcm_file(file_path, std::ios::binary | std::ios::ate); 73 | 74 | if (!pcm_file.is_open()) { 75 | throw std::runtime_error("Failed to open file:" + file_path); // 抛出异常 76 | } 77 | 78 | // 获取文件大小 79 | std::streampos file_size = pcm_file.tellg(); 80 | pcm_file.seekg(0, std::ios::beg); 81 | 82 | // 读取PCM数据 83 | std::vector pcm_data(file_size); 84 | pcm_file.read(pcm_data.data(), file_size); 85 | pcm_file.close(); 86 | 87 | // 将PCM数据转换为float数据 88 | const int16_t* pcm_data_ptr = reinterpret_cast(pcm_data.data()); 89 | int sample_count = file_size / sizeof(int16_t); 90 | 91 | for (int i = 0; i < sample_count; ++i) { 92 | pcm_float.push_back(static_cast(pcm_data_ptr[i])); 93 | } 94 | } 95 | 96 | void process_directory(const boost::filesystem::path &dirpath, std::vector &wavePaths) { 97 | /*递归读取目录和子目录中的wav文件 98 | * */ 99 | for (boost::filesystem::directory_iterator it(dirpath); it != boost::filesystem::directory_iterator(); ++it) { 100 | const boost::filesystem::path &path = it->path(); 101 | if (boost::filesystem::is_regular_file(path) && path.extension() == ".wav") { 102 | // 读取 WAV 文件 103 | //std::cout << "读取 WAV 文件:" << path.string() << std::endl; 104 | wavePaths.push_back(path.string()); 105 | } else if (boost::filesystem::is_directory(path)) { 106 | // 递归处理子目录 107 | process_directory(path, wavePaths); 108 | } 109 | } 110 | } 111 | 112 | void writeVectorToFile(const std::vector& data, const std::string& filename) { 113 | std::ofstream file(filename); 114 | 115 | if (file.is_open()) { 116 | for (const auto& line : data) { 117 | file << line << std::endl; 118 | } 119 | file.close(); 120 | } else { 121 | throw std::runtime_error("Failed to open file:" + filename); // 抛出异常 122 | } 123 | } 124 | 125 | } // namespace wekws -------------------------------------------------------------------------------- /onnxruntime/kws/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) 2 | // Copyright (c) 2024 Yang Chen (cyang8050@163.com) 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | 16 | #ifndef UTILS_UTILS_H_ 17 | #define UTILS_UTILS_H_ 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include //apt-get install libboost-all-dev 25 | #include 26 | 27 | namespace wekws { 28 | 29 | template 30 | void TopK(const std::vector& data, int32_t k, std::vector* values, 31 | std::vector* indices); 32 | 33 | void read_pcm(const std::string& file_path, std::vector& pcm_float); 34 | 35 | void process_directory(const boost::filesystem::path &dirpath, std::vector &wavePaths); 36 | 37 | void writeVectorToFile(const std::vector& data, const std::string& filename); 38 | } // namespace wekws 39 | 40 | #endif // UTILS_UTILS_H_ 41 | -------------------------------------------------------------------------------- /onnxruntime/utils/blocking_queue.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef UTILS_BLOCKING_QUEUE_H_ 16 | #define UTILS_BLOCKING_QUEUE_H_ 17 | 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | namespace wenet { 25 | 26 | #define WENET_DISALLOW_COPY_AND_ASSIGN(Type) \ 27 | Type(const Type&) = delete; \ 28 | Type& operator=(const Type&) = delete; 29 | 30 | template 31 | class BlockingQueue { 32 | /*这段代码定义了一个线程安全的阻塞队列,可以在多线程环境下进行安全的插入和移除操作。同时,它使用条件变量来进行线程间的同步和通信, 33 | 以实现阻塞和唤醒线程的功能。当队列已满时,插入操作会被阻塞,直到队列有空闲位置。当队列为空时,移除操作会被阻塞,直到队列有元素可供移除。 34 | 这种阻塞队列的设计可以用于实现生产者-消费者模型,其中生产者线程向队列中插入数据,而消费者线程从队列中获取数据。当队列为空时,消费者线程会被阻塞, 35 | 直到有数据可供消费。当队列已满时,生产者线程会被阻塞,直到有空闲位置可以插入数据。 36 | 注意,该代码中使用了C++11中的互斥锁(std::mutex)、条件变量(std::condition_variable),以及队列(std::queue)等线程同步和容器类。 37 | */ 38 | public: 39 | explicit BlockingQueue(size_t capacity = std::numeric_limits::max()) 40 | : capacity_(capacity) {} 41 | 42 | void Push(const T& value) { 43 | { 44 | std::unique_lock lock(mutex_); 45 | while (queue_.size() >= capacity_) { 46 | not_full_condition_.wait(lock); 47 | } 48 | queue_.push(value); 49 | } 50 | not_empty_condition_.notify_one(); 51 | } 52 | 53 | void Push(T&& value) { 54 | { 55 | std::unique_lock lock(mutex_); 56 | while (queue_.size() >= capacity_) { 57 | not_full_condition_.wait(lock); 58 | } 59 | queue_.push(std::move(value)); 60 | } 61 | not_empty_condition_.notify_one(); 62 | } 63 | 64 | T Pop() { 65 | std::unique_lock lock(mutex_); 66 | while (queue_.empty()) { 67 | not_empty_condition_.wait(lock); 68 | } 69 | T t(std::move(queue_.front())); 70 | queue_.pop(); 71 | not_full_condition_.notify_one(); 72 | return t; 73 | } 74 | 75 | bool Empty() const { 76 | std::lock_guard lock(mutex_); 77 | return queue_.empty(); 78 | } 79 | 80 | size_t Size() const { 81 | std::lock_guard lock(mutex_); 82 | return queue_.size(); 83 | } 84 | 85 | void Clear() { 86 | while (!Empty()) { 87 | Pop(); 88 | } 89 | } 90 | 91 | private: 92 | size_t capacity_; 93 | mutable std::mutex mutex_; 94 | std::condition_variable not_full_condition_; 95 | std::condition_variable not_empty_condition_; 96 | std::queue queue_; 97 | 98 | public: 99 | WENET_DISALLOW_COPY_AND_ASSIGN(BlockingQueue); 100 | }; 101 | 102 | } // namespace wenet 103 | 104 | #endif // UTILS_BLOCKING_QUEUE_H_ 105 | -------------------------------------------------------------------------------- /onnxruntime/utils/log.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Binbin Zhang (binbzha@qq.com) 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | 16 | #ifndef UTILS_LOG_H_ 17 | #define UTILS_LOG_H_ 18 | 19 | #include 20 | 21 | #include 22 | #include 23 | 24 | namespace wenet { 25 | 26 | const int INFO = 0, WARNING = 1, ERROR = 2, FATAL = 3; 27 | 28 | class Logger { 29 | public: 30 | Logger(int severity, const char* func, const char* file, int line) { 31 | severity_ = severity; 32 | switch (severity) { 33 | case INFO: 34 | ss_ << "INFO ("; 35 | break; 36 | case WARNING: 37 | ss_ << "WARNING ("; 38 | break; 39 | case ERROR: 40 | ss_ << "ERROR ("; 41 | break; 42 | case FATAL: 43 | ss_ << "FATAL ("; 44 | break; 45 | default: 46 | severity_ = FATAL; 47 | ss_ << "FATAL ("; 48 | } 49 | ss_ << func << "():" << file << ':' << line << ") "; 50 | } 51 | 52 | ~Logger() { 53 | std::cerr << ss_.str() << std::endl << std::flush; 54 | if (severity_ == FATAL) { 55 | abort(); 56 | } 57 | } 58 | 59 | template Logger& operator<<(const T &val) { 60 | ss_ << val; 61 | return *this; 62 | } 63 | 64 | private: 65 | int severity_; 66 | std::ostringstream ss_; 67 | }; 68 | 69 | #define LOG(severity) ::wenet::Logger( \ 70 | ::wenet::severity, __func__, __FILE__, __LINE__) 71 | 72 | #define CHECK(test) \ 73 | do { \ 74 | if (!(test)) { \ 75 | std::cerr << "CHECK (" << __func__ << "():" << __FILE__ << ":" \ 76 | << __LINE__ << ") " << #test << std::endl; \ 77 | exit(-1); \ 78 | } \ 79 | } while (0) 80 | 81 | } // namespace wenet 82 | 83 | #endif // UTILS_LOG_H_ 84 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | pyaml 3 | onnxruntime==1.12.1 --------------------------------------------------------------------------------