├── .gitignore ├── LICENSE ├── README.md ├── README_zh.md ├── egs └── aishell1 │ └── s5 │ ├── avg.sh │ ├── base_decode.log │ ├── base_train.log │ ├── config_base.yaml │ ├── config_lm_lstm.yaml │ ├── config_lst.yaml │ ├── decode_test.sh │ ├── local │ ├── aishell_data_prep.sh │ └── download_and_untar.sh │ ├── path.sh │ ├── prep_data.sh │ ├── run.sh │ ├── score.sh │ └── train.sh ├── figs ├── dec_enc_att.png ├── enc_att.png └── loss.png ├── src ├── avg_last_ckpts.py ├── data.py ├── data_test.py ├── decode.py ├── decoder_layers.py ├── encoder_layers.py ├── encoder_layers_test.py ├── lm_layers.py ├── lm_train.py ├── metric.py ├── models.py ├── modules.py ├── prepare_data.py ├── schedule.py ├── sp_layers.py ├── sp_layers_test.py ├── stat_grapheme.py ├── stat_length.py ├── testdata │ ├── 100-121669-0000.flac │ ├── 100-121669-0000.wav │ ├── 103-1240-0005.flac │ ├── BAC009S0764W0121.wav │ ├── lexicon.txt │ ├── test.json │ ├── tokens.txt │ └── train_chars.txt ├── third_party │ ├── kaldi_io.py │ ├── kaldi_signal.py │ ├── transformer.py │ └── wavfile.py ├── train.py ├── trainer.py ├── utils.py └── utils_test.py └── tools ├── combine_data.sh ├── filter_scp.pl ├── int2sym.pl ├── parse_options.sh ├── run.pl ├── sclite ├── spk2utt_to_utt2spk.pl └── utt2spk_to_spk2utt.pl /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OpenASR 2 | 3 | A pytorch based end2end speech recognition system. The main architecture is [Speech-Transformer](https://ieeexplore.ieee.org/abstract/document/8462506/). 4 | 5 | [中文说明](https://github.com/by2101/OpenASR/blob/master/README_zh.md) 6 | 7 | ## Features 8 | 9 | 1. **Minimal Dependency**. The system does not depend on external softwares for feature extraction or decoding. Users just install PyTorch deep learning framework. 10 | 2. **Good Performance**. The system includes advanced algorithms, such as Label Smoothing, SpecAug, LST, and achieves good performance on ASHELL1. The baseline CER on AISHELL1 test is 6.6, which is better than ESPNet. 11 | 3. **Modular Design**. We divided the system into several modules, such as trainer, metric, schedule, models. It is easy for extension and adding features. 12 | 4. **End2End**. The feature extraction and tokenization are online. The system directly processes wave file. So, the procedure is much simpified. 13 | 14 | ## Dependency 15 | * python >= 3.6 16 | * pytorch >= 1.1 17 | * pyyaml >= 5.1 18 | * tensorflow and tensorboardX for visualization. (if you do not need visualize the results, you can set TENSORBOARD_LOGGING to 0 in src/utils.py) 19 | 20 | ## Usage 21 | We use KALDI style example organization. The example directory include top-level shell scripts, data directory, exp directory. We provide an AISHELL-1 example. The path is ROOT/egs/aishell1/s5. 22 | 23 | ### Data Preparation 24 | The data preparation script is prep_data.sh. It will automaticlly download AISHELL-1 dataset, and format it into KALDI style data directory. Then, it will generate json files, and grapheme vocabulary. You can set `corpusdir` for storing dataset. 25 | 26 | bash prep_data.sh 27 | 28 | Then, it will generate data directory and exp directory. 29 | 30 | ### Train Models 31 | We use yaml files for parameter configuration. We provide 3 examples. 32 | 33 | config_base.yaml # baseline ASR system 34 | config_lm_lstm.yaml # LSTM language model 35 | config_lst.yaml # training ASR with LST 36 | 37 | Run train.sh script for training baseline system. 38 | 39 | bash train.sh 40 | 41 | ### Model Averaging 42 | Average checkpoints for improving performance. 43 | 44 | bash avg.sh 45 | 46 | ### Decoding and Scoring 47 | Run decode_test.sh script for decoding test set. 48 | 49 | bash decode_test.sh 50 | bash score.sh data/test/text exp/exp1/decode_test_avg-last10 51 | 52 | ## Visualization 53 | We provide TensorboardX based visualization. The event files are stored in $expdir/log. You can use tensorboard to visualize the training procedure. 54 | 55 | tensorboard --logdir=$expdir --bind_all 56 | 57 | Then you can see procedures in browser (http://localhost:6006). 58 | 59 | Examples: 60 | 61 | ![per token loss in batch](https://github.com/by2101/OpenASR/raw/master/figs/loss.png) 62 | 63 | ![encoder attention](https://github.com/by2101/OpenASR/raw/master/figs/enc_att.png) 64 | 65 | ![encoder-decoder attention](https://github.com/by2101/OpenASR/raw/master/figs/dec_enc_att.png) 66 | 67 | 68 | ## Acknowledgement 69 | This system is implemented with PyTorch. We use wave reading codes from SciPy. We use SCTK software for scoring. Thanks to Dan Povey's team and their KALDI software. I learn ASR concept, and example organization from KALDI. And thanks to Google Lingvo Team. I learn the modular design from Lingvo. 70 | 71 | ## Bib 72 | @article{bai2019learn, 73 | title={Learn Spelling from Teachers: Transferring Knowledge from Language Models to Sequence-to-Sequence Speech Recognition}, 74 | author={Bai, Ye and Yi, Jiangyan and Tao, Jianhua and Tian, Zhengkun and Wen, Zhengqi}, 75 | year={2019} 76 | } 77 | 78 | ## References 79 | Dong, Linhao, Shuang Xu, and Bo Xu. "Speech-transformer: a no-recurrence sequence-to-sequence model for speech recognition." 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2018. 80 | Zhou, Shiyu, et al. "Syllable-based sequence-to-sequence speech recognition with the transformer in mandarin chinese." arXiv preprint arXiv:1804.10752 (2018). 81 | -------------------------------------------------------------------------------- /README_zh.md: -------------------------------------------------------------------------------- 1 | # OpenASR 2 | 3 | 基于Pytorch的端到端语音识别系统. 主要结构使用 [Speech-Transformer](https://ieeexplore.ieee.org/abstract/document/8462506/). 4 | 5 | [README](https://github.com/by2101/OpenASR/blob/master/README.md) 6 | 7 | ## 主要特性 8 | 9 | 1. **最小依赖**. 系统不依赖其它额外的软件来提取特征或是解码。 用户只需要安装Pytorch即可。 10 | 2. **性能优良**. 系统集成了多个算法,包括Label Smoothing, SpecAugmentation, LST 等。在AISHELL-1数据集上,基线系统CER为6.6%,好于ESPNet。 11 | 3. **模块化设计**. 系统分为trainer, metric, schedule等模块,方便进一步扩展。 12 | 4. **端到端实现**. 特征提取和token划分采用在线实现。系统可以直接处理wav文件,整个流程大大简化。 13 | 14 | ## 依赖 15 | python >= 3.6 16 | pytorch >= 1.1 17 | pyyaml >= 5.1 18 | tensorflow 和 tensorboardX (如果不需要可视化,可以将src/utils.py中TENSORBOARD_LOGGING变量设为0) 19 | 20 | ## 使用方法 21 | 我们采用KALDI风格的例子。例子的目录包括一些高层脚本,data目录和exp目录。我们提供了一个AISHELL-1的例子,位于ROOT/egs/aishell1/s5. 22 | 23 | ### 数据准备 24 | 数据准备的脚本是prep_data.sh。它会自动地下载AISHELL-1数据集,并将数据整理成KALDI风格的data目录。然后,它会生成json数据文件和字表。你可以设置`corpusdir` 来改变存储数据的目录。 25 | 26 | bash prep_data.sh 27 | 28 | 29 | ### 训练模型 30 | 我们采用yaml文件来配置参数。我们提供3个例子。 31 | 32 | config_base.yaml # 基线 ASR 33 | config_lm_lstm.yaml # LSTM 语言模型 34 | config_lst.yaml # 采用LST训练的ASR 35 | 36 | 运行 train.sh 脚本训练基线系统。 37 | 38 | bash train.sh 39 | 40 | ### 模型平均 41 | 我们采用模型平均来提高性能。 42 | 43 | bash avg.sh 44 | 45 | ### 解码和打分 46 | 运行 decode_test.sh 解码测试集。然后运行score.sh计算CER。 47 | 48 | bash decode_test.sh 49 | bash score.sh data/test/text exp/exp1/decode_test_avg-last10 50 | 51 | ## 可视化 52 | 我们提供基于TensorbordX的可视化。event文件保存在$expdir/log。你可以通过tensorboard来观察训练过程。 53 | 54 | tensorboard --logdir=$expdir --bind_all 55 | 56 | 然后你就可以在浏览器中观察 (http://localhost:6006). 57 | 58 | 例子: 59 | ![per token loss in batch](https://github.com/by2101/OpenASR/raw/master/figs/loss.png) 60 | ![encoder attention](https://github.com/by2101/OpenASR/raw/master/figs/enc_att.png) 61 | ![encoder-decoder attention](https://github.com/by2101/OpenASR/raw/master/figs/dec_enc_att.png) 62 | 63 | 64 | ## 致谢 65 | 系统是基于PyTorch实现的。我们采用了SciPy里的读取wav文件的代码。我们使用了SCTK来计算CER。感谢Dan Povey团队和他们的KALDI,ASR的概念,例子的组织是从KALDI里学到的。感谢Google的Lingvo团队,模块化设计从Lingvo里学了很多。 66 | 67 | 68 | ## 引用 69 | Dong, Linhao, Shuang Xu, and Bo Xu. "Speech-transformer: a no-recurrence sequence-to-sequence model for speech recognition." 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2018. 70 | Zhou, Shiyu, et al. "Syllable-based sequence-to-sequence speech recognition with the transformer in mandarin chinese." arXiv preprint arXiv:1804.10752 (2018). 71 | -------------------------------------------------------------------------------- /egs/aishell1/s5/avg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source path.sh 3 | expdir=exp/exp1 4 | 5 | python $MAIN_ROOT/src/avg_last_ckpts.py \ 6 | $expdir \ 7 | 10 8 | 9 | 10 | -------------------------------------------------------------------------------- /egs/aishell1/s5/base_decode.log: -------------------------------------------------------------------------------- 1 | 2020-03-30 21:50:35,127 - /home/by2101/OpenASR/egs/aishell1/s5/../../../src/decode.py[line:64] - INFO: Load package from exp/base/ep-0045.pt. 2 | 2020-03-30 21:50:36,707 - /home/by2101/OpenASR/egs/aishell1/s5/../../../src/decode.py[line:71] - INFO: 3 | Model info: 4 | Model( 5 | (splayer): SPLayer() 6 | (encoder): Transformer( 7 | (sub): Conv2dSubsampleV2( 8 | (conv): Sequential( 9 | (subsample/conv0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 1)) 10 | (subsample/relu0): ReLU() 11 | (subsample/conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 1)) 12 | (subsample/relu1): ReLU() 13 | ) 14 | (affine): Linear(in_features=2432, out_features=512, bias=True) 15 | ) 16 | (pe): PositionalEncoding() 17 | (dropout): Dropout(p=0.1) 18 | (transformer_encoder): TransformerEncoder( 19 | (layers): ModuleList( 20 | (0): TransformerEncoderLayer( 21 | (self_attn): MultiheadAttention( 22 | (out_proj): Linear(in_features=512, out_features=512, bias=True) 23 | ) 24 | (linear1): Linear(in_features=512, out_features=4096, bias=True) 25 | (dropout): Dropout(p=0.1) 26 | (linear2): Linear(in_features=2048, out_features=512, bias=True) 27 | (norm1): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 28 | (norm2): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 29 | (dropout1): Dropout(p=0.1) 30 | (dropout2): Dropout(p=0.1) 31 | ) 32 | (1): TransformerEncoderLayer( 33 | (self_attn): MultiheadAttention( 34 | (out_proj): Linear(in_features=512, out_features=512, bias=True) 35 | ) 36 | (linear1): Linear(in_features=512, out_features=4096, bias=True) 37 | (dropout): Dropout(p=0.1) 38 | (linear2): Linear(in_features=2048, out_features=512, bias=True) 39 | (norm1): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 40 | (norm2): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 41 | (dropout1): Dropout(p=0.1) 42 | (dropout2): Dropout(p=0.1) 43 | ) 44 | (2): TransformerEncoderLayer( 45 | (self_attn): MultiheadAttention( 46 | (out_proj): Linear(in_features=512, out_features=512, bias=True) 47 | ) 48 | (linear1): Linear(in_features=512, out_features=4096, bias=True) 49 | (dropout): Dropout(p=0.1) 50 | (linear2): Linear(in_features=2048, out_features=512, bias=True) 51 | (norm1): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 52 | (norm2): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 53 | (dropout1): Dropout(p=0.1) 54 | (dropout2): Dropout(p=0.1) 55 | ) 56 | (3): TransformerEncoderLayer( 57 | (self_attn): MultiheadAttention( 58 | (out_proj): Linear(in_features=512, out_features=512, bias=True) 59 | ) 60 | (linear1): Linear(in_features=512, out_features=4096, bias=True) 61 | (dropout): Dropout(p=0.1) 62 | (linear2): Linear(in_features=2048, out_features=512, bias=True) 63 | (norm1): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 64 | (norm2): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 65 | (dropout1): Dropout(p=0.1) 66 | (dropout2): Dropout(p=0.1) 67 | ) 68 | (4): TransformerEncoderLayer( 69 | (self_attn): MultiheadAttention( 70 | (out_proj): Linear(in_features=512, out_features=512, bias=True) 71 | ) 72 | (linear1): Linear(in_features=512, out_features=4096, bias=True) 73 | (dropout): Dropout(p=0.1) 74 | (linear2): Linear(in_features=2048, out_features=512, bias=True) 75 | (norm1): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 76 | (norm2): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 77 | (dropout1): Dropout(p=0.1) 78 | (dropout2): Dropout(p=0.1) 79 | ) 80 | (5): TransformerEncoderLayer( 81 | (self_attn): MultiheadAttention( 82 | (out_proj): Linear(in_features=512, out_features=512, bias=True) 83 | ) 84 | (linear1): Linear(in_features=512, out_features=4096, bias=True) 85 | (dropout): Dropout(p=0.1) 86 | (linear2): Linear(in_features=2048, out_features=512, bias=True) 87 | (norm1): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 88 | (norm2): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 89 | (dropout1): Dropout(p=0.1) 90 | (dropout2): Dropout(p=0.1) 91 | ) 92 | ) 93 | (norm): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 94 | ) 95 | ) 96 | (decoder): TransformerDecoder( 97 | (emb): Embedding(4233, 512) 98 | (pe): PositionalEncoding() 99 | (dropout): Dropout(p=0.1) 100 | (transformer_block): TransformerDecoder( 101 | (layers): ModuleList( 102 | (0): TransformerDecoderLayer( 103 | (self_attn): MultiheadAttention( 104 | (out_proj): Linear(in_features=512, out_features=512, bias=True) 105 | ) 106 | (multihead_attn): MultiheadAttention( 107 | (out_proj): Linear(in_features=512, out_features=512, bias=True) 108 | ) 109 | (linear1): Linear(in_features=512, out_features=4096, bias=True) 110 | (dropout): Dropout(p=0.1) 111 | (linear2): Linear(in_features=2048, out_features=512, bias=True) 112 | (norm1): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 113 | (norm2): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 114 | (norm3): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 115 | (dropout1): Dropout(p=0.1) 116 | (dropout2): Dropout(p=0.1) 117 | (dropout3): Dropout(p=0.1) 118 | ) 119 | (1): TransformerDecoderLayer( 120 | (self_attn): MultiheadAttention( 121 | (out_proj): Linear(in_features=512, out_features=512, bias=True) 122 | ) 123 | (multihead_attn): MultiheadAttention( 124 | (out_proj): Linear(in_features=512, out_features=512, bias=True) 125 | ) 126 | (linear1): Linear(in_features=512, out_features=4096, bias=True) 127 | (dropout): Dropout(p=0.1) 128 | (linear2): Linear(in_features=2048, out_features=512, bias=True) 129 | (norm1): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 130 | (norm2): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 131 | (norm3): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 132 | (dropout1): Dropout(p=0.1) 133 | (dropout2): Dropout(p=0.1) 134 | (dropout3): Dropout(p=0.1) 135 | ) 136 | (2): TransformerDecoderLayer( 137 | (self_attn): MultiheadAttention( 138 | (out_proj): Linear(in_features=512, out_features=512, bias=True) 139 | ) 140 | (multihead_attn): MultiheadAttention( 141 | (out_proj): Linear(in_features=512, out_features=512, bias=True) 142 | ) 143 | (linear1): Linear(in_features=512, out_features=4096, bias=True) 144 | (dropout): Dropout(p=0.1) 145 | (linear2): Linear(in_features=2048, out_features=512, bias=True) 146 | (norm1): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 147 | (norm2): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 148 | (norm3): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 149 | (dropout1): Dropout(p=0.1) 150 | (dropout2): Dropout(p=0.1) 151 | (dropout3): Dropout(p=0.1) 152 | ) 153 | (3): TransformerDecoderLayer( 154 | (self_attn): MultiheadAttention( 155 | (out_proj): Linear(in_features=512, out_features=512, bias=True) 156 | ) 157 | (multihead_attn): MultiheadAttention( 158 | (out_proj): Linear(in_features=512, out_features=512, bias=True) 159 | ) 160 | (linear1): Linear(in_features=512, out_features=4096, bias=True) 161 | (dropout): Dropout(p=0.1) 162 | (linear2): Linear(in_features=2048, out_features=512, bias=True) 163 | (norm1): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 164 | (norm2): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 165 | (norm3): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 166 | (dropout1): Dropout(p=0.1) 167 | (dropout2): Dropout(p=0.1) 168 | (dropout3): Dropout(p=0.1) 169 | ) 170 | (4): TransformerDecoderLayer( 171 | (self_attn): MultiheadAttention( 172 | (out_proj): Linear(in_features=512, out_features=512, bias=True) 173 | ) 174 | (multihead_attn): MultiheadAttention( 175 | (out_proj): Linear(in_features=512, out_features=512, bias=True) 176 | ) 177 | (linear1): Linear(in_features=512, out_features=4096, bias=True) 178 | (dropout): Dropout(p=0.1) 179 | (linear2): Linear(in_features=2048, out_features=512, bias=True) 180 | (norm1): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 181 | (norm2): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 182 | (norm3): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 183 | (dropout1): Dropout(p=0.1) 184 | (dropout2): Dropout(p=0.1) 185 | (dropout3): Dropout(p=0.1) 186 | ) 187 | (5): TransformerDecoderLayer( 188 | (self_attn): MultiheadAttention( 189 | (out_proj): Linear(in_features=512, out_features=512, bias=True) 190 | ) 191 | (multihead_attn): MultiheadAttention( 192 | (out_proj): Linear(in_features=512, out_features=512, bias=True) 193 | ) 194 | (linear1): Linear(in_features=512, out_features=4096, bias=True) 195 | (dropout): Dropout(p=0.1) 196 | (linear2): Linear(in_features=2048, out_features=512, bias=True) 197 | (norm1): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 198 | (norm2): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 199 | (norm3): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True) 200 | (dropout1): Dropout(p=0.1) 201 | (dropout2): Dropout(p=0.1) 202 | (dropout3): Dropout(p=0.1) 203 | ) 204 | ) 205 | ) 206 | (output_affine): Linear(in_features=512, out_features=4233, bias=True) 207 | ) 208 | ) 209 | 2020-03-30 21:50:36,707 - /home/by2101/OpenASR/src/models.py[line:250] - INFO: Restore model states... 210 | 2020-03-30 21:50:48,003 - /home/by2101/OpenASR/egs/aishell1/s5/../../../src/decode.py[line:90] - INFO: Start feedforward... 211 | 2020-03-30 21:50:49,489 - /home/by2101/OpenASR/egs/aishell1/s5/../../../src/decode.py[line:117] - INFO: 212 | Results for BAC009S0764W0121: 213 | top1: 甚 至 出 现 交 易 几 乎 停 止 的 情 况 score: -1.8161354065 214 | top2: 甚 至 出 现 交 易 几 乎 停 滞 的 情 况 score: -2.8520746231 215 | top3: 甚 至 出 现 交 易 几 乎 停 止 了 情 况 score: -3.8272104263 216 | top4: 甚 至 出 现 交 易 几 乎 调 整 的 情 况 score: -5.5626416206 217 | top5: 甚 至 出 现 交 易 几 乎 停 滞 了 情 况 score: -6.0889887810 218 | 219 | 2020-03-30 21:50:49,489 - /home/by2101/OpenASR/egs/aishell1/s5/../../../src/decode.py[line:119] - INFO: Prossesed 1 utterances in 1.485 s 220 | 2020-03-30 21:50:50,488 - /home/by2101/OpenASR/egs/aishell1/s5/../../../src/decode.py[line:117] - INFO: 221 | Results for BAC009S0764W0122: 222 | top1: 一 二 线 城 市 虽 然 也 处 于 调 整 中 score: -1.3934774399 223 | top2: 一 二 线 城 市 孙 杨 也 处 于 调 整 中 score: -3.4023485184 224 | top3: 一 二 线 城 市 孙 阳 也 处 于 调 整 中 score: -6.3469929695 225 | top4: 一 二 线 城 市 依 然 也 处 于 调 整 中 score: -6.8223943710 226 | top5: 一 二 线 城 市 虽 然 也 出 于 调 整 中 score: -7.0899319649 227 | 228 | 2020-03-30 21:50:50,488 - /home/by2101/OpenASR/egs/aishell1/s5/../../../src/decode.py[line:119] - INFO: Prossesed 2 utterances in 2.484 s 229 | 2020-03-30 21:50:51,489 - /home/by2101/OpenASR/egs/aishell1/s5/../../../src/decode.py[line:117] - INFO: 230 | Results for BAC009S0764W0123: 231 | top1: 但 因 为 聚 集 了 过 多 公 共 资 源 score: -1.5715856552 232 | top2: 但 因 为 聚 集 了 过 多 公 共 思 源 score: -3.6406021118 233 | top3: 但 因 为 聚 集 了 过 多 公 共 四 元 score: -4.1958231926 234 | top4: 但 因 为 聚 集 了 过 多 公 共 司 员 score: -4.7456755638 235 | top5: 但 因 为 拒 绝 了 过 多 公 共 资 源 score: -4.7964982986 236 | 237 | 2020-03-30 21:50:51,490 - /home/by2101/OpenASR/egs/aishell1/s5/../../../src/decode.py[line:119] - INFO: Prossesed 3 utterances in 3.486 s 238 | 2020-03-30 21:50:52,477 - /home/by2101/OpenASR/egs/aishell1/s5/../../../src/decode.py[line:117] - INFO: 239 | Results for BAC009S0764W0124: 240 | top1: 为 了 规 避 三 四 线 城 市 明 显 过 剩 的 市 场 风 险 score: -1.5112810135 241 | top2: 为 了 规 避 三 四 线 城 市 明 显 过 胜 的 市 场 风 险 score: -6.5478034019 242 | top3: 为 了 规 避 三 四 线 城 市 明 显 过 剩 的 市 场 风 现 score: -6.7623558044 243 | top4: 为 了 规 闭 三 四 线 城 市 明 显 过 剩 的 市 场 风 险 score: -6.7998075485 244 | top5: 为 了 规 避 三 次 线 城 市 明 显 过 剩 的 市 场 风 险 score: -6.9653940201 245 | 246 | 2020-03-30 21:50:52,478 - /home/by2101/OpenASR/egs/aishell1/s5/../../../src/decode.py[line:119] - INFO: Prossesed 4 utterances in 4.474 s 247 | 2020-03-30 21:50:53,670 - /home/by2101/OpenASR/egs/aishell1/s5/../../../src/decode.py[line:117] - INFO: 248 | Results for BAC009S0764W0125: 249 | top1: 标 杆 房 企 必 然 调 整 市 场 战 略 score: -1.2184476852 250 | top2: 标 杆 房 企 必 然 调 整 市 场 策 略 score: -5.4294738770 251 | top3: 标 杆 房 企 必 然 调 整 市 场 的 战 略 score: -5.6316819191 252 | top4: 标 准 房 企 必 然 调 整 市 场 战 略 score: -5.8972325325 253 | top5: 标 杆 房 企 必 然 调 整 市 场 大 跃 score: -7.0259995461 254 | 255 | 2020-03-30 21:50:53,671 - /home/by2101/OpenASR/egs/aishell1/s5/../../../src/decode.py[line:119] - INFO: Prossesed 5 utterances in 5.667 s 256 | 2020-03-30 21:50:54,698 - /home/by2101/OpenASR/egs/aishell1/s5/../../../src/decode.py[line:117] - INFO: 257 | Results for BAC009S0764W0126: 258 | top1: 因 此 土 地 储 备 至 关 重 要 score: -1.5152788162 259 | top2: 因 此 土 地 储 备 直 观 重 要 score: -2.5729951859 260 | top3: 因 此 土 地 储 备 之 关 重 要 score: -3.3197784424 261 | top4: 因 此 土 地 储 备 置 关 重 要 score: -5.1346631050 262 | top5: 因 此 土 地 储 备 之 官 重 要 score: -6.1023716927 263 | 264 | 2020-03-30 21:50:54,698 - /home/by2101/OpenASR/egs/aishell1/s5/../../../src/decode.py[line:119] - INFO: Prossesed 6 utterances in 6.695 s 265 | 2020-03-30 21:50:55,901 - /home/by2101/OpenASR/egs/aishell1/s5/../../../src/decode.py[line:117] - INFO: 266 | Results for BAC009S0764W0127: 267 | top1: 中 原 地 产 首 席 分 析 师 张 大 伟 说 score: -1.1358175278 268 | top2: 中 元 地 产 首 席 分 析 师 张 大 伟 说 score: -6.1328477859 269 | top3: 中 原 地 产 手 机 分 析 师 张 大 伟 说 score: -6.6685914993 270 | top4: 中 原 地 场 首 席 分 析 师 张 大 伟 说 score: -6.8094315529 271 | top5: 中 原 地 产 首 席 分 析 师 张 大 玮 说 score: -7.0437431335 272 | 273 | 2020-03-30 21:50:55,901 - /home/by2101/OpenASR/egs/aishell1/s5/../../../src/decode.py[line:119] - INFO: Prossesed 7 utterances in 7.898 s 274 | 2020-03-30 21:50:57,289 - /home/by2101/OpenASR/egs/aishell1/s5/../../../src/decode.py[line:117] - INFO: 275 | Results for BAC009S0764W0128: 276 | top1: 一 线 城 市 土 地 供 应 量 减 少 score: -0.8210105896 277 | top2: 一 线 城 市 土 地 供 用 量 减 少 score: -5.3575639725 278 | top3: 一 线 城 市 土 地 供 应 链 减 少 score: -6.8985385895 279 | top4: 一 些 城 市 土 地 供 应 量 减 少 score: -6.9726257324 280 | top5: 一 线 城 市 土 地 供 需 量 减 少 score: -6.9961733818 281 | 282 | 2020-03-30 21:50:57,289 - /home/by2101/OpenASR/egs/aishell1/s5/../../../src/decode.py[line:119] - INFO: Prossesed 8 utterances in 9.286 s 283 | 2020-03-30 21:50:58,541 - /home/by2101/OpenASR/egs/aishell1/s5/../../../src/decode.py[line:117] - INFO: 284 | Results for BAC009S0764W0129: 285 | top1: 也 助 推 了 土 地 市 场 的 火 爆 score: -0.8998661041 286 | top2: 也 注 推 了 土 地 市 场 的 火 爆 score: -4.7260103226 287 | top3: 也 助 推 了 土 地 市 场 的 活 动 score: -4.8095049858 288 | top4: 也 助 推 了 土 地 市 场 的 活 报 score: -6.2755641937 289 | top5: 也 助 推 了 土 地 市 场 的 活 爆 score: -6.3123712540 290 | 291 | 2020-03-30 21:50:58,541 - /home/by2101/OpenASR/egs/aishell1/s5/../../../src/decode.py[line:119] - INFO: Prossesed 9 utterances in 10.538 s 292 | 2020-03-30 21:50:59,809 - /home/by2101/OpenASR/egs/aishell1/s5/../../../src/decode.py[line:117] - INFO: 293 | Results for BAC009S0764W0130: 294 | top1: 北 京 仅 新 增 住 宅 土 地 供 应 时 松 score: -4.4282484055 295 | top2: 北 京 仅 新 增 住 宅 土 地 供 应 十 宗 score: -5.1933073997 296 | top3: 北 京 简 新 增 住 宅 土 地 供 应 时 松 score: -6.4350633621 297 | top4: 北 京 仅 新 增 住 宅 土 地 供 应 时 宗 score: -6.4487853050 298 | top5: 北 京 减 薪 增 住 宅 土 地 供 应 时 松 score: -6.8076214790 299 | 300 | 2020-03-30 21:50:59,810 - /home/by2101/OpenASR/egs/aishell1/s5/../../../src/decode.py[line:119] - INFO: Prossesed 10 utterances in 11.806 s 301 | 2020-03-30 21:51:00,935 - /home/by2101/OpenASR/egs/aishell1/s5/../../../src/decode.py[line:117] - INFO: 302 | Results for BAC009S0764W0131: 303 | top1: 开 发 边 界 将 作 为 城 市 发 展 的 刚 性 约 定 score: -1.4182643890 304 | top2: 开 发 边 际 将 作 为 城 市 发 展 的 刚 性 约 定 score: -6.1457824707 305 | top3: 开 发 边 解 将 作 为 城 市 发 展 的 刚 性 约 定 score: -6.6899409294 306 | top4: 开 发 边 界 将 作 为 城 市 发 展 的 刚 性 月 定 score: -7.1161942482 307 | top5: 开 发 边 界 将 作 为 城 市 发 展 的 纲 性 约 定 score: -7.2151556015 308 | 309 | 2020-03-30 21:51:00,935 - /home/by2101/OpenASR/egs/aishell1/s5/../../../src/decode.py[line:119] - INFO: Prossesed 11 utterances in 12.932 s 310 | 2020-03-30 21:51:01,966 - /home/by2101/OpenASR/egs/aishell1/s5/../../../src/decode.py[line:117] - INFO: 311 | Results for BAC009S0764W0132: 312 | top1: 不 得 超 越 界 限 盲 目 扩 张 score: -0.4009504318 313 | top2: 不 得 超 越 借 限 盲 目 扩 张 score: -5.7239694595 314 | top3: 不 的 超 越 界 限 盲 目 扩 张 score: -6.0082988739 315 | top4: 不 德 超 越 界 限 盲 目 扩 张 score: -6.0330848694 316 | top5: 不 得 超 越 借 线 盲 目 扩 张 score: -6.4361486435 317 | 318 | 2020-03-30 21:51:01,966 - /home/by2101/OpenASR/egs/aishell1/s5/../../../src/decode.py[line:119] - INFO: Prossesed 12 utterances in 13.963 s 319 | decode_test.sh: line 17: 7269 Terminated CUDA_VISIBLE_DEVICES="1" python -W ignore::UserWarning $MAIN_ROOT/src/decode.py --feed-batchsize 1 --nbest 5 --use_gpu True $expdir/${ep}.pt exp/aishell1_train_chars.txt data/test "file" $decode_dir/hyp.trn 320 | -------------------------------------------------------------------------------- /egs/aishell1/s5/config_base.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | trainset: exp/train.json 3 | devset: exp/dev.json 4 | vocab_path: "exp/aishell1_train_chars.txt" 5 | maxlen: 60 6 | fetchworker_num: 12 7 | model: 8 | signal: 9 | feature_type: fbank 10 | sample_rate: 16000 11 | num_mel_bins: 80 12 | use_energy: False 13 | spec_aug: 14 | freq_mask_num: 2 15 | freq_mask_width: 27 16 | time_mask_num: 2 17 | time_mask_width: 40 18 | encoder: 19 | type: Transformer 20 | sub: 21 | type: ConvV2 22 | layer_num: 2 23 | input_dim: 80 24 | d_model: 512 25 | nhead: 8 26 | dim_feedforward: 2048 27 | activation: "glu" 28 | num_layers: 6 29 | dropout_rate: 0.1 30 | decoder: 31 | type: TransformerDecoder 32 | vocab_size: -1 # derived by tokenizer 33 | d_model: 512 34 | nhead: 8 35 | num_layers: 6 36 | encoder_dim: 512 37 | dim_feedforward: 2048 38 | activation: "glu" 39 | dropout_rate: 0.1 40 | training: 41 | batch_time: 150 42 | multi_gpu: False 43 | exp_dir: exp/base 44 | print_inteval: 10 45 | num_epoch: 80 46 | accumulate_grad_batch: 8 47 | init_lr: 1.0 48 | optimtype: adam 49 | grad_max_norm: 50. 50 | label_smooth: 0.1 51 | lr_scheduler: 52 | type: "warmup_transformer" 53 | warmup_step: 16000 54 | d_model: 512 55 | -------------------------------------------------------------------------------- /egs/aishell1/s5/config_lm_lstm.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | trainset: exp/train_text 3 | devset: exp/dev_text 4 | vocab_path: "exp/aishell1_train_chars.txt" 5 | maxlen: 70 6 | fetchworker_num: 12 7 | model: 8 | type: lstm 9 | vocab_size: -1 # derived by tokenizer 10 | hidden_size: 1024 11 | num_layers: 2 12 | dropout_rate: 0.1 13 | training: 14 | batch_size: 20 15 | multi_gpu: True 16 | exp_dir: exp/lm_lstm 17 | print_inteval: 10 18 | vis_atten: False 19 | num_epoch: 20 20 | accumulate_grad_batch: 1 21 | label_smooth: 0. 22 | init_lr: 0.1 23 | optimtype: sgd 24 | grad_max_norm: 50. 25 | lr_scheduler: 26 | type: "bob" 27 | decay_coef: 0.5 28 | tolerate: 0.1 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /egs/aishell1/s5/config_lst.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | trainset: exp/train.json 3 | devset: exp/dev.json 4 | vocab_path: "exp/aishell1_train_chars.txt" 5 | maxlen: 60 6 | fetchworker_num: 12 7 | model: 8 | signal: 9 | feature_type: fbank 10 | sample_rate: 16000 11 | num_mel_bins: 80 12 | use_energy: False 13 | spec_aug: 14 | freq_mask_num: 2 15 | freq_mask_width: 27 16 | time_mask_num: 2 17 | time_mask_width: 40 18 | encoder: 19 | type: Transformer 20 | sub: 21 | type: ConvV2 22 | layer_num: 2 23 | input_dim: 80 24 | d_model: 512 25 | nhead: 8 26 | dim_feedforward: 2048 27 | activation: "glu" 28 | num_layers: 6 29 | dropout_rate: 0.1 30 | decoder: 31 | type: TransformerDecoder 32 | vocab_size: -1 # derived by tokenizer 33 | d_model: 512 34 | nhead: 8 35 | num_layers: 6 36 | encoder_dim: 512 37 | dim_feedforward: 2048 38 | activation: "glu" 39 | dropout_rate: 0.1 40 | training: 41 | batch_time: 150 42 | multi_gpu: False 43 | exp_dir: exp/exp1 44 | print_inteval: 10 45 | num_epoch: 80 46 | accumulate_grad_batch: 8 47 | init_lr: 1.0 48 | optimtype: adam 49 | grad_max_norm: 50. 50 | label_smooth: 0.1 51 | lst: 52 | lm_path: "exp/lm_lstm/ep-0030.pt" 53 | lst_w: 0.1 54 | lst_t: 5.0 55 | lr_scheduler: 56 | type: "warmup_transformer" 57 | warmup_step: 12000 58 | d_model: 512 59 | -------------------------------------------------------------------------------- /egs/aishell1/s5/decode_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source path.sh 3 | expdir=exp/base 4 | ep=avg-last10 5 | decode_dir=$expdir/decode_test_${ep} 6 | mkdir -p $decode_dir 7 | 8 | CUDA_VISIBLE_DEVICES="1" \ 9 | python -W ignore::UserWarning $MAIN_ROOT/src/decode.py \ 10 | --feed-batchsize 40 \ 11 | --nbest 5 \ 12 | --use_gpu True \ 13 | $expdir/${ep}.pt \ 14 | exp/aishel1_train_chars.txt \ 15 | data/test \ 16 | "file" \ 17 | $decode_dir/hyp.trn 18 | 19 | 20 | -------------------------------------------------------------------------------- /egs/aishell1/s5/local/aishell_data_prep.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2017 Xingyu Na 4 | # Apache 2.0 5 | 6 | #. ./path.sh || exit 1; 7 | 8 | if [ $# != 2 ]; then 9 | echo "Usage: $0 " 10 | echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript" 11 | exit 1; 12 | fi 13 | 14 | aishell_audio_dir=$1 15 | aishell_text=$2/aishell_transcript_v0.8.txt 16 | 17 | train_dir=data/local/train 18 | dev_dir=data/local/dev 19 | test_dir=data/local/test 20 | tmp_dir=data/local/tmp 21 | 22 | mkdir -p $train_dir 23 | mkdir -p $dev_dir 24 | mkdir -p $test_dir 25 | mkdir -p $tmp_dir 26 | 27 | # data directory check 28 | if [ ! -d $aishell_audio_dir ] || [ ! -f $aishell_text ]; then 29 | echo "Error: $0 requires two directory arguments" 30 | exit 1; 31 | fi 32 | 33 | # find wav audio file for train, dev and test resp. 34 | find $aishell_audio_dir -iname "*.wav" > $tmp_dir/wav.flist 35 | n=`cat $tmp_dir/wav.flist | wc -l` 36 | [ $n -ne 141925 ] && \ 37 | echo Warning: expected 141925 data data files, found $n 38 | 39 | grep -i "wav/train" $tmp_dir/wav.flist > $train_dir/wav.flist || exit 1; 40 | grep -i "wav/dev" $tmp_dir/wav.flist > $dev_dir/wav.flist || exit 1; 41 | grep -i "wav/test" $tmp_dir/wav.flist > $test_dir/wav.flist || exit 1; 42 | 43 | rm -r $tmp_dir 44 | 45 | # Transcriptions preparation 46 | for dir in $train_dir $dev_dir $test_dir; do 47 | echo Preparing $dir transcriptions 48 | sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list 49 | sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{i=NF-1;printf("%s %s\n",$NF,$i)}' > $dir/utt2spk_all 50 | paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all 51 | perl $TOOLS_ROOT/filter_scp.pl -f 1 $dir/utt.list $aishell_text > $dir/transcripts.txt 52 | awk '{print $1}' $dir/transcripts.txt > $dir/utt.list 53 | perl $TOOLS_ROOT/filter_scp.pl -f 1 $dir/utt.list $dir/utt2spk_all | sort -u > $dir/utt2spk 54 | perl $TOOLS_ROOT/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp 55 | sort -u $dir/transcripts.txt > $dir/text 56 | perl $TOOLS_ROOT/utt2spk_to_spk2utt.pl $dir/utt2spk > $dir/spk2utt 57 | done 58 | 59 | mkdir -p data/train data/dev data/test 60 | 61 | for f in spk2utt utt2spk wav.scp text; do 62 | cp $train_dir/$f data/train/$f || exit 1; 63 | cp $dev_dir/$f data/dev/$f || exit 1; 64 | cp $test_dir/$f data/test/$f || exit 1; 65 | done 66 | 67 | echo "$0: AISHELL data preparation succeeded" 68 | exit 0; 69 | -------------------------------------------------------------------------------- /egs/aishell1/s5/local/download_and_untar.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2014 Johns Hopkins University (author: Daniel Povey) 4 | # 2017 Xingyu Na 5 | # Apache 2.0 6 | 7 | remove_archive=false 8 | 9 | if [ "$1" == --remove-archive ]; then 10 | remove_archive=true 11 | shift 12 | fi 13 | 14 | if [ $# -ne 3 ]; then 15 | echo "Usage: $0 [--remove-archive] " 16 | echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell" 17 | echo "With --remove-archive it will remove the archive after successfully un-tarring it." 18 | echo " can be one of: data_aishell, resource_aishell." 19 | fi 20 | 21 | data=$1 22 | url=$2 23 | part=$3 24 | 25 | if [ ! -d "$data" ]; then 26 | echo "$0: no such directory $data" 27 | exit 1; 28 | fi 29 | 30 | part_ok=false 31 | list="data_aishell resource_aishell" 32 | for x in $list; do 33 | if [ "$part" == $x ]; then part_ok=true; fi 34 | done 35 | if ! $part_ok; then 36 | echo "$0: expected to be one of $list, but got '$part'" 37 | exit 1; 38 | fi 39 | 40 | if [ -z "$url" ]; then 41 | echo "$0: empty URL base." 42 | exit 1; 43 | fi 44 | 45 | if [ -f $data/$part/.complete ]; then 46 | echo "$0: data part $part was already successfully extracted, nothing to do." 47 | exit 0; 48 | fi 49 | 50 | # sizes of the archive files in bytes. 51 | sizes="15582913665 1246920" 52 | 53 | if [ -f $data/$part.tgz ]; then 54 | size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}') 55 | size_ok=false 56 | for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done 57 | if ! $size_ok; then 58 | echo "$0: removing existing file $data/$part.tgz because its size in bytes $size" 59 | echo "does not equal the size of one of the archives." 60 | rm $data/$part.gz 61 | else 62 | echo "$data/$part.tgz exists and appears to be complete." 63 | fi 64 | fi 65 | 66 | if [ ! -f $data/$part.tgz ]; then 67 | if ! which wget >/dev/null; then 68 | echo "$0: wget is not installed." 69 | exit 1; 70 | fi 71 | full_url=$url/$part.tgz 72 | echo "$0: downloading data from $full_url. This may take some time, please be patient." 73 | 74 | cd $data 75 | if ! wget --no-check-certificate $full_url; then 76 | echo "$0: error executing wget $full_url" 77 | exit 1; 78 | fi 79 | fi 80 | 81 | cd $data 82 | 83 | if ! tar -xvzf $part.tgz; then 84 | echo "$0: error un-tarring archive $data/$part.tgz" 85 | exit 1; 86 | fi 87 | 88 | touch $data/$part/.complete 89 | 90 | if [ $part == "data_aishell" ]; then 91 | cd $data/$part/wav 92 | for wav in ./*.tar.gz; do 93 | echo "Extracting wav from $wav" 94 | tar -zxf $wav && rm $wav 95 | done 96 | fi 97 | 98 | echo "$0: Successfully downloaded and un-tarred $data/$part.tgz" 99 | 100 | if $remove_archive; then 101 | echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied." 102 | rm $data/$part.tgz 103 | fi 104 | 105 | exit 0; 106 | -------------------------------------------------------------------------------- /egs/aishell1/s5/path.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export MAIN_ROOT=$PWD/../../.. 4 | export TOOLS_ROOT=$MAIN_ROOT/tools 5 | export SRC_ROOT=$MAIN_ROOT/src 6 | export PYTHONPATH=$MAIN_ROOT/src 7 | 8 | -------------------------------------------------------------------------------- /egs/aishell1/s5/prep_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source path.sh 4 | 5 | corpusdir=/data1/Corpora/aishell 6 | url=www.openslr.org/resources/33 7 | 8 | echo "============================================================================" 9 | echo "Step 1: Download AISHELL1 dataset, and prepare KALDI style data directories." 10 | echo "============================================================================" 11 | bash local/download_and_untar.sh $corpusdir $url data_aishell 12 | bash local/aishell_data_prep.sh $corpusdir/data_aishell/wav $corpusdir/data_aishell/transcript 13 | 14 | echo "============================================================================" 15 | echo "Step 2: Format data to json file for training Seq2Seq models." 16 | echo "============================================================================" 17 | echo "Remove space in transcripts" 18 | for x in train dev test; do 19 | if [ ! -f $x/text.org ]; then 20 | cp data/$x/text data/$x/text.org 21 | fi 22 | cat data/$x/text.org | awk '{printf($1" "); for(i=2;i<=NF;i++){printf($i)}; printf("\n")}' > data/$x/text 23 | done 24 | echo "Prepare json format data files." 25 | mkdir -p exp 26 | for x in train dev test; do 27 | python $MAIN_ROOT/src/prepare_data.py --tag file data/$x exp/${x}.json 28 | done 29 | echo "Generate vocabulary" 30 | python $MAIN_ROOT/src/stat_grapheme.py data/train/text exp/aishell1_train_chars.txt 31 | 32 | 33 | -------------------------------------------------------------------------------- /egs/aishell1/s5/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | bash prep_data.sh 5 | bash train.sh 6 | bash decode_test.sh 7 | bash avg.sh 8 | bash score.sh data/test/text exp/exp1/decode_test_avg-last10 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /egs/aishell1/s5/score.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source path.sh 4 | 5 | ref=$1 6 | dir=$2 7 | 8 | cat $ref | python3 -c \ 9 | " 10 | import sys 11 | for line in sys.stdin: 12 | utt,txt = line.strip().split(' ', 1) 13 | txt = ' '.join(list(txt)) 14 | print('{} ({})'.format(txt, utt)) 15 | " > ${dir}/ref.trn 16 | 17 | $MAIN_ROOT/tools/sclite -r ${dir}/ref.trn trn -h ${dir}/hyp.trn trn -c NOASCII -i wsj -o all stdout > ${dir}/result.txt 18 | 19 | echo "write a CER (or TER) result in ${dir}/result.txt" 20 | grep -e Avg -m 2 ${dir}/result.txt 21 | 22 | 23 | -------------------------------------------------------------------------------- /egs/aishell1/s5/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source path.sh 4 | export CUDA_VISIBLE_DEVICES=0 5 | 6 | sys_tag="base" 7 | if [ $# != 0 ]; then 8 | sys_tag=$1 9 | fi 10 | 11 | 12 | if [ "$sys_tag" == "base" ]; then 13 | 14 | echo "Training a baseline transformer ASR system..." 15 | python $MAIN_ROOT/src/train.py config_base.yaml 2>&1 | tee base.log 16 | 17 | elif [ "$sys_tag" == "lm" ]; then 18 | cat data/train/text | cut -d" " -f2- > exp/train_text 19 | cat data/dev/text | cut -d" " -f2- > exp/dev_text 20 | python $MAIN_ROOT/src/lm_train.py config_lm_lstm.yaml 2>&1 | tee base.log 21 | 22 | elif [ "$sys_tag" == "lst" ]; then 23 | echo "" 24 | else 25 | 26 | echo "The sys_tag should be base, lm or lst." 27 | exit 1 28 | fi 29 | -------------------------------------------------------------------------------- /figs/dec_enc_att.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/by2101/OpenASR/c5213d68304a270a0448b2d53adc72b57f4efdb3/figs/dec_enc_att.png -------------------------------------------------------------------------------- /figs/enc_att.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/by2101/OpenASR/c5213d68304a270a0448b2d53adc72b57f4efdb3/figs/enc_att.png -------------------------------------------------------------------------------- /figs/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/by2101/OpenASR/c5213d68304a270a0448b2d53adc72b57f4efdb3/figs/loss.png -------------------------------------------------------------------------------- /src/avg_last_ckpts.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Ye Bai by1993@qq.com 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | import sys 19 | import os 20 | import argparse 21 | import logging 22 | import yaml 23 | import torch 24 | 25 | logging.basicConfig( 26 | level=logging.INFO, 27 | format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s') 28 | 29 | def get_args(): 30 | parser = argparse.ArgumentParser(description=""" 31 | Usage: avg_last_ckpts.py """) 32 | parser.add_argument("expdir", help="The directory contains the checkpoints.") 33 | parser.add_argument("num", type=int, help="The number of models to average") 34 | args = parser.parse_args() 35 | return args 36 | 37 | if __name__ == "__main__": 38 | args = get_args() 39 | fnckpts = [t for t in os.listdir(args.expdir) if t.startswith("ep-") and t.endswith(".pt")] 40 | fnckpts.sort() 41 | fnckpts.reverse() 42 | fnckpts = fnckpts[:args.num] 43 | logging.info("Average checkpoints:\n{}".format("\n".join(fnckpts))) 44 | pkg = torch.load(os.path.join(args.expdir, fnckpts[0]), map_location=lambda storage, loc: storage) 45 | for k in pkg["model"]: 46 | if k.endswith("_state"): 47 | for key in pkg["model"][k].keys(): 48 | pkg["model"][k][key] = torch.zeros_like(pkg["model"][k][key]) 49 | 50 | for fn in fnckpts: 51 | pkg_tmp = torch.load(os.path.join(args.expdir, fn), map_location=lambda storage, loc: storage) 52 | logging.info("Loading {}".format(os.path.join(args.expdir, fn))) 53 | for k in pkg["model"]: 54 | if k.endswith("_state"): 55 | for key in pkg["model"][k].keys(): 56 | pkg["model"][k][key] += pkg_tmp["model"][k][key]/len(fnckpts) 57 | fn_save = os.path.join(args.expdir, "avg-last{}.pt".format(len(fnckpts))) 58 | logging.info("Save averaged model to {}.".format(fn_save)) 59 | torch.save(pkg, fn_save) 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Ye Bai by1993@qq.com 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import os 18 | import logging 19 | import json 20 | import numpy as np 21 | import torch 22 | import torch.utils.data as data 23 | from torch.utils.data.sampler import Sampler 24 | 25 | import utils 26 | 27 | IGNORE_ID = -1 28 | 29 | SOS_SYM = "" 30 | EOS_SYM = "" 31 | UNK_SYM = "" 32 | SPECIAL_SYM_SET = {SOS_SYM, EOS_SYM, UNK_SYM} 33 | 34 | class CharTokenizer(object): 35 | def __init__(self, fn_vocab): 36 | with open(fn_vocab, 'r') as f: 37 | units = f.read().strip().split('\n') 38 | units = [UNK_SYM, SOS_SYM, EOS_SYM] + units 39 | self.unit2id = {k:v for v,k in enumerate(units)} 40 | self.id2unit = units 41 | 42 | def to_id(self, unit): 43 | return self.unit2id[unit] 44 | 45 | def to_unit(self, id): 46 | return self.id2unit[id] 47 | 48 | def encode(self, textline): 49 | return [self.unit2id[char] 50 | if char in self.unit2id 51 | else self.unit2id[UNK_SYM] 52 | for char in list(textline.strip())] 53 | 54 | def decode(self, ids, split_token=True, remove_special_sym=True): 55 | syms = [self.id2unit[i] for i in ids] 56 | if remove_special_sym: 57 | syms = [sym for sym in syms if sym not in SPECIAL_SYM_SET] 58 | if split_token: 59 | return " ".join(syms) 60 | return "".join(syms) 61 | 62 | def unit_num(self): 63 | return len(self.unit2id) 64 | 65 | 66 | def gen_casual_targets(idslist, maxlen, sos_id, eos_id): 67 | ids_with_sym_list = [[sos_id]+ids+[eos_id] for ids in idslist] 68 | B = len(idslist) 69 | padded_rawids = -torch.ones(B, maxlen+1).long() 70 | 71 | for b, ids in enumerate(ids_with_sym_list): 72 | if len(ids) > maxlen: 73 | logging.warn("ids length {} vs. maxlen {}, cut it.".format(len(ids), maxlen)) 74 | l = min(len(ids), maxlen) 75 | padded_rawids[b, :l] = torch.tensor(ids).long()[:l] 76 | paddings = (padded_rawids==-1).long() 77 | padded_rawids = padded_rawids*(1-paddings) + eos_id*paddings # modify -1 to eos_id 78 | 79 | labels = padded_rawids[:, 1:] 80 | ids = padded_rawids[:, :-1] 81 | paddings = paddings[:, 1:] # the padding is for labels 82 | 83 | return ids, labels, paddings 84 | 85 | 86 | class TextLineByLineDataset(data.Dataset): 87 | def __init__(self, fn): 88 | super(TextLineByLineDataset, self).__init__() 89 | with open(fn, 'r') as f: 90 | self.data = f.read().strip().split('\n') 91 | 92 | def __getitem__(self, index): 93 | return self.data[index] 94 | 95 | def __len__(self): 96 | return len(self.data) 97 | 98 | 99 | class SpeechDataset(data.Dataset): 100 | def __init__(self, data_json_path, reverse=False): 101 | super(SpeechDataset, self).__init__() 102 | with open(data_json_path, 'rb') as f: 103 | data = json.load(f) 104 | self.data = sorted(data, key=lambda x: float(x["duration"])) 105 | if reverse: 106 | self.data.reverse() 107 | 108 | def __getitem__(self, index): 109 | return self.data[index] 110 | 111 | def __len__(self): 112 | return len(self.data) 113 | 114 | 115 | class KaldiDataset(data.Dataset): 116 | def __init__(self, data_dir, tag="file"): 117 | super(KaldiDataset, self).__init__() 118 | self.data = [] 119 | if os.path.exists(os.path.join(data_dir, 'feats.scp')): 120 | p = os.path.join(data_dir, 'feats.scp') 121 | elif os.path.exists(os.path.join(data_dir, 'wav.scp')): 122 | p = os.path.join(data_dir, 'wav.scp') 123 | else: 124 | raise ValueError("None of feats.scp or wav.scp exist.") 125 | with open(p, 'r') as f: 126 | for line in f: 127 | utt, path = line.strip().split(' ', 1) 128 | path = "{}:{}".format(tag, path) 129 | d = (utt, path) 130 | self.data.append(d) 131 | 132 | def __getitem__(self, index): 133 | return self.data[index] 134 | 135 | def __len__(self): 136 | return len(self.data) 137 | 138 | 139 | class TimeBasedSampler(Sampler): 140 | def __init__(self, dataset, duration=200, ngpu=1, shuffle=False): # 200s 141 | self.dataset = dataset 142 | self.dur = duration 143 | self.shuffle = shuffle 144 | 145 | batchs = [] 146 | batch = [] 147 | batch_dur = 0. 148 | for idx in range(len(self.dataset)): 149 | batch.append(idx) 150 | batch_dur += self.dataset[idx]["duration"] 151 | if batch_dur >= self.dur and len(batch)%ngpu==0: 152 | # To make the numbers of batchs are equal for each GPU. 153 | batchs.append(batch) 154 | batch = [] 155 | batch_dur = 0. 156 | if batch: 157 | if len(batch)%ngpu==0: 158 | batchs.append(batch) 159 | else: 160 | b = len(batch) 161 | batchs.append(batch[b//ngpu*ngpu:]) 162 | self.batchs = batchs 163 | 164 | def __iter__(self): 165 | if self.shuffle: 166 | np.random.shuffle(self.batchs) 167 | for b in self.batchs: 168 | yield b 169 | 170 | def __len__(self): 171 | return len(self.batchs) 172 | 173 | 174 | def load_wave_batch(paths): 175 | waveforms = [] 176 | lengths = [] 177 | for path in paths: 178 | sample_rate, waveform = utils.load_wave(path) 179 | waveform = torch.from_numpy(waveform) 180 | waveforms.append(waveform) 181 | lengths.append(waveform.shape[0]) 182 | max_length = max(lengths) 183 | padded_waveforms = torch.zeros(len(lengths), max_length) 184 | for i in range(len(lengths)): 185 | padded_waveforms[i, :lengths[i]] += waveforms[i] 186 | return padded_waveforms, torch.tensor(lengths).long() 187 | 188 | 189 | def load_feat_batch(paths): 190 | features = [] 191 | lengths = [] 192 | for path in paths: 193 | feature = utils.load_feat(path) 194 | feature = torch.from_numpy(feature) 195 | features.append(feature) 196 | lengths.append(feature.shape[0]) 197 | max_length = max(lengths) 198 | dim = feature.shape[1] 199 | padded_features = torch.zeros(len(lengths), max_length, dim) 200 | for i in range(len(lengths)): 201 | padded_features[i, :lengths[i], :] += features[i] 202 | return padded_features, torch.tensor(lengths).long() 203 | 204 | 205 | class TextCollate(object): 206 | def __init__(self, tokenizer, maxlen): 207 | self.tokenizer = tokenizer 208 | self.maxlen = maxlen 209 | return 210 | 211 | def __call__(self, batch): 212 | timer = utils.Timer() 213 | timer.tic() 214 | rawids_list = [self.tokenizer.encode(t) for t in batch] 215 | ids, labels, paddings = gen_casual_targets(rawids_list, self.maxlen, 216 | self.tokenizer.to_id(SOS_SYM), self.tokenizer.to_id(EOS_SYM)) 217 | logging.debug("Text Processing Time: {}s".format(timer.toc())) 218 | return ids, labels, paddings 219 | 220 | 221 | class WaveCollate(object): 222 | def __init__(self, tokenizer, maxlen): 223 | self.tokenizer = tokenizer 224 | self.maxlen = maxlen 225 | return 226 | 227 | def __call__(self, batch): 228 | utts = [d["utt"] for d in batch] 229 | paths = [d["path"] for d in batch] 230 | trans = [d["transcript"] for d in batch] 231 | timer = utils.Timer() 232 | timer.tic() 233 | padded_waveforms, wave_lengths = load_wave_batch(paths) 234 | logging.debug("Wave Loading Time: {}s".format(timer.toc())) 235 | timer.tic() 236 | rawids_list = [self.tokenizer.encode(t) for t in trans] 237 | ids, labels, paddings = gen_casual_targets(rawids_list, self.maxlen, 238 | self.tokenizer.to_id(SOS_SYM), self.tokenizer.to_id(EOS_SYM)) 239 | logging.debug("Transcription Processing Time: {}s".format(timer.toc())) 240 | return utts, padded_waveforms, wave_lengths, ids, labels, paddings 241 | 242 | 243 | class FeatureCollate(object): 244 | def __init__(self, tokenizer, maxlen): 245 | self.tokenizer = tokenizer 246 | self.maxlen = maxlen 247 | return 248 | 249 | def __call__(self, batch): 250 | utts = [d["utt"] for d in batch] 251 | paths = [d["path"] for d in batch] 252 | trans = [d["transcript"] for d in batch] 253 | timer = utils.Timer() 254 | timer.tic() 255 | padded_features, feature_lengths = load_feat_batch(paths) 256 | logging.debug("Feature Loading Time: {}s".format(timer.toc())) 257 | timer.tic() 258 | rawids_list = [self.tokenizer.encode(t) for t in trans] 259 | ids, labels, paddings = gen_casual_targets(rawids_list, self.maxlen, 260 | self.tokenizer.to_id(SOS_SYM), self.tokenizer.to_id(EOS_SYM)) 261 | logging.debug("Transcription Processing Time: {}s".format(timer.toc())) 262 | return utts, padded_features, feature_lengths, ids, labels, paddings 263 | 264 | 265 | def kaldi_wav_collate(batch): 266 | utts = [d[0] for d in batch] 267 | paths = [d[1] for d in batch] 268 | padded_data, lengths = load_wave_batch(paths) 269 | return utts, padded_data, lengths 270 | 271 | 272 | def kaldi_feat_collate(batch): 273 | utts = [d[0] for d in batch] 274 | paths = [d[1] for d in batch] 275 | padded_data, lengths = load_feat_batch(paths) 276 | return utts, padded_data, lengths 277 | -------------------------------------------------------------------------------- /src/data_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Ye Bai by1993@qq.com 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import os 18 | import sys 19 | import numpy as np 20 | import torch.utils.data 21 | import data 22 | import pdb 23 | 24 | os.chdir(os.path.abspath(os.path.dirname(__file__))) 25 | 26 | def test_dataset(): 27 | dataset = data.SpeechDataset("testdata/test.json") 28 | print(dataset[0]) 29 | 30 | def test_dataloader(): 31 | dataset = data.SpeechDataset("testdata/test.json") 32 | sampler = data.TimeBasedSampler(dataset, 5) 33 | tokenizer = data.CharTokenizer("testdata/train_chars.txt") 34 | collate = data.WaveCollate(tokenizer, 60) 35 | dataloader = torch.utils.data.DataLoader(dataset, 36 | batch_sampler=sampler, collate_fn=collate, shuffle=False) 37 | dataiter = iter(dataloader) 38 | batch = next(dataiter) 39 | utts, padded_waveforms, wave_lengths, ids, labels, paddings = batch 40 | 41 | print(utts[0]) 42 | print(ids[0]) 43 | print(labels[0]) 44 | print(paddings[0]) 45 | 46 | 47 | if __name__ == "__main__": 48 | test_dataloader() 49 | test_dataset() 50 | -------------------------------------------------------------------------------- /src/decode.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Ye Bai by1993@qq.com 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | import sys 17 | import os 18 | import argparse 19 | import logging 20 | import yaml 21 | import numpy as np 22 | import torch 23 | 24 | logging.basicConfig( 25 | level=logging.DEBUG, 26 | format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s') 27 | 28 | from third_party import kaldi_io as kio 29 | import utils 30 | import data 31 | import sp_layers 32 | import encoder_layers 33 | import decoder_layers 34 | import models 35 | 36 | 37 | def get_args(): 38 | parser = argparse.ArgumentParser(description=""" 39 | Usage: feedforward.py """) 40 | parser.add_argument("model_pkg", help="path to model package.") 41 | parser.add_argument("vocab_file", help="path to vocabulary file.") 42 | parser.add_argument("data_dir", help="data directory") 43 | parser.add_argument("scptag", help="tag of wav.scp. unused for feats.scp") 44 | parser.add_argument("output", help="output") 45 | parser.add_argument("--feed-batchsize", type=int, default=20, help="batch_size") 46 | parser.add_argument("--nbest", type=int, default=13, help="nbest") 47 | parser.add_argument("--maxlen", type=int, default=80, help="max_length") 48 | parser.add_argument("--use_gpu", type=utils.str2bool, default=False, help="whether to use gpu.") 49 | args = parser.parse_args() 50 | return args 51 | 52 | 53 | 54 | if __name__ == "__main__": 55 | timer = utils.Timer() 56 | timer.tic() 57 | args = get_args() 58 | 59 | if args.output.strip() == "-": 60 | fd = sys.stdout.buffer 61 | else: 62 | fd = open(args.output, 'w', encoding="utf8") 63 | 64 | logging.info("Load package from {}.".format(args.model_pkg)) 65 | pkg = torch.load(args.model_pkg, map_location=lambda storage, loc: storage) 66 | splayer = sp_layers.SPLayer(pkg["model"]["splayer_config"]) 67 | encoder = encoder_layers.Transformer(pkg["model"]["encoder_config"]) 68 | decoder = decoder_layers.TransformerDecoder(pkg["model"]["decoder_config"]) 69 | 70 | model = models.Model(splayer, encoder, decoder) 71 | logging.info("\nModel info:\n{}".format(model)) 72 | model.restore(pkg["model"]) 73 | if args.use_gpu: 74 | model = model.cuda() 75 | model.eval() 76 | if args.vocab_file.endswith("wpm"): 77 | tokenizer = data.WpmTokenizer(args.vocab_file) 78 | else: 79 | tokenizer = data.CharTokenizer(args.vocab_file) 80 | test_set = data.KaldiDataset(args.data_dir, tag=args.scptag) 81 | 82 | if os.path.exists(os.path.join(args.data_dir, 'wav.scp')): 83 | offline = False 84 | test_loader = torch.utils.data.DataLoader(test_set, 85 | collate_fn=data.kaldi_wav_collate, shuffle=False, batch_size=args.feed_batchsize) 86 | elif os.path.exists(os.path.join(args.data_dir, 'feats.scp')): 87 | offline = True 88 | test_loader = torch.utils.data.DataLoader(test_set, 89 | collate_fn=data.kaldi_feat_collate, shuffle=False, batch_size=args.feed_batchsize) 90 | logging.info("Start feedforward...") 91 | 92 | tot_timer = utils.Timer() 93 | tot_utt = 0 94 | tot_timer.tic() 95 | for utts, padded_waveforms, wave_lengths in test_loader: 96 | wave_time = 0 97 | if not offline: 98 | wave_time += wave_lengths.sum().numpy()/model.splayer.sample_rate 99 | else: 100 | wave_time += wave_lengths.sum().numpy()/100. # by default, 100 frames cost 1 sec. 101 | if next(model.parameters()).is_cuda: 102 | padded_waveforms = padded_waveforms.cuda() 103 | wave_lengths = wave_lengths.cuda() 104 | with torch.no_grad(): 105 | target_ids, scores = model.decode(padded_waveforms, wave_lengths, nbest_keep=args.nbest, maxlen=args.maxlen) 106 | all_ids_batch = target_ids.cpu().numpy() 107 | all_score_batch = scores.cpu().numpy() 108 | for i in range(all_ids_batch.shape[0]): 109 | utt = utts[i] 110 | msg = "Results for {}:\n".format(utt) 111 | for h in range(all_ids_batch.shape[1]): 112 | hyp = tokenizer.decode(all_ids_batch[i, h]) 113 | score = all_score_batch[i, h] 114 | msg += "top{}: {} score: {:.10f}\n".format(h+1, hyp, score) 115 | if h == 0: 116 | fd.write("{} ({})\n".format(hyp, utt)) 117 | logging.info("\n"+msg) 118 | tot_utt += len(utts) 119 | logging.info("Prossesed {} utterances in {:.3f} s".format(tot_utt, tot_timer.toc())) 120 | tot_time = tot_timer.toc() 121 | logging.info("Decoded {} utterances. The time cost is {:.2f} min." 122 | " Avg time cost is {:.2f} per utt.".format(tot_utt, tot_time/60., tot_time/tot_utt)) 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /src/decoder_layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Ye Bai by1993@qq.com 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import logging 18 | import math 19 | import chardet 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | from torch.nn import init 24 | from torch.nn.modules.normalization import LayerNorm 25 | 26 | from third_party import transformer 27 | import utils 28 | import modules 29 | 30 | 31 | class TransformerDecoder(nn.Module): 32 | def __init__(self, config): 33 | super(TransformerDecoder, self).__init__() 34 | self.config = config 35 | 36 | self.d_model = config["d_model"] 37 | self.nhead = config["nhead"] 38 | self.num_layers = config["num_layers"] 39 | self.encoder_dim = config["encoder_dim"] 40 | self.dim_feedforward = config["dim_feedforward"] 41 | self.vocab_size = config["vocab_size"] 42 | self.dropout_rate = config["dropout_rate"] 43 | self.activation = config["activation"] 44 | 45 | self.emb = nn.Embedding(self.vocab_size, self.d_model) 46 | self.emb_scale = self.d_model ** 0.5 47 | self.pe = modules.PositionalEncoding(self.d_model) 48 | self.dropout = nn.Dropout(self.dropout_rate) 49 | 50 | transformer_decoder_layer = transformer.TransformerDecoderLayer( 51 | d_model=self.d_model, 52 | nhead=self.nhead, 53 | dim_feedforward=self.dim_feedforward, 54 | dropout=self.dropout_rate, 55 | activation=self.activation) 56 | self.transformer_block = transformer.TransformerDecoder(transformer_decoder_layer, 57 | self.num_layers) 58 | 59 | self.output_affine = nn.Linear(self.d_model, self.vocab_size) 60 | nn.init.xavier_normal_(self.output_affine.weight) 61 | self.emb.weight = self.output_affine.weight # tying weight 62 | 63 | def forward(self, encoder_outputs, encoder_output_lengths, decoder_inputs, decoder_input_lengths, return_atten=False): 64 | 65 | B, T_e, D_e = encoder_outputs.shape 66 | encoder_outputs = encoder_outputs.permute(1, 0, 2) # [S, B, D_e] 67 | 68 | _, T_d = decoder_inputs.shape 69 | 70 | memory_key_padding_mask = utils.get_transformer_padding_byte_masks( 71 | B, T_e, encoder_output_lengths).to(encoder_outputs.device) 72 | tgt_key_padding_mask = utils.get_transformer_padding_byte_masks( 73 | B, T_d, decoder_input_lengths).to(encoder_outputs.device) 74 | casual_masks = utils.get_transformer_casual_masks(T_d).to(encoder_outputs.device) 75 | 76 | outputs = self.emb(decoder_inputs) * self.emb_scale 77 | outputs = self.pe(outputs) 78 | outputs = self.dropout(outputs) 79 | outputs = outputs.permute(1, 0, 2) 80 | 81 | if return_atten: 82 | outputs, decoder_atten_tuple_list = self.transformer_block(outputs, encoder_outputs, 83 | memory_mask=None, memory_key_padding_mask=memory_key_padding_mask, 84 | tgt_key_padding_mask=tgt_key_padding_mask, tgt_mask=casual_masks, 85 | return_atten=True) 86 | else: 87 | outputs = self.transformer_block(outputs, encoder_outputs, 88 | memory_mask=None, memory_key_padding_mask=memory_key_padding_mask, 89 | tgt_key_padding_mask=tgt_key_padding_mask, tgt_mask=casual_masks, 90 | return_atten=False) 91 | outputs = outputs.permute(1, 0, 2) 92 | outputs = self.output_affine(outputs) 93 | 94 | if return_atten: 95 | return outputs, decoder_atten_tuple_list 96 | return outputs 97 | -------------------------------------------------------------------------------- /src/encoder_layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Ye Bai by1993@qq.com 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import math 18 | from collections import OrderedDict 19 | import logging 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | from torch.nn.init import xavier_uniform_ 24 | from third_party import transformer 25 | import utils 26 | import modules 27 | from torch.nn.modules.normalization import LayerNorm 28 | 29 | 30 | class Conv1dSubsample(torch.nn.Module): 31 | # the same as stack frames 32 | def __init__(self, input_dim, d_model, context_width, subsample): 33 | super(Conv1dSubsample, self).__init__() 34 | 35 | self.conv = nn.Conv1d(input_dim, d_model, context_width, stride=self.subsample) 36 | self.conv_norm = LayerNorm(self.d_model) 37 | self.subsample = subsample 38 | self.context_width = context_width 39 | 40 | def forward(self, feats, feat_lengths): 41 | outputs = self.conv(feats.permute(0, 2, 1)) 42 | outputs = output.permute(0, 2, 1) 43 | outputs = self.conv_norm(outputs) 44 | output_lengths = ((feat_lengths - 1*(self.context_width-1)-1)/self.subsample + 1).long() 45 | return outputs, output_lengths 46 | 47 | 48 | class Conv2dSubsample(torch.nn.Module): 49 | # Follow ESPNet configuration 50 | def __init__(self, input_dim, d_model): 51 | super(Conv2dSubsample, self).__init__() 52 | self.conv = torch.nn.Sequential( 53 | torch.nn.Conv2d(1, 32, 3, 2), 54 | torch.nn.ReLU(), 55 | torch.nn.Conv2d(32, 32, 3, 2), 56 | torch.nn.ReLU() 57 | ) 58 | self.affine = torch.nn.Linear(32 * (((input_dim - 1) // 2 - 1) // 2), d_model) 59 | 60 | def forward(self, feats, feat_lengths): 61 | outputs = feats.unsqueeze(1) # [B, C, T, D] 62 | outputs = self.conv(outputs) 63 | B, C, T, D = outputs.size() 64 | outputs = outputs.permute(0, 2, 1, 3).contiguous().view(B, T, C*D) 65 | outputs = self.affine(outputs) 66 | output_lengths = (((feat_lengths-1) / 2 - 1) / 2).long() 67 | return outputs, output_lengths 68 | 69 | 70 | class Conv2dSubsampleV2(torch.nn.Module): 71 | def __init__(self, input_dim, d_model, layer_num=2): 72 | super(Conv2dSubsampleV2, self).__init__() 73 | assert layer_num >= 1 74 | self.layer_num = layer_num 75 | layers = [("subsample/conv0", torch.nn.Conv2d(1, 32, 3, (2, 1))), 76 | ("subsample/relu0", torch.nn.ReLU())] 77 | for i in range(layer_num-1): 78 | layers += [ 79 | ("subsample/conv{}".format(i+1), torch.nn.Conv2d(32, 32, 3, (2, 1))), 80 | ("subsample/relu{}".format(i+1), torch.nn.ReLU()) 81 | ] 82 | layers = OrderedDict(layers) 83 | self.conv = torch.nn.Sequential(layers) 84 | self.affine = torch.nn.Linear(32 * (input_dim-2*layer_num), d_model) 85 | 86 | def forward(self, feats, feat_lengths): 87 | outputs = feats.unsqueeze(1) # [B, C, T, D] 88 | outputs = self.conv(outputs) 89 | B, C, T, D = outputs.size() 90 | outputs = outputs.permute(0, 2, 1, 3).contiguous().view(B, T, C*D) 91 | outputs = self.affine(outputs) 92 | output_lengths = feat_lengths 93 | for _ in range(self.layer_num): 94 | output_lengths = ((output_lengths-1) / 2).long() 95 | return outputs, output_lengths 96 | 97 | 98 | class Transformer(torch.nn.Module): 99 | def __init__(self, config): 100 | super(Transformer, self).__init__() 101 | self.config = config 102 | 103 | self.input_dim = config["input_dim"] 104 | self.d_model = config["d_model"] 105 | self.nhead = config["nhead"] 106 | self.dim_feedforward = config["dim_feedforward"] 107 | self.num_layers = config["num_layers"] 108 | self.dropout_rate = config["dropout_rate"] 109 | self.activation = config["activation"] 110 | self.subconf = config["sub"] 111 | if self.subconf["type"] == "ConvV1": 112 | self.sub = Conv2dSubsample(self.input_dim, self.d_model) 113 | elif self.subconf["type"] == "ConvV2": 114 | self.sub = Conv2dSubsampleV2(self.input_dim, self.d_model, self.subconf["layer_num"]) 115 | elif self.subconf["type"] == "Stack": 116 | self.context_width = config["context_width"] 117 | self.subsample = config["subsample"] 118 | self.sub = Conv1dSubsample(self.input_dim, self.d_model, self.context_width, self.subsample) 119 | 120 | self.scale = self.d_model ** 0.5 121 | 122 | self.pe = modules.PositionalEncoding(self.d_model) 123 | self.dropout = nn.Dropout(self.dropout_rate) 124 | encoder_norm = LayerNorm(self.d_model) 125 | encoder_layer = transformer.TransformerEncoderLayer(d_model=self.d_model, 126 | nhead=self.nhead, dim_feedforward=self.dim_feedforward, 127 | dropout=self.dropout_rate, activation=self.activation) 128 | self.transformer_encoder = transformer.TransformerEncoder(encoder_layer, self.num_layers, encoder_norm) 129 | 130 | def forward(self, feats, feat_lengths, return_atten=False): 131 | outputs, output_lengths = self.sub(feats, feat_lengths) 132 | outputs = self.dropout(self.pe(outputs)) 133 | 134 | B, T, D_o = outputs.shape 135 | src_key_padding_mask = utils.get_transformer_padding_byte_masks(B, T, output_lengths).to(outputs.device) 136 | outputs = outputs.permute(1, 0, 2) 137 | if return_atten: 138 | outputs, self_atten_list = self.transformer_encoder(outputs, 139 | src_key_padding_mask=src_key_padding_mask, 140 | return_atten=True) 141 | else: 142 | outputs = self.transformer_encoder(outputs, 143 | src_key_padding_mask=src_key_padding_mask, 144 | return_atten=False) 145 | outputs = outputs.permute(1, 0, 2) 146 | if return_atten: 147 | return outputs, output_lengths, self_atten_list 148 | return outputs, output_lengths -------------------------------------------------------------------------------- /src/encoder_layers_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import encoder_layers 3 | 4 | def test_Conv2dSubsampleV2(): 5 | layer = encoder_layers.Conv2dSubsampleV2(80, 512, 3) 6 | feats = torch.rand(3, 3000, 80) 7 | lengths = torch.tensor([100, 2899, 3000]).long() 8 | outputs, output_lengths = layer(feats, lengths) 9 | print("outputs.shape", outputs.shape) 10 | print("input_lengths", lengths) 11 | print("output_lengths", output_lengths) 12 | 13 | if __name__ == "__main__": 14 | test_Conv2dSubsampleV2() 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /src/lm_layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Ye Bai by1993@qq.com 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import logging 18 | import math 19 | import chardet 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.init as init 23 | import torch.nn.functional as F 24 | from torch.nn import init 25 | from torch.nn.modules.normalization import LayerNorm 26 | 27 | from third_party import transformer 28 | import modules 29 | import utils 30 | 31 | class LSTM(nn.Module): 32 | def __init__(self, config): 33 | super(LSTM, self).__init__() 34 | self.config = config 35 | 36 | self.vocab_size = config["vocab_size"] 37 | self.hidden_size = config["hidden_size"] 38 | self.num_layers = config["num_layers"] 39 | self.dropout_rate = config["dropout_rate"] 40 | self.emb = nn.Embedding(self.vocab_size, self.hidden_size) 41 | self.rnn = nn.LSTM(self.hidden_size, self.hidden_size, num_layers=self.num_layers, dropout=self.dropout_rate, batch_first=True) 42 | self.dropout1 = nn.Dropout(self.dropout_rate) 43 | self.dropout2 = nn.Dropout(self.dropout_rate) 44 | self.output_affine = nn.Linear(self.hidden_size, self.vocab_size, bias=False) 45 | self.emb.weight = self.output_affine.weight 46 | 47 | def forward(self, ids, lengths=None): 48 | outputs = self.emb(ids) 49 | outputs = self.dropout1(outputs) 50 | outputs, (h, c) = self.rnn(outputs) 51 | outputs = self.dropout2(outputs) 52 | outputs = self.output_affine(outputs) 53 | return outputs 54 | 55 | def reset_parameters(self): 56 | for name, param in self.named_parameters(): 57 | if 'weight_ih' in name: 58 | init.xavier_uniform_(param.data) 59 | elif 'weight_hh' in name: 60 | init.orthogonal_(param.data) 61 | elif 'bias' in name: 62 | param.data.fill_(0) 63 | init.uniform_(self.emb.weight, a=-0.01, b=0.01) 64 | 65 | 66 | class TransformerLM(nn.Module): 67 | def __init__(self, config): 68 | super(TransformerLM, self).__init__() 69 | self.config = config 70 | 71 | self.vocab_size = config["vocab_size"] 72 | self.d_model = config["d_model"] 73 | self.nhead = config["nhead"] 74 | self.num_layers = config["num_layers"] 75 | self.dim_feedforward = config["dim_feedforward"] 76 | self.activation = config["activation"] 77 | self.dropout_rate = config["dropout_rate"] 78 | 79 | 80 | self.dropout = nn.Dropout(self.dropout_rate) 81 | self.scale = self.d_model ** 0.5 82 | self.pe = modules.PositionalEncoding(self.d_model) 83 | self.emb = nn.Embedding(self.vocab_size, self.d_model) 84 | encoder_layer = transformer.TransformerEncoderLayer(d_model=self.d_model, 85 | nhead=self.nhead, dim_feedforward=self.dim_feedforward, 86 | dropout=self.dropout_rate, activation=self.activation) 87 | self.transformer_encoder = transformer.TransformerEncoder(encoder_layer, self.num_layers) 88 | self.output_affine = nn.Linear(self.d_model, self.vocab_size, bias=False) 89 | self.emb.weight = self.output_affine.weight 90 | 91 | def forward(self, ids, lengths, return_atten=False): 92 | B, T = ids.shape 93 | 94 | key_padding_mask = utils.get_transformer_padding_byte_masks( 95 | B, T, lengths).to(ids.device) 96 | casual_masks = utils.get_transformer_casual_masks(T).to(ids.device) 97 | 98 | outputs = self.emb(ids) * self.scale 99 | outputs = self.pe(outputs) 100 | outputs = self.dropout(outputs) 101 | outputs = outputs.permute(1, 0, 2) 102 | 103 | outputs, self_atten_list = self.transformer_encoder(outputs, 104 | mask=casual_masks, 105 | src_key_padding_mask=key_padding_mask, 106 | return_atten=True) 107 | outputs = self.output_affine(outputs) 108 | if return_atten: 109 | return outputs, self_atten_list 110 | return outputs 111 | 112 | 113 | -------------------------------------------------------------------------------- /src/lm_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Ye Bai by1993@qq.com 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import sys 18 | import os 19 | import argparse 20 | import logging 21 | import yaml 22 | import torch 23 | 24 | if "LAS_LOG_LEVEL" in os.environ: 25 | LOG_LEVEL = os.environ["LAS_LOG_LEVEL"] 26 | else: 27 | LOG_LEVEL = "INFO" 28 | if LOG_LEVEL == "DEBUG": 29 | logging.basicConfig( 30 | level=logging.DEBUG, 31 | format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s') 32 | else: 33 | logging.basicConfig( 34 | level=logging.INFO, 35 | format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s') 36 | 37 | import utils 38 | import data 39 | import lm_layers 40 | import models 41 | from trainer import LmTrainer 42 | 43 | 44 | def get_args(): 45 | parser = argparse.ArgumentParser(description=""" 46 | Usage: lm_train.py """) 47 | parser.add_argument("config", help="path to config file") 48 | parser.add_argument('--continue-training', type=utils.str2bool, default=False, 49 | help='Continue training from last_model.pt.') 50 | args = parser.parse_args() 51 | return args 52 | 53 | 54 | if __name__ == "__main__": 55 | timer = utils.Timer() 56 | args = get_args() 57 | timer.tic() 58 | with open(args.config) as f: 59 | config = yaml.load(f, Loader=yaml.FullLoader) 60 | dataconfig = config["data"] 61 | trainingconfig = config["training"] 62 | modelconfig = config["model"] 63 | 64 | training_set = data.TextLineByLineDataset(dataconfig["trainset"]) 65 | valid_set = data.TextLineByLineDataset(dataconfig["devset"]) 66 | if "vocab_path" in dataconfig: 67 | tokenizer = data.CharTokenizer(dataconfig["vocab_path"]) 68 | else: 69 | raise ValueError("Unknown tokenizer.") 70 | 71 | 72 | modelconfig['vocab_size'] = tokenizer.unit_num() 73 | collate = data.TextCollate(tokenizer, dataconfig["maxlen"]) 74 | 75 | ngpu = 1 76 | if "multi_gpu" in trainingconfig and trainingconfig["multi_gpu"] == True: 77 | ngpu = torch.cuda.device_count() 78 | 79 | tr_loader = torch.utils.data.DataLoader(training_set, 80 | collate_fn=collate, batch_size=trainingconfig['batch_size'], shuffle=True, num_workers=dataconfig["fetchworker_num"]) 81 | cv_loader = torch.utils.data.DataLoader(valid_set, 82 | collate_fn=collate, batch_size=trainingconfig['batch_size'], shuffle=False, num_workers=dataconfig["fetchworker_num"]) 83 | 84 | if modelconfig["type"] == "lstm": 85 | lmlayer = lm_layers.LSTM(modelconfig) 86 | else: 87 | raise ValueError("Unknown model") 88 | 89 | model = models.LM(lmlayer) 90 | logging.info("\nModel info:\n{}".format(model)) 91 | 92 | if args.continue_training: 93 | logging.info("Load package from {}.".format(os.path.join(trainingconfig["exp_dir"], "last-ckpt.pt"))) 94 | pkg = torch.load(os.path.join(trainingconfig["exp_dir"], "last-ckpt.pt")) 95 | model.restore(pkg["model"]) 96 | 97 | if "multi_gpu" in trainingconfig and trainingconfig["multi_gpu"] == True: 98 | logging.info("Let's use {} GPUs!".format(torch.cuda.device_count())) 99 | model = torch.nn.DataParallel(model) 100 | 101 | model = model.cuda() 102 | 103 | trainer = LmTrainer(model, trainingconfig, tr_loader, cv_loader) 104 | 105 | if args.continue_training: 106 | logging.info("Restore trainer states...") 107 | trainer.restore(pkg) 108 | logging.info("Start training...") 109 | trainer.train() 110 | logging.info("Total time: {:.4f} mins".format(timer.toc()/60.)) 111 | 112 | -------------------------------------------------------------------------------- /src/metric.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Ye Bai by1993@qq.com 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | import utils 21 | 22 | class MetricSummarizer(object): 23 | def __init__(self): 24 | self.metrics = {} # name: (loss, weight) 25 | self.metric_names = [] 26 | self.summarized = {} 27 | 28 | def register_metric(self, name, display=False, visual=False, optim=False): 29 | self.metric_names.append({ 30 | "name": name, 31 | "display": display, 32 | "visual": visual, 33 | "optim": optim, 34 | }) 35 | 36 | def reset_metrics(self): 37 | del self.metrics 38 | del self.summarized 39 | self.metrics = {} 40 | for item in self.metric_names: 41 | self.metrics[item["name"]] = None 42 | self.summarized = {} 43 | 44 | def get_metric_by_name(self, name): 45 | return self.metrics[name] 46 | 47 | def update_metric(self, name, loss, weight=1.0): 48 | if name in self.metrics: 49 | self.metrics[name] = (loss, weight) 50 | else: 51 | raise ValueError("The metric {} is not registered.".format(name)) 52 | 53 | def summarize(self): 54 | self.summarized = {} # name: torch.Tensor 55 | for key in self.metrics.keys(): 56 | if self.metrics[key] is None: 57 | logging.warn("{} is not updated. Skip it.".format(key)) 58 | continue 59 | item = self.metrics[key] 60 | self.summarized[key] = item[0] * item[1] 61 | 62 | def collect_loss(self): 63 | loss = 0 64 | for item in self.metric_names: 65 | key = item['name'] 66 | if item["optim"] == True: 67 | v = self.metrics[key] 68 | loss += v[0] * v[1] 69 | return loss 70 | 71 | def fetch_scalers(self, use="display"): 72 | fetched = [] 73 | for item in self.metric_names: 74 | if item[use] == True: 75 | if item["name"] not in self.summarized: 76 | logging.warn("{} is not summarized. Skip it.".format(item["name"])) 77 | continue 78 | fetched.append( 79 | (item["name"], self.summarized[item["name"]])) 80 | return fetched 81 | 82 | def display_msg(self, fetched, max_item_one_line=3): 83 | msglist = [] 84 | msglists = [] 85 | cnt = 0 86 | for name, value in fetched: 87 | if isinstance(value, torch.Tensor): 88 | msglist.append("{}: {:.7f}".format(name, value.item())) 89 | else: 90 | msglist.append("{}: {:.7f}".format(name, value)) 91 | cnt += 1 92 | if cnt == max_item_one_line: 93 | msglists.append(msglist) 94 | msglist = [] 95 | cnt = 0 96 | if msglist: 97 | msglists.append(msglist) 98 | l = [] 99 | for msglist in msglists: 100 | l.append(" | ".join(msglist)) 101 | msg = "\n".join(l) 102 | return msg 103 | 104 | def visualize_scalers(self, fetched, step): 105 | for name, value in fetched: 106 | if isinstance(value, torch.Tensor): 107 | utils.visualizer.add_scalar(name, value.item(), step) 108 | else: 109 | utils.visualizer.add_scalar(name, value, step) 110 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Ye Bai by1993@qq.com 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import logging 18 | import numpy as np 19 | import torch 20 | import torch.nn.functional as F 21 | from torch.nn.init import xavier_uniform_ 22 | import utils 23 | import lm_layers 24 | import pdb 25 | 26 | 27 | class Model(torch.nn.Module): 28 | def __init__(self, splayer, encoder, decoder, lm=None): 29 | super(Model, self).__init__() 30 | self.splayer = splayer 31 | self.encoder = encoder 32 | self.decoder = decoder 33 | self._reset_parameters() 34 | 35 | self.lm = lm # this must be set after parameter initialization 36 | 37 | 38 | def forward(self, batch_wave, lengths, target_ids, target_labels=None, target_paddings=None, label_smooth=0., lst_w=0., lst_t=1.0, return_atten=False): 39 | target_lengths = torch.sum(1-target_paddings, dim=-1).long() 40 | logits, atten_info = self.get_logits(batch_wave, lengths, 41 | target_ids, target_lengths, return_atten=True) 42 | losses = self._compute_cross_entropy_losses(logits, target_labels, target_paddings) 43 | loss = torch.sum(losses) 44 | if label_smooth > 0: 45 | loss = loss*(1-label_smooth) + self._uniform_label_smooth(logits, target_paddings)*label_smooth 46 | if lst_w > 0.: 47 | loss = loss*(1-lst_w) + self._lst(logits, target_ids, target_paddings, T=lst_t)*lst_w 48 | if return_atten: 49 | return loss, atten_info 50 | return loss 51 | 52 | 53 | def _uniform_label_smooth(self, logits, paddings): 54 | log_probs = F.log_softmax(logits, dim=-1) 55 | nlabel = log_probs.shape[-1] 56 | ent_uniform = -torch.sum(log_probs, dim=-1)/nlabel 57 | return torch.sum(ent_uniform*(1-paddings).float()) 58 | 59 | 60 | def _lst(self, logits, target_ids, target_paddings, T=5.0): 61 | with torch.no_grad(): 62 | self.lm.eval() 63 | lengths = torch.sum(1-target_paddings, dim=-1).long() 64 | teacher_probs = self.lm.get_probs(target_ids, lengths, T=T) 65 | logprobs = torch.log_softmax(logits, dim=-1) 66 | losses = -torch.sum(teacher_probs * logprobs, dim=-1) 67 | return torch.sum(losses*(1-target_paddings).float()) 68 | 69 | 70 | def _compute_cross_entropy_losses(self, logits, labels, paddings): 71 | B, T, V = logits.shape 72 | losses = F.cross_entropy(logits.view(-1, V), labels.view(-1), reduction="none").view(B, T) * (1-paddings).float() 73 | return losses 74 | 75 | 76 | def _compute_wers(self, hyps, labels): 77 | raise NotImplementedError() 78 | 79 | 80 | def _sample_nbest(self, encoder_output, encoder_output_lengths, nbest_keep=4,): 81 | self._beam_search(encoder_outputs, encoder_output_lengths, nbest_keep, sosid, maxlen) 82 | raise NotImplementedError() 83 | 84 | 85 | def _compute_mwer_loss(self): 86 | raise NotImplementedError() 87 | 88 | 89 | def get_logits(self, batch_wave, lengths, target_ids, target_lengths, return_atten=False): 90 | if return_atten: 91 | timer = utils.Timer() 92 | timer.tic() 93 | sp_outputs, sp_output_lengths = self.splayer(batch_wave, lengths) 94 | logging.debug("splayer time: {}s".format(timer.toc())) 95 | timer.tic() 96 | encoder_outputs, encoder_output_lengths, enc_self_atten_list = self.encoder(sp_outputs, sp_output_lengths, return_atten=True) 97 | logging.debug("encoder time: {}s".format(timer.toc())) 98 | timer.tic() 99 | outputs, decoder_atten_tuple_list = self.decoder(encoder_outputs, encoder_output_lengths, target_ids, target_lengths, return_atten=True) 100 | logging.debug("decoder time: {}s".format(timer.toc())) 101 | timer.tic() 102 | return outputs, (encoder_outputs, encoder_output_lengths, enc_self_atten_list, target_lengths, decoder_atten_tuple_list, sp_outputs, sp_output_lengths) 103 | else: 104 | timer = utils.Timer() 105 | timer.tic() 106 | encoder_outputs, encoder_output_lengths = self.splayer(batch_wave, lengths) 107 | logging.debug("splayer time: {}s".format(timer.toc())) 108 | timer.tic() 109 | encoder_outputs, encoder_output_lengths = self.encoder(encoder_outputs, encoder_output_lengths, return_atten=False) 110 | logging.debug("encoder time: {}s".format(timer.toc())) 111 | timer.tic() 112 | outputs = self.decoder(encoder_outputs, encoder_output_lengths, target_ids, target_lengths, return_atten=False) 113 | logging.debug("decoder time: {}s".format(timer.toc())) 114 | timer.tic() 115 | return outputs 116 | 117 | 118 | def decode(self, batch_wave, lengths, nbest_keep, sosid=1, eosid=2, maxlen=100): 119 | if type(nbest_keep) != int: 120 | raise ValueError("nbest_keep must be a int.") 121 | encoder_outputs, encoder_output_lengths = self._get_acoustic_representations( 122 | batch_wave, lengths) 123 | target_ids, scores = self._beam_search(encoder_outputs, encoder_output_lengths, nbest_keep, sosid, eosid, maxlen) 124 | return target_ids, scores 125 | 126 | 127 | def _get_acoustic_representations(self, batch_wave, lengths): 128 | encoder_outputs, encoder_output_lengths = self.splayer(batch_wave, lengths) 129 | encoder_outputs, encoder_output_lengths = self.encoder(encoder_outputs, encoder_output_lengths, return_atten=False) 130 | return encoder_outputs, encoder_output_lengths 131 | 132 | 133 | def _beam_search(self, encoder_outputs, encoder_output_lengths, nbest_keep, sosid, eosid, maxlen): 134 | 135 | B = encoder_outputs.shape[0] 136 | # init 137 | init_target_ids = torch.ones(B, 1).to(encoder_outputs.device).long()*sosid 138 | init_target_lengths = torch.ones(B).to(encoder_outputs.device).long() 139 | outputs = (self.decoder(encoder_outputs, encoder_output_lengths, init_target_ids, init_target_lengths)[:, -1, :]) 140 | vocab_size = outputs.size(-1) 141 | outputs = outputs.view(B, vocab_size) 142 | log_probs = F.log_softmax(outputs, dim=-1) 143 | topk_res = torch.topk(log_probs, k=nbest_keep, dim=-1) 144 | nbest_ids = topk_res[1].view(-1) #[batch_size*nbest_keep, 1] 145 | nbest_logprobs = topk_res[0].view(-1) 146 | 147 | target_ids = torch.ones(B*nbest_keep, 1).to(encoder_outputs.device).long()*sosid 148 | target_lengths = torch.ones(B*nbest_keep).to(encoder_outputs.device).long() 149 | 150 | target_ids = torch.cat([target_ids, nbest_ids.view(B*nbest_keep, 1)], dim=-1) 151 | target_lengths += 1 152 | 153 | finished_sel = None 154 | ended = [] 155 | ended_scores = [] 156 | ended_batch_idx = [] 157 | for step in range(1, maxlen): 158 | (nbest_ids, nbest_logprobs, beam_from) = self._decode_single_step( 159 | encoder_outputs, encoder_output_lengths, target_ids, target_lengths, nbest_logprobs, finished_sel) 160 | batch_idx = (torch.arange(B)*nbest_keep).view(B, -1).repeat(1, nbest_keep).contiguous().to(beam_from.device) 161 | batch_beam_from = (batch_idx + beam_from.view(-1, nbest_keep)).view(-1) 162 | nbest_logprobs = nbest_logprobs.view(-1) 163 | finished_sel = (nbest_ids.view(-1) == eosid) 164 | target_ids = target_ids[batch_beam_from] 165 | target_ids = torch.cat([target_ids, nbest_ids.view(B*nbest_keep, 1)], dim=-1) 166 | target_lengths += 1 167 | 168 | for i in range(finished_sel.shape[0]): 169 | if finished_sel[i]: 170 | ended.append(target_ids[i]) 171 | ended_scores.append(nbest_logprobs[i]) 172 | ended_batch_idx.append(i // nbest_keep) 173 | target_ids = target_ids * (1 - finished_sel[:, None].long()) # mask out finished 174 | 175 | for i in range(target_ids.shape[0]): 176 | ended.append(target_ids[i]) 177 | ended_scores.append(nbest_logprobs[i]) 178 | ended_batch_idx.append(i // nbest_keep) 179 | 180 | formated = {} 181 | for i in range(B): 182 | formated[i] = [] 183 | for i in range(len(ended)): 184 | if ended[i][0] == sosid: 185 | formated[ended_batch_idx[i]].append((ended[i], ended_scores[i])) 186 | for i in range(B): 187 | formated[i] = sorted(formated[i], key=lambda x:x[1], reverse=True)[:nbest_keep] 188 | 189 | target_ids = torch.zeros(B, nbest_keep, maxlen+1).to(encoder_outputs.device).long() 190 | scores = torch.zeros(B, nbest_keep).to(encoder_outputs.device) 191 | for i in range(B): 192 | for j in range(nbest_keep): 193 | item = formated[i][j] 194 | l = min(item[0].shape[0], target_ids[i, j].shape[0]) 195 | target_ids[i, j, :l] = item[0][:l] 196 | scores[i, j] = item[1] 197 | return target_ids, scores 198 | 199 | 200 | def _decode_single_step(self, encoder_outputs, encoder_output_lengths, target_ids, target_lengths, accumu_scores, finished_sel=None): 201 | """ 202 | encoder_outputs: [B, T_e, D_e] 203 | encoder_output_lengths: [B] 204 | target_ids: [B*nbest_keep, T_d] 205 | target_lengths: [B*nbest_keep] 206 | accumu_scores: [B*nbest_keep] 207 | """ 208 | 209 | B, T_e, D_e = encoder_outputs.shape 210 | B_d, T_d = target_ids.shape 211 | if B_d % B != 0: 212 | raise ValueError("The dim of target_ids does not match the encoder_outputs.") 213 | nbest_keep = B_d // B 214 | encoder_outputs = (encoder_outputs.view(B, 1, T_e, D_e) 215 | .repeat(1, nbest_keep, 1, 1).view(B*nbest_keep, T_e, D_e)) 216 | encoder_output_lengths = (encoder_output_lengths.view(B, 1) 217 | .repeat(1, nbest_keep).view(-1)) 218 | 219 | # outputs: [B, nbest_keep, vocab_size] 220 | outputs = (self.decoder(encoder_outputs, encoder_output_lengths, target_ids, target_lengths)[:, -1, :]) 221 | vocab_size = outputs.size(-1) 222 | outputs = outputs.view(B, nbest_keep, vocab_size) 223 | log_probs = F.log_softmax(outputs, dim=-1) # [B, nbest_keep, vocab_size] 224 | if finished_sel is not None: 225 | log_probs = log_probs.view(B*nbest_keep, -1) - finished_sel.view(B*nbest_keep, -1).float()*9e9 226 | log_probs = log_probs.view(B, nbest_keep, vocab_size) 227 | this_accumu_scores = accumu_scores.view(B, nbest_keep, 1) + log_probs 228 | topk_res = torch.topk(this_accumu_scores.view(B, nbest_keep*vocab_size), k=nbest_keep, dim=-1) 229 | 230 | nbest_logprobs = topk_res[0] # [B, nbest_keep] 231 | nbest_ids = topk_res[1] % vocab_size # [B, nbest_keep] 232 | beam_from = (topk_res[1] / vocab_size).long() 233 | return nbest_ids, nbest_logprobs, beam_from 234 | 235 | 236 | def package(self): 237 | pkg = { 238 | "splayer_config": self.splayer.config, 239 | "splayer_state": self.splayer.state_dict(), 240 | "encoder_config": self.encoder.config, 241 | "encoder_state": self.encoder.state_dict(), 242 | "decoder_config": self.decoder.config, 243 | "decoder_state": self.decoder.state_dict(), 244 | } 245 | return pkg 246 | 247 | 248 | def restore(self, pkg): 249 | # check config 250 | logging.info("Restore model states...") 251 | for key in self.splayer.config.keys(): 252 | if key == "spec_aug": 253 | continue 254 | if self.splayer.config[key] != pkg["splayer_config"][key]: 255 | raise ValueError("splayer_config mismatch.") 256 | for key in self.encoder.config.keys(): 257 | if (key != "dropout_rate" and 258 | self.encoder.config[key] != pkg["encoder_config"][key]): 259 | raise ValueError("encoder_config mismatch.") 260 | for key in self.decoder.config.keys(): 261 | if (key != "dropout_rate" and 262 | self.decoder.config[key] != pkg["decoder_config"][key]): 263 | raise ValueError("decoder_config mismatch.") 264 | 265 | self.splayer.load_state_dict(pkg["splayer_state"]) 266 | self.encoder.load_state_dict(pkg["encoder_state"]) 267 | self.decoder.load_state_dict(pkg["decoder_state"]) 268 | 269 | 270 | def _reset_parameters(self): 271 | for p in self.parameters(): 272 | if p.dim() > 1: 273 | xavier_uniform_(p) 274 | 275 | 276 | class LM(torch.nn.Module): 277 | def __init__(self, lmlayer): 278 | super(LM, self).__init__() 279 | self.lm_layer = lmlayer 280 | #self._reset_parameters() 281 | 282 | 283 | def forward(self, ids, labels, paddings, label_smooth=0.): 284 | lengths = torch.sum(1-paddings, dim=1).long() 285 | logits = self.get_logits(ids, lengths) 286 | ntoken = torch.sum(1-paddings) 287 | tot_loss = torch.sum(self._compute_ce_loss(logits, labels, paddings)) 288 | if label_smooth > 0: 289 | tot_loss = tot_loss*(1-label_smooth) + self._uniform_label_smooth(logits, paddings)*label_smooth 290 | tot_ncorrect = self._compute_ncorrect(logits, labels, paddings) 291 | return tot_loss, tot_ncorrect 292 | 293 | 294 | def fetch_vis_info(self, ids, labels, paddings): 295 | lengths = torch.sum(1-paddings, dim=1).long() 296 | atten = None 297 | if isinstance(self.lm_layer, lm_layers.TransformerLM): 298 | logits, atten = self.lm_layer(ids, lengths, return_atten=True) 299 | elif (isinstance(self.lm_layer, clozer.ClozerV2) or 300 | isinstance(self.lm_layer, clozer.Clozer) or 301 | isinstance(self.lm_layer, clozer.UniClozer) or 302 | isinstance(self.lm_layer, clozer.BwdUniClozer)): 303 | logits, atten = self.lm_layer(ids, lengths, return_atten=True) 304 | else: 305 | raise ValueError('Unknown lm layer') 306 | return atten 307 | 308 | 309 | def get_probs(self, ids, lengths, T=1.0): 310 | logits = self.get_logits(ids, lengths) 311 | probs = F.softmax(logits/T, dim=-1) 312 | return probs 313 | 314 | 315 | def get_logprobs(self, ids, lengths, T=1.0): 316 | logits = self.get_logits(ids, lengths) 317 | logprobs = F.log_softmax(logits/T, dim=-1) 318 | return logprobs 319 | 320 | 321 | def get_logits(self, ids, lengths=None): 322 | if len(ids.shape) == 1: 323 | B = ids.shape[0] 324 | ids = ids.view(B, 1).contiguous() 325 | logits = self.lm_layer(ids, lengths) 326 | return logits 327 | 328 | 329 | def _compute_ce_loss(self, logits, labels, paddings): 330 | D = logits.size(-1) 331 | losses = F.cross_entropy(logits.view(-1, D).contiguous(), labels.view(-1), reduction='none') 332 | return losses * (1-paddings).view(-1).float() 333 | 334 | 335 | def _uniform_label_smooth(self, logits, paddings): 336 | log_probs = F.log_softmax(logits, dim=-1) 337 | nlabel = log_probs.shape[-1] 338 | ent_uniform = -torch.sum(log_probs, dim=-1)/nlabel 339 | return torch.sum(ent_uniform*(1-paddings).float()) 340 | 341 | 342 | def _compute_ncorrect(self, logits, labels, paddings): 343 | D = logits.size(-1) 344 | logprobs = F.log_softmax(logits, dim=-1) 345 | pred = torch.argmax(logprobs.view(-1, D), dim=-1) 346 | n_correct = torch.sum((pred == labels.view(-1)).float() * (1-paddings).view(-1).float()) 347 | return n_correct 348 | 349 | 350 | def package(self): 351 | pkg = { 352 | "lm_config": self.lm_layer.config, 353 | "lm_state": self.lm_layer.state_dict(), 354 | } 355 | return pkg 356 | 357 | 358 | def restore(self, pkg): 359 | # check config 360 | logging.info("Restore model states...") 361 | for key in self.lm_layer.config.keys(): 362 | if (key != "dropout_rate" and 363 | self.lm_layer.config[key] != pkg["lm_config"][key]): 364 | raise ValueError("lm_config mismatch.") 365 | 366 | self.lm_layer.load_state_dict(pkg["lm_state"]) 367 | 368 | 369 | def _reset_parameters(self): 370 | self.lm_layer.reset_parameters() 371 | -------------------------------------------------------------------------------- /src/modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | class PositionalEncoding(nn.Module): 6 | """Implement the positional encoding (PE) function. 7 | 8 | PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) 9 | PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) 10 | """ 11 | 12 | def __init__(self, d_model, max_len=5000): 13 | super(PositionalEncoding, self).__init__() 14 | # Compute the positional encodings once in log space. 15 | self.scale = d_model**0.5 16 | pe = torch.zeros(max_len, d_model, requires_grad=False) 17 | position = torch.arange(0, max_len).unsqueeze(1).float() 18 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * 19 | -(math.log(10000.0) / d_model)) 20 | pe[:, 0::2] = torch.sin(position * div_term) 21 | pe[:, 1::2] = torch.cos(position * div_term) 22 | pe = pe.unsqueeze(0) 23 | self.register_buffer('pe', pe) 24 | 25 | def forward(self, input): 26 | """ 27 | Args: 28 | input: N x T x D 29 | """ 30 | length = input.size(1) 31 | return input*(self.scale)+self.pe[:, :length] 32 | 33 | 34 | -------------------------------------------------------------------------------- /src/prepare_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Ye Bai by1993@qq.com 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | import logging 17 | import argparse 18 | import json 19 | import os 20 | import utils 21 | 22 | logging.basicConfig( 23 | level=logging.INFO, 24 | format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s') 25 | 26 | 27 | def get_args(): 28 | parser = argparse.ArgumentParser(description=""" 29 | Usage: prepare_data.py """) 30 | parser.add_argument("data_dir", help="data directory") 31 | parser.add_argument("dest_path", help="path to dest") 32 | parser.add_argument("--tag", type=str, default="file", 33 | help="tag of path. It should be file, pipe, or ark.") 34 | parser.add_argument("--maxdur", type=float, default=9e9, 35 | help="if the duration is longer than maxdur, drop it.") 36 | args = parser.parse_args() 37 | return args 38 | 39 | 40 | def get_dur(wav_dic): 41 | durdic = {} 42 | for key, path in wav_dic.items(): 43 | sample_rate, data = utils.load_wave(path) 44 | dur = data.shape[0]/float(sample_rate) 45 | durdic[key] = dur 46 | return durdic 47 | 48 | 49 | if __name__ == "__main__": 50 | args = get_args() 51 | datadir = args.data_dir 52 | fw = args.dest_path 53 | logging.info("Preparing data for {}...".format(datadir)) 54 | if os.path.exists(os.path.join(datadir, "wav.scp")): 55 | logging.info("wav.scp exists. Use it.") 56 | wav_dic = utils.parse_scp(os.path.join(datadir, "wav.scp")) 57 | elif os.path.exists(os.path.join(datadir, "feats.scp")): 58 | logging.info("wav.scp does not exists. Use feats.scp.") 59 | wav_dic = utils.parse_scp(os.path.join(datadir, "feats.scp")) 60 | else: 61 | raise ValueError("No speech scp.") 62 | trans_dic = utils.parse_scp(os.path.join(datadir, "text")) 63 | utts = wav_dic.keys() 64 | for utt in utts: 65 | wav_dic[utt] = "{}:{}".format(args.tag, wav_dic[utt]) 66 | if os.path.exists(os.path.join(datadir, "utt2dur")): 67 | dur_dic = utils.parse_scp(os.path.join(datadir, "utt2dur")) 68 | else: 69 | logging.info("No utt2dur file, generate it.") 70 | dur_dic = get_dur(wav_dic) 71 | 72 | n_tot = 0 73 | n_success = 0 74 | n_durskip = 0 75 | towrite = [] 76 | for utt in utts: 77 | n_tot += 1 78 | if utt not in trans_dic: 79 | logging.warn("No trans for {}, skip it.".format(utt)) 80 | continue 81 | elif utt not in dur_dic: 82 | logging.warn("No dur for {}, skip it.".format(utt)) 83 | continue 84 | 85 | if float(dur_dic[utt]) > args.maxdur: 86 | logging.warn("{} is longer than {}s, skip it.".format(utt, dur_dic[utt])) 87 | n_durskip += 1 88 | continue 89 | else: 90 | towrite.append({ 91 | "utt": utt, 92 | "path": wav_dic[utt], 93 | "transcript": trans_dic[utt], 94 | "duration": float(dur_dic[utt]), 95 | }) 96 | n_success += 1 97 | with open(fw, 'w', encoding="utf8") as f: 98 | json.dump(towrite, f, ensure_ascii=False, indent=2) 99 | logging.info("\nProcessed {} utterances successfully. " 100 | "The total number is {}. ({:2%}) {} utterances are too long.".format(n_success, n_tot, 1.*n_success/n_tot, n_durskip)) 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /src/schedule.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Ye Bai by1993@qq.com 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | import logging 17 | import torch 18 | import math 19 | 20 | 21 | def get_scheduler(config): 22 | if config["type"] == "linear": 23 | return LinearLearningRateSchedule(config) 24 | elif config["type"] == "warmup_linear": 25 | return WarmupLinearLearningRateSchedule(config) 26 | elif config["type"] == "bob": 27 | return BobLearningRateSchedule(config) 28 | elif config["type"] == "warmup_transformer": 29 | return WarmupTransformerLearningRateSchedule(config) 30 | else: 31 | raise ValueError("Unknown scheduler.") 32 | 33 | 34 | class BaseLearningRateSchedule(object): 35 | def __init__(self): 36 | self.step_num = 0 37 | self.decay_rate = 1. 38 | self.config = None 39 | self.misc_state = -1 40 | self.update_only_with_step = True 41 | 42 | def set_lr(self, optimizer, init_lr): 43 | for param_group in optimizer.param_groups: 44 | param_group['lr'] = init_lr * self.decay_rate 45 | 46 | def step(self): 47 | self.step_num += 1 48 | if self.update_only_with_step: 49 | self.update_decay_rate() 50 | 51 | def pack_state(self): 52 | pkg = { 53 | "step": self.step_num, 54 | "decay_rate": self.decay_rate, 55 | "misc_state": self.misc_state 56 | } 57 | return pkg 58 | 59 | def restore_state(self, pkg): 60 | self.step_num = pkg['step'] 61 | self.decay_rate = pkg['decay_rate'] 62 | self.misc_state = pkg['misc_state'] 63 | self.check_misc_state() 64 | 65 | def check_misc_state(self): 66 | raise NotImplementedError() 67 | 68 | def update_decay_rate(self): 69 | raise NotImplementedError() 70 | 71 | 72 | def compute_polynomial_intep(x, x0, y0, x1, y1, power): 73 | if x < x0: 74 | return y0 75 | elif x > x1: 76 | return y1 77 | else: 78 | if power != 1.0: 79 | f = ((1.0 * x - x0) / (x1 - x0)) ** power 80 | else: 81 | f = ((1.0 * x - x0) / (x1 - x0)) 82 | y = y0 + f * (y1 - y0) 83 | return y 84 | 85 | 86 | def compute_linear_intep(x, x0, y0, x1, y1): 87 | return compute_polynomial_intep(x, x0, y0, x1, y1, 1.0) 88 | 89 | 90 | class LinearLearningRateSchedule(BaseLearningRateSchedule): 91 | def __init__(self, conf): 92 | super(LinearLearningRateSchedule, self).__init__() 93 | self.config = { 94 | "x0": conf["x0"], 95 | "y0": conf["y0"], 96 | "x1": conf["x1"], 97 | "y1": conf["y1"], 98 | } 99 | def check_misc_state(self): 100 | pass 101 | 102 | 103 | def update_decay_rate(self): 104 | self.decay_rate = compute_linear_intep(self.step_num, self.config["x0"], 105 | self.config["y0"], self.config["x1"], self.config["y1"]) 106 | 107 | 108 | class WarmupLinearLearningRateSchedule(LinearLearningRateSchedule): 109 | def __init__(self, conf): 110 | super(WarmupLinearLearningRateSchedule, self).__init__(conf) 111 | self.config["warmup_step"] = conf["warmup_step"] 112 | 113 | def update_decay_rate(self): 114 | dc0 = compute_linear_intep(self.step_num, 0, 115 | 0, self.config["warmup_step"], self.config["y0"]) 116 | dc1 = compute_linear_intep(self.step_num, self.config["x0"], 117 | self.config["y0"], self.config["x1"], self.config["y1"]) 118 | self.decay_rate = min(dc0, dc1) 119 | 120 | 121 | class WarmupTransformerLearningRateSchedule(BaseLearningRateSchedule): 122 | def __init__(self, conf): 123 | super(WarmupTransformerLearningRateSchedule, self).__init__() 124 | self.config = {} 125 | self.config["warmup_step"] = conf["warmup_step"] 126 | self.config["d_model"] = conf["d_model"] 127 | 128 | def update_decay_rate(self): 129 | d0 = self.step_num**(-0.5) 130 | d1 = self.step_num*(self.config["warmup_step"]**(-1.5)) 131 | self.decay_rate = (self.config["d_model"]**(-0.5))*min(d0, d1) 132 | 133 | def check_misc_state(self): 134 | pass 135 | 136 | 137 | class BobLearningRateSchedule(BaseLearningRateSchedule): 138 | def __init__(self, conf): 139 | super(BobLearningRateSchedule, self).__init__() 140 | self.update_only_with_step = False 141 | self.config = { 142 | "decay_coef": conf["decay_coef"], 143 | "tolerate": conf["tolerate"], 144 | } 145 | self.misc_state = { 146 | "last_loss": -1, 147 | "last_decay_rate": 1, 148 | } 149 | 150 | def update_decay_rate(self, this_loss): 151 | improvement = (self.misc_state["last_loss"] - this_loss)/self.misc_state["last_loss"] 152 | if improvement < self.config["tolerate"]: 153 | logging.info(("Improvment {:.4f} is smaller than tolerate {:.4f}," 154 | " decay LR.").format(improvement, self.config["tolerate"])) 155 | new_decay_rate = self.misc_state["last_decay_rate"] * self.config["decay_coef"] 156 | self.decay_rate = new_decay_rate 157 | self.misc_state["last_decay_rate"] = new_decay_rate 158 | self.misc_state["last_loss"] = this_loss 159 | 160 | def check_misc_state(self): 161 | if (not "last_loss" in self.misc_state or 162 | not "last_decay_rate" in self.misc_state): 163 | raise ValueError("The misc states are not match. Maybe the package was not trained with Bob lr schedule.") 164 | -------------------------------------------------------------------------------- /src/sp_layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Ye Bai by1993@qq.com 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import logging 18 | import numpy as np 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | 23 | from third_party import kaldi_signal as ksp 24 | import utils 25 | 26 | class SPLayer(nn.Module): 27 | 28 | def __init__(self, config): 29 | super(SPLayer, self).__init__() 30 | self.config = config 31 | self.feature_type = config["feature_type"] 32 | self.sample_rate = float(config["sample_rate"]) 33 | self.num_mel_bins = int(config["num_mel_bins"]) 34 | self.use_energy = config["use_energy"] 35 | self.spec_aug_conf = None 36 | if "spec_aug" in config: 37 | self.spec_aug_conf = { 38 | "freq_mask_num": config["spec_aug"]["freq_mask_num"], 39 | "freq_mask_width": config["spec_aug"]["freq_mask_width"], 40 | "time_mask_num": config["spec_aug"]["time_mask_num"], 41 | "time_mask_width": config["spec_aug"]["time_mask_width"], 42 | } 43 | 44 | if self.feature_type == "mfcc": 45 | self.num_ceps = config["num_ceps"] 46 | else: 47 | self.num_ceps = None 48 | if self.feature_type == "offline": 49 | feature_func = None 50 | logging.warn("Use offline features. It is your duty to keep features match.") 51 | elif self.feature_type == "fbank": 52 | def feature_func(waveform): 53 | return ksp.fbank( 54 | waveform, 55 | sample_frequency=self.sample_rate, 56 | use_energy=self.use_energy, 57 | num_mel_bins=self.num_mel_bins 58 | ) 59 | elif self.feature_type == "mfcc": 60 | def feature_func(waveform): 61 | return ksp.mfcc( 62 | waveform, 63 | sample_frequency=self.sample_rate, 64 | use_energy=self.use_energy, 65 | num_mel_bins=self.num_mel_bins 66 | ) 67 | else: 68 | raise ValueError("Unknown feature type.") 69 | self.func = feature_func 70 | 71 | def spec_aug(self, padded_features, feature_lengths): 72 | freq_means = torch.mean(padded_features, dim=-1) 73 | time_means = (torch.sum(padded_features, dim=1) 74 | /feature_lengths[:, None].float()) # Note that features are padded with zeros. 75 | 76 | B, T, V = padded_features.shape 77 | # mask freq 78 | for _ in range(self.spec_aug_conf["freq_mask_num"]): 79 | fs = (self.spec_aug_conf["freq_mask_width"]*torch.rand(size=[B], 80 | device=padded_features.device, requires_grad=False)).long() 81 | f0s = ((V-fs).float()*torch.rand(size=[B], 82 | device=padded_features.device, requires_grad=False)).long() 83 | for b in range(B): 84 | padded_features[b, :, f0s[b]:f0s[b]+fs[b]] = freq_means[b][:, None] 85 | 86 | # mask time 87 | for _ in range(self.spec_aug_conf["time_mask_num"]): 88 | ts = (self.spec_aug_conf["time_mask_width"]*torch.rand(size=[B], 89 | device=padded_features.device, requires_grad=False)).long() 90 | t0s = ((feature_lengths-ts).float()*torch.rand(size=[B], 91 | device=padded_features.device, requires_grad=False)).long() 92 | for b in range(B): 93 | padded_features[b, t0s[b]:t0s[b]+ts[b], :] = time_means[b][None, :] 94 | return padded_features, feature_lengths 95 | 96 | def forward(self, wav_batch, lengths): 97 | batch_size, batch_length = wav_batch.shape[0], wav_batch.shape[1] 98 | if self.func is not None: 99 | features = [] 100 | feature_lengths = [] 101 | for i in range(batch_size): 102 | feature = self.func(wav_batch[i, :lengths[i]].view(1, -1)) 103 | features.append(feature) 104 | feature_lengths.append(feature.shape[0]) 105 | 106 | # pad to max_length 107 | max_length = max(feature_lengths) 108 | padded_features = torch.zeros(batch_size, max_length, feature.shape[-1]).to(feature.device) 109 | for i in range(batch_size): 110 | l = feature_lengths[i] 111 | padded_features[i, :l, :] += features[i] 112 | else: 113 | padded_features = torch.tensor(wav_batch) 114 | feature_lengths = lengths 115 | 116 | feature_lengths = torch.tensor(feature_lengths).long().to(padded_features.device) 117 | 118 | if self.training and self.spec_aug_conf is not None: 119 | padded_features, feature_lengths = self.spec_aug(padded_features, feature_lengths) 120 | 121 | return padded_features, feature_lengths 122 | -------------------------------------------------------------------------------- /src/sp_layers_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import utils 7 | from sp_layers import SPLayer 8 | 9 | def fbank_test(): 10 | conf = { 11 | "feature_type": "fbank", 12 | "sample_rate": 16000, 13 | "num_mel_bins": 40, 14 | "use_energy": False 15 | } 16 | fn = "file:testdata/100-121669-0000.wav" 17 | pipe = "pipe:flac -c -d -s testdata/103-1240-0005.flac |" 18 | sample_rate, waveform1 = utils.load_wave(fn) 19 | sample_rate, waveform2 = utils.load_wave(pipe) 20 | waveform1 = torch.from_numpy(waveform1) 21 | waveform2 = torch.from_numpy(waveform2) 22 | lengths = [waveform1.shape[0], waveform2.shape[0]] 23 | max_length = max(lengths) 24 | padded_waveforms = torch.zeros(2, max_length) 25 | padded_waveforms[0, :lengths[0]] += waveform1 26 | padded_waveforms[1, :lengths[1]] += waveform2 27 | layer = SPLayer(conf) 28 | 29 | features, feature_lengths = layer(padded_waveforms, lengths) 30 | print(features) 31 | print(feature_lengths) 32 | 33 | def specaug_fbank_test(): 34 | conf = { 35 | "feature_type": "fbank", 36 | "sample_rate": 16000, 37 | "num_mel_bins": 80, 38 | "use_energy": False, 39 | "spec_aug": { 40 | "freq_mask_num": 2, 41 | "freq_mask_width": 27, 42 | "time_mask_num": 2, 43 | "time_mask_width": 100, 44 | } 45 | } 46 | fn = "file:testdata/100-121669-0000.wav" 47 | pipe = "pipe:flac -c -d -s testdata/103-1240-0005.flac |" 48 | sample_rate, waveform1 = utils.load_wave(fn) 49 | sample_rate, waveform2 = utils.load_wave(pipe) 50 | waveform1 = torch.from_numpy(waveform1) 51 | waveform2 = torch.from_numpy(waveform2) 52 | lengths = [waveform1.shape[0], waveform2.shape[0]] 53 | max_length = max(lengths) 54 | print(lengths) 55 | padded_waveforms = torch.zeros(2, max_length) 56 | padded_waveforms[0, :lengths[0]] += waveform1 57 | padded_waveforms[1, :lengths[1]] += waveform2 58 | layer = SPLayer(conf) 59 | 60 | features, feature_lengths = layer(padded_waveforms, lengths) 61 | 62 | import matplotlib as mpl 63 | mpl.use('Agg') 64 | import matplotlib.pyplot as plt 65 | plt.imshow(features[1].numpy()) 66 | plt.savefig("test.png") 67 | 68 | #print(features) 69 | #print(feature_lengths) 70 | 71 | 72 | def specaug_test(): 73 | featconf = { 74 | "feature_type": "fbank", 75 | "sample_rate": 16000, 76 | "num_mel_bins": 40, 77 | "use_energy": False 78 | } 79 | augconf = { 80 | "feature_type": "fbank", 81 | "sample_rate": 16000, 82 | "num_mel_bins": 40, 83 | "use_energy": False, 84 | "spec_aug": { 85 | "freq_mask_width": 10, 86 | "freq_mask_num": 2, 87 | "time_mask_width": 100, 88 | "time_mask_num": 2} 89 | } 90 | fn = "file:testdata/100-121669-0000.wav" 91 | pipe = "pipe:flac -c -d -s testdata/103-1240-0005.flac |" 92 | sample_rate, waveform1 = utils.load_wave(fn) 93 | sample_rate, waveform2 = utils.load_wave(pipe) 94 | waveform1 = torch.from_numpy(waveform1) 95 | waveform2 = torch.from_numpy(waveform2) 96 | lengths = [waveform1.shape[0], waveform2.shape[0]] 97 | max_length = max(lengths) 98 | padded_waveforms = torch.zeros(2, max_length) 99 | padded_waveforms[0, :lengths[0]] += waveform1 100 | padded_waveforms[1, :lengths[1]] += waveform2 101 | splayer = SPLayer(featconf) 102 | auglayer = SPLayer(augconf) 103 | features, feature_lengths = splayer(padded_waveforms, lengths) 104 | features2, feature_lengths2 = auglayer(padded_waveforms, lengths) 105 | print("Before augmentation") 106 | print(features) 107 | print("After augmentation") 108 | print(features2) 109 | 110 | if __name__ == "__main__": 111 | fbank_test() 112 | specaug_test() 113 | specaug_fbank_test() 114 | 115 | 116 | 117 | 118 | -------------------------------------------------------------------------------- /src/stat_grapheme.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | import argparse 3 | 4 | 5 | def get_args(): 6 | parser = argparse.ArgumentParser(description=""" 7 | Usage: stat_grapheme.py """) 8 | parser.add_argument("text", help="path to text.") 9 | parser.add_argument("vocab", help="path to store vocab.") 10 | parser.add_argument("--vocab-size", type=int, default=100000, help="vocabulary size.") 11 | args = parser.parse_args() 12 | return args 13 | 14 | if __name__ == "__main__": 15 | args = get_args() 16 | fn = args.text 17 | fnw = args.vocab 18 | vocabsize = args.vocab_size 19 | txt = "" 20 | with open(fn, 'r', encoding="utf8") as f: 21 | for line in f: 22 | items = line.strip().split(' ', 1) 23 | if len(items) == 1: 24 | continue 25 | txt += items[1] 26 | 27 | txtlist = list(txt) 28 | 29 | cnter = Counter(txtlist) 30 | 31 | most = cnter.most_common(None) 32 | 33 | with open(fnw, 'w', encoding="utf8") as f: 34 | t = [m[0] for m in most] 35 | f.write("\n".join(t[:vocabsize])) 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /src/stat_length.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Ye Bai by1993@qq.com 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | import sys 17 | import os 18 | import argparse 19 | import logging 20 | import yaml 21 | import numpy as np 22 | import torch 23 | 24 | if "LAS_LOG_LEVEL" in os.environ: 25 | LOG_LEVEL = os.environ["LAS_LOG_LEVEL"] 26 | else: 27 | LOG_LEVEL = "INFO" 28 | if LOG_LEVEL == "DEBUG": 29 | logging.basicConfig( 30 | level=logging.DEBUG, 31 | format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s') 32 | else: 33 | logging.basicConfig( 34 | level=logging.INFO, 35 | format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s') 36 | 37 | import utils 38 | import data 39 | 40 | 41 | def get_args(): 42 | parser = argparse.ArgumentParser(description=""" 43 | Usage: stat_lengths.py """) 44 | parser.add_argument("data", help="path to data.") 45 | parser.add_argument("vocab_file", help="path to vocabulary file.") 46 | args = parser.parse_args() 47 | return args 48 | 49 | 50 | if __name__ == "__main__": 51 | args = get_args() 52 | vocab_path = args.vocab_file 53 | data_path = args.data 54 | training_set = data.SpeechDataset(data_path) 55 | if vocab_path.endswith(".model"): 56 | tokenizer = data.WpmTokenizer(vocab_path) 57 | else: 58 | tokenizer = data.CharTokenizer(vocab_path) 59 | 60 | durs = [] 61 | id_lengths = [] 62 | for d in iter(training_set): 63 | durs.append(d["duration"]) 64 | ids = tokenizer.encode(d["transcript"]) 65 | #print(ids) 66 | id_lengths.append(len(ids)) 67 | durs = np.array(durs) 68 | id_lengths = np.array(id_lengths).astype(np.float) 69 | dur_percentile = np.percentile(durs, [10, 50, 90]) 70 | dur_max = np.max(durs) 71 | dur_min = np.min(durs) 72 | dur_mean = np.mean(durs) 73 | msg = ("duration statistics:\n" + 74 | "max: {:.4f}s | min {:.4f}s | mean {:.4f}\n".format(dur_max, dur_min, dur_mean) + 75 | "percentile at (10, 50, 90): {}s {}s {}s\n".format(dur_percentile[0], dur_percentile[1], dur_percentile[2])) 76 | 77 | id_len_percentile = np.percentile(id_lengths, [10, 50, 90]) 78 | id_len_max = np.max(id_lengths) 79 | id_len_min = np.min(id_lengths) 80 | id_len_mean = np.mean(id_lengths) 81 | msg += ("ids length statistics:\n" + 82 | "max: {:.4f} | min {:.4f} | mean {:.4f}\n".format(id_len_max, id_len_min, id_len_mean) + 83 | "percentile at (10, 50, 90): {} {} {}\n".format(id_len_percentile[0], id_len_percentile[1], id_len_percentile[2])) 84 | logging.info("\n"+msg) 85 | -------------------------------------------------------------------------------- /src/testdata/100-121669-0000.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/by2101/OpenASR/c5213d68304a270a0448b2d53adc72b57f4efdb3/src/testdata/100-121669-0000.flac -------------------------------------------------------------------------------- /src/testdata/100-121669-0000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/by2101/OpenASR/c5213d68304a270a0448b2d53adc72b57f4efdb3/src/testdata/100-121669-0000.wav -------------------------------------------------------------------------------- /src/testdata/103-1240-0005.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/by2101/OpenASR/c5213d68304a270a0448b2d53adc72b57f4efdb3/src/testdata/103-1240-0005.flac -------------------------------------------------------------------------------- /src/testdata/BAC009S0764W0121.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/by2101/OpenASR/c5213d68304a270a0448b2d53adc72b57f4efdb3/src/testdata/BAC009S0764W0121.wav -------------------------------------------------------------------------------- /src/testdata/test.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "utt": "BAC009S0764W0121", 4 | "path": "file:testdata/BAC009S0764W0121.wav", 5 | "transcript": "甚至出现交易几乎停滞的情况", 6 | "duration": 4.203938 7 | } 8 | ] -------------------------------------------------------------------------------- /src/testdata/tokens.txt: -------------------------------------------------------------------------------- 1 | 0 2 | 1 3 | SIL 2 4 | SPN 3 5 | AA 4 6 | AA0 5 7 | AA1 6 8 | AA2 7 9 | AE 8 10 | AE0 9 11 | AE1 10 12 | AE2 11 13 | AH 12 14 | AH0 13 15 | AH1 14 16 | AH2 15 17 | AO 16 18 | AO0 17 19 | AO1 18 20 | AO2 19 21 | AW 20 22 | AW0 21 23 | AW1 22 24 | AW2 23 25 | AY 24 26 | AY0 25 27 | AY1 26 28 | AY2 27 29 | B 28 30 | CH 29 31 | D 30 32 | DH 31 33 | EH 32 34 | EH0 33 35 | EH1 34 36 | EH2 35 37 | ER 36 38 | ER0 37 39 | ER1 38 40 | ER2 39 41 | EY 40 42 | EY0 41 43 | EY1 42 44 | EY2 43 45 | F 44 46 | G 45 47 | HH 46 48 | IH 47 49 | IH0 48 50 | IH1 49 51 | IH2 50 52 | IY 51 53 | IY0 52 54 | IY1 53 55 | IY2 54 56 | JH 55 57 | K 56 58 | L 57 59 | M 58 60 | N 59 61 | NG 60 62 | OW 61 63 | OW0 62 64 | OW1 63 65 | OW2 64 66 | OY 65 67 | OY0 66 68 | OY1 67 69 | OY2 68 70 | P 69 71 | R 70 72 | S 71 73 | SH 72 74 | T 73 75 | TH 74 76 | UH 75 77 | UH0 76 78 | UH1 77 79 | UH2 78 80 | UW 79 81 | UW0 80 82 | UW1 81 83 | UW2 82 84 | V 83 85 | W 84 86 | Y 85 87 | Z 86 88 | ZH 87 89 | #0 88 90 | #1 89 91 | #2 90 92 | #3 91 93 | #4 92 94 | #5 93 95 | #6 94 96 | #7 95 97 | #8 96 98 | #9 97 99 | #10 98 100 | #11 99 101 | #12 100 102 | #13 101 103 | #14 102 104 | #15 103 105 | #16 104 106 | -------------------------------------------------------------------------------- /src/third_party/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | import torch.nn.functional as F 4 | from torch.nn.modules import Module 5 | from torch.nn.modules.activation import MultiheadAttention 6 | from torch.nn.modules.container import ModuleList 7 | from torch.nn.init import xavier_uniform_ 8 | from torch.nn.modules.dropout import Dropout 9 | from torch.nn.modules.linear import Linear 10 | from torch.nn.modules.normalization import LayerNorm 11 | 12 | class Transformer(Module): 13 | r"""A transformer model. User is able to modify the attributes as needed. The architecture 14 | is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, 15 | Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and 16 | Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information 17 | Processing Systems, pages 6000-6010. Users can build the BERT(https://arxiv.org/abs/1810.04805) 18 | model with corresponding parameters. 19 | 20 | Args: 21 | d_model: the number of expected features in the encoder/decoder inputs (default=512). 22 | nhead: the number of heads in the multiheadattention models (default=8). 23 | num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6). 24 | num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6). 25 | dim_feedforward: the dimension of the feedforward network model (default=2048). 26 | dropout: the dropout value (default=0.1). 27 | activation: the activation function of encoder/decoder intermediate layer, relu or gelu (default=relu). 28 | custom_encoder: custom encoder (default=None). 29 | custom_decoder: custom decoder (default=None). 30 | 31 | Examples:: 32 | >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12) 33 | >>> src = torch.rand((10, 32, 512)) 34 | >>> tgt = torch.rand((20, 32, 512)) 35 | >>> out = transformer_model(src, tgt) 36 | 37 | Note: A full example to apply nn.Transformer module for the word language model is available in 38 | https://github.com/pytorch/examples/tree/master/word_language_model 39 | """ 40 | 41 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, 42 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 43 | activation="relu", custom_encoder=None, custom_decoder=None): 44 | super(Transformer, self).__init__() 45 | 46 | if custom_encoder is not None: 47 | self.encoder = custom_encoder 48 | else: 49 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation) 50 | encoder_norm = LayerNorm(d_model) 51 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 52 | 53 | if custom_decoder is not None: 54 | self.decoder = custom_decoder 55 | else: 56 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation) 57 | decoder_norm = LayerNorm(d_model) 58 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm) 59 | 60 | self._reset_parameters() 61 | 62 | self.d_model = d_model 63 | self.nhead = nhead 64 | 65 | def forward(self, src, tgt, src_mask=None, tgt_mask=None, 66 | memory_mask=None, src_key_padding_mask=None, 67 | tgt_key_padding_mask=None, memory_key_padding_mask=None): 68 | r"""Take in and process masked source/target sequences. 69 | 70 | Args: 71 | src: the sequence to the encoder (required). 72 | tgt: the sequence to the decoder (required). 73 | src_mask: the additive mask for the src sequence (optional). 74 | tgt_mask: the additive mask for the tgt sequence (optional). 75 | memory_mask: the additive mask for the encoder output (optional). 76 | src_key_padding_mask: the ByteTensor mask for src keys per batch (optional). 77 | tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional). 78 | memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional). 79 | 80 | Shape: 81 | - src: :math:`(S, N, E)`. 82 | - tgt: :math:`(T, N, E)`. 83 | - src_mask: :math:`(S, S)`. 84 | - tgt_mask: :math:`(T, T)`. 85 | - memory_mask: :math:`(T, S)`. 86 | - src_key_padding_mask: :math:`(N, S)`. 87 | - tgt_key_padding_mask: :math:`(N, T)`. 88 | - memory_key_padding_mask: :math:`(N, S)`. 89 | 90 | Note: [src/tgt/memory]_mask should be filled with 91 | float('-inf') for the masked positions and float(0.0) else. These masks 92 | ensure that predictions for position i depend only on the unmasked positions 93 | j and are applied identically for each sequence in a batch. 94 | [src/tgt/memory]_key_padding_mask should be a ByteTensor where True values are positions 95 | that should be masked with float('-inf') and False values will be unchanged. 96 | This mask ensures that no information will be taken from position i if 97 | it is masked, and has a separate mask for each sequence in a batch. 98 | 99 | - output: :math:`(T, N, E)`. 100 | 101 | Note: Due to the multi-head attention architecture in the transformer model, 102 | the output sequence length of a transformer is same as the input sequence 103 | (i.e. target) length of the decode. 104 | 105 | where S is the source sequence length, T is the target sequence length, N is the 106 | batch size, E is the feature number 107 | 108 | Examples: 109 | >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask) 110 | """ 111 | 112 | if src.size(1) != tgt.size(1): 113 | raise RuntimeError("the batch number of src and tgt must be equal") 114 | 115 | if src.size(2) != self.d_model or tgt.size(2) != self.d_model: 116 | raise RuntimeError("the feature number of src and tgt must be equal to d_model") 117 | 118 | memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask) 119 | output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, 120 | tgt_key_padding_mask=tgt_key_padding_mask, 121 | memory_key_padding_mask=memory_key_padding_mask) 122 | return output 123 | 124 | def generate_square_subsequent_mask(self, sz): 125 | r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). 126 | Unmasked positions are filled with float(0.0). 127 | """ 128 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 129 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 130 | return mask 131 | 132 | def _reset_parameters(self): 133 | r"""Initiate parameters in the transformer model.""" 134 | 135 | for p in self.parameters(): 136 | if p.dim() > 1: 137 | xavier_uniform_(p) 138 | 139 | 140 | class TransformerEncoder(Module): 141 | r"""TransformerEncoder is a stack of N encoder layers 142 | 143 | Args: 144 | encoder_layer: an instance of the TransformerEncoderLayer() class (required). 145 | num_layers: the number of sub-encoder-layers in the encoder (required). 146 | norm: the layer normalization component (optional). 147 | 148 | Examples:: 149 | >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) 150 | >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) 151 | >>> src = torch.rand(10, 32, 512) 152 | >>> out = transformer_encoder(src) 153 | """ 154 | 155 | def __init__(self, encoder_layer, num_layers, norm=None): 156 | super(TransformerEncoder, self).__init__() 157 | self.layers = _get_clones(encoder_layer, num_layers) 158 | self.num_layers = num_layers 159 | self.norm = norm 160 | 161 | def forward(self, src, mask=None, src_key_padding_mask=None, return_atten=False): 162 | r"""Pass the input through the encoder layers in turn. 163 | 164 | Args: 165 | src: the sequnce to the encoder (required). 166 | mask: the mask for the src sequence (optional). 167 | src_key_padding_mask: the mask for the src keys per batch (optional). 168 | 169 | Shape: 170 | see the docs in Transformer class. 171 | """ 172 | output = src 173 | atten_probs_list = [] 174 | for i in range(self.num_layers): 175 | if return_atten: 176 | output, self_atten_probs = self.layers[i](output, src_mask=mask, 177 | src_key_padding_mask=src_key_padding_mask, 178 | return_atten=True) 179 | atten_probs_list.append(self_atten_probs) 180 | else: 181 | output = self.layers[i](output, src_mask=mask, 182 | src_key_padding_mask=src_key_padding_mask, 183 | return_atten=False) 184 | 185 | if self.norm: 186 | output = self.norm(output) 187 | if return_atten: 188 | return output, atten_probs_list 189 | return output 190 | 191 | 192 | class TransformerDecoder(Module): 193 | r"""TransformerDecoder is a stack of N decoder layers 194 | 195 | Args: 196 | decoder_layer: an instance of the TransformerDecoderLayer() class (required). 197 | num_layers: the number of sub-decoder-layers in the decoder (required). 198 | norm: the layer normalization component (optional). 199 | 200 | Examples:: 201 | >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) 202 | >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) 203 | >>> memory = torch.rand(10, 32, 512) 204 | >>> tgt = torch.rand(20, 32, 512) 205 | >>> out = transformer_decoder(tgt, memory) 206 | """ 207 | 208 | def __init__(self, decoder_layer, num_layers, norm=None): 209 | super(TransformerDecoder, self).__init__() 210 | self.layers = _get_clones(decoder_layer, num_layers) 211 | self.num_layers = num_layers 212 | self.norm = norm 213 | 214 | def forward(self, tgt, memory, tgt_mask=None, 215 | memory_mask=None, tgt_key_padding_mask=None, 216 | memory_key_padding_mask=None, return_atten=False): 217 | r"""Pass the inputs (and mask) through the decoder layer in turn. 218 | 219 | Args: 220 | tgt: the sequence to the decoder (required). 221 | memory: the sequnce from the last layer of the encoder (required). 222 | tgt_mask: the mask for the tgt sequence (optional). 223 | memory_mask: the mask for the memory sequence (optional). 224 | tgt_key_padding_mask: the mask for the tgt keys per batch (optional). 225 | memory_key_padding_mask: the mask for the memory keys per batch (optional). 226 | 227 | Shape: 228 | see the docs in Transformer class. 229 | """ 230 | output = tgt 231 | atten_probs_list = [] 232 | for i in range(self.num_layers): 233 | if return_atten: 234 | output, atten_probs_tuple = self.layers[i](output, memory, tgt_mask=tgt_mask, 235 | memory_mask=memory_mask, 236 | tgt_key_padding_mask=tgt_key_padding_mask, 237 | memory_key_padding_mask=memory_key_padding_mask, 238 | return_atten=True) 239 | atten_probs_list.append(atten_probs_tuple) 240 | else: 241 | output = self.layers[i](output, memory, tgt_mask=tgt_mask, 242 | memory_mask=memory_mask, 243 | tgt_key_padding_mask=tgt_key_padding_mask, 244 | memory_key_padding_mask=memory_key_padding_mask, 245 | return_atten=False) 246 | 247 | if self.norm: 248 | output = self.norm(output) 249 | 250 | if return_atten: 251 | return output, atten_probs_list 252 | return output 253 | 254 | class TransformerEncoderLayer(Module): 255 | r"""TransformerEncoderLayer is made up of self-attn and feedforward network. 256 | This standard encoder layer is based on the paper "Attention Is All You Need". 257 | Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, 258 | Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in 259 | Neural Information Processing Systems, pages 6000-6010. Users may modify or implement 260 | in a different way during application. 261 | 262 | Args: 263 | d_model: the number of expected features in the input (required). 264 | nhead: the number of heads in the multiheadattention models (required). 265 | dim_feedforward: the dimension of the feedforward network model (default=2048). 266 | dropout: the dropout value (default=0.1). 267 | activation: the activation function of intermediate layer, relu or gelu (default=relu). 268 | 269 | Examples:: 270 | >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) 271 | >>> src = torch.rand(10, 32, 512) 272 | >>> out = encoder_layer(src) 273 | """ 274 | 275 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"): 276 | super(TransformerEncoderLayer, self).__init__() 277 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) 278 | # Implementation of Feedforward model 279 | if activation == "glu": 280 | self.linear1 = Linear(d_model, 2*dim_feedforward) 281 | else: 282 | self.linear1 = Linear(d_model, dim_feedforward) 283 | self.dropout = Dropout(dropout) 284 | self.linear2 = Linear(dim_feedforward, d_model) 285 | 286 | self.norm1 = LayerNorm(d_model) 287 | self.norm2 = LayerNorm(d_model) 288 | self.dropout1 = Dropout(dropout) 289 | self.dropout2 = Dropout(dropout) 290 | 291 | self.activation = _get_activation_fn(activation) 292 | 293 | def forward(self, src, src_mask=None, src_key_padding_mask=None, return_atten=False): 294 | r"""Pass the input through the encoder layer. 295 | 296 | Args: 297 | src: the sequnce to the encoder layer (required). 298 | src_mask: the mask for the src sequence (optional). 299 | src_key_padding_mask: the mask for the src keys per batch (optional). 300 | 301 | Shape: 302 | see the docs in Transformer class. 303 | """ 304 | src2, self_atten_probs = self.self_attn(src, src, src, attn_mask=src_mask, 305 | key_padding_mask=src_key_padding_mask) 306 | src = src + self.dropout1(src2) 307 | src = self.norm1(src) 308 | if hasattr(self, "activation"): 309 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 310 | else: # for backward compatibility 311 | src2 = self.linear2(self.dropout(F.relu(self.linear1(src)))) 312 | src = src + self.dropout2(src2) 313 | src = self.norm2(src) 314 | if return_atten: 315 | return src, self_atten_probs 316 | return src 317 | 318 | 319 | class TransformerDecoderLayer(Module): 320 | r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. 321 | This standard decoder layer is based on the paper "Attention Is All You Need". 322 | Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, 323 | Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in 324 | Neural Information Processing Systems, pages 6000-6010. Users may modify or implement 325 | in a different way during application. 326 | 327 | Args: 328 | d_model: the number of expected features in the input (required). 329 | nhead: the number of heads in the multiheadattention models (required). 330 | dim_feedforward: the dimension of the feedforward network model (default=2048). 331 | dropout: the dropout value (default=0.1). 332 | activation: the activation function of intermediate layer, relu or gelu (default=relu). 333 | 334 | Examples:: 335 | >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) 336 | >>> memory = torch.rand(10, 32, 512) 337 | >>> tgt = torch.rand(20, 32, 512) 338 | >>> out = decoder_layer(tgt, memory) 339 | """ 340 | 341 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"): 342 | super(TransformerDecoderLayer, self).__init__() 343 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) 344 | self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout) 345 | # Implementation of Feedforward model 346 | if activation == "glu": 347 | self.linear1 = Linear(d_model, 2*dim_feedforward) 348 | else: 349 | self.linear1 = Linear(d_model, dim_feedforward) 350 | self.dropout = Dropout(dropout) 351 | self.linear2 = Linear(dim_feedforward, d_model) 352 | 353 | self.norm1 = LayerNorm(d_model) 354 | self.norm2 = LayerNorm(d_model) 355 | self.norm3 = LayerNorm(d_model) 356 | self.dropout1 = Dropout(dropout) 357 | self.dropout2 = Dropout(dropout) 358 | self.dropout3 = Dropout(dropout) 359 | 360 | self.activation = _get_activation_fn(activation) 361 | 362 | def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, 363 | tgt_key_padding_mask=None, memory_key_padding_mask=None, return_atten=False): 364 | r"""Pass the inputs (and mask) through the decoder layer. 365 | 366 | Args: 367 | tgt: the sequence to the decoder layer (required). 368 | memory: the sequnce from the last layer of the encoder (required). 369 | tgt_mask: the mask for the tgt sequence (optional). 370 | memory_mask: the mask for the memory sequence (optional). 371 | tgt_key_padding_mask: the mask for the tgt keys per batch (optional). 372 | memory_key_padding_mask: the mask for the memory keys per batch (optional). 373 | 374 | Shape: 375 | see the docs in Transformer class. 376 | """ 377 | tgt2, self_atten_probs = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, 378 | key_padding_mask=tgt_key_padding_mask) 379 | tgt = tgt + self.dropout1(tgt2) 380 | tgt = self.norm1(tgt) 381 | tgt2, enc_dec_atten_probs = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, 382 | key_padding_mask=memory_key_padding_mask) 383 | tgt = tgt + self.dropout2(tgt2) 384 | tgt = self.norm2(tgt) 385 | if hasattr(self, "activation"): 386 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 387 | else: # for backward compatibility 388 | tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt)))) 389 | tgt = tgt + self.dropout3(tgt2) 390 | tgt = self.norm3(tgt) 391 | if return_atten: 392 | return tgt, (self_atten_probs, enc_dec_atten_probs) 393 | return tgt 394 | 395 | 396 | def _get_clones(module, N): 397 | return ModuleList([copy.deepcopy(module) for i in range(N)]) 398 | 399 | 400 | def _get_activation_fn(activation): 401 | if activation == "relu": 402 | return F.relu 403 | elif activation == "gelu": 404 | return F.gelu 405 | elif activation == "glu": 406 | return F.glu 407 | else: 408 | raise RuntimeError("activation should be relu/gelu, not %s." % activation) 409 | -------------------------------------------------------------------------------- /src/third_party/wavfile.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module to read / write wav files using numpy arrays 3 | 4 | Functions 5 | --------- 6 | `read`: Return the sample rate (in samples/sec) and data from a WAV file. 7 | 8 | `write`: Write a numpy array as a WAV file. 9 | 10 | """ 11 | from __future__ import division, print_function, absolute_import 12 | 13 | import sys 14 | import numpy 15 | import struct 16 | import warnings 17 | 18 | 19 | __all__ = [ 20 | 'WavFileWarning', 21 | 'read', 22 | 'write' 23 | ] 24 | 25 | 26 | class WavFileWarning(UserWarning): 27 | pass 28 | 29 | 30 | WAVE_FORMAT_PCM = 0x0001 31 | WAVE_FORMAT_IEEE_FLOAT = 0x0003 32 | WAVE_FORMAT_EXTENSIBLE = 0xfffe 33 | KNOWN_WAVE_FORMATS = (WAVE_FORMAT_PCM, WAVE_FORMAT_IEEE_FLOAT) 34 | 35 | # assumes file pointer is immediately 36 | # after the 'fmt ' id 37 | 38 | 39 | def _read_fmt_chunk(fid, is_big_endian): 40 | """ 41 | Returns 42 | ------- 43 | size : int 44 | size of format subchunk in bytes (minus 8 for "fmt " and itself) 45 | format_tag : int 46 | PCM, float, or compressed format 47 | channels : int 48 | number of channels 49 | fs : int 50 | sampling frequency in samples per second 51 | bytes_per_second : int 52 | overall byte rate for the file 53 | block_align : int 54 | bytes per sample, including all channels 55 | bit_depth : int 56 | bits per sample 57 | """ 58 | if is_big_endian: 59 | fmt = '>' 60 | else: 61 | fmt = '<' 62 | 63 | size = res = struct.unpack(fmt+'I', fid.read(4))[0] 64 | bytes_read = 0 65 | 66 | if size < 16: 67 | raise ValueError("Binary structure of wave file is not compliant") 68 | 69 | res = struct.unpack(fmt+'HHIIHH', fid.read(16)) 70 | bytes_read += 16 71 | 72 | format_tag, channels, fs, bytes_per_second, block_align, bit_depth = res 73 | 74 | if format_tag == WAVE_FORMAT_EXTENSIBLE and size >= (16+2): 75 | ext_chunk_size = struct.unpack(fmt+'H', fid.read(2))[0] 76 | bytes_read += 2 77 | if ext_chunk_size >= 22: 78 | extensible_chunk_data = fid.read(22) 79 | bytes_read += 22 80 | raw_guid = extensible_chunk_data[2+4:2+4+16] 81 | # GUID template {XXXXXXXX-0000-0010-8000-00AA00389B71} (RFC-2361) 82 | # MS GUID byte order: first three groups are native byte order, 83 | # rest is Big Endian 84 | if is_big_endian: 85 | tail = b'\x00\x00\x00\x10\x80\x00\x00\xAA\x00\x38\x9B\x71' 86 | else: 87 | tail = b'\x00\x00\x10\x00\x80\x00\x00\xAA\x00\x38\x9B\x71' 88 | if raw_guid.endswith(tail): 89 | format_tag = struct.unpack(fmt+'I', raw_guid[:4])[0] 90 | else: 91 | raise ValueError("Binary structure of wave file is not compliant") 92 | 93 | if format_tag not in KNOWN_WAVE_FORMATS: 94 | raise ValueError("Unknown wave file format") 95 | 96 | # move file pointer to next chunk 97 | if size > (bytes_read): 98 | fid.read(size - bytes_read) 99 | 100 | return (size, format_tag, channels, fs, bytes_per_second, block_align, 101 | bit_depth) 102 | 103 | 104 | # assumes file pointer is immediately after the 'data' id 105 | def _read_data_chunk(fid, format_tag, channels, bit_depth, is_big_endian, 106 | mmap=False): 107 | if is_big_endian: 108 | fmt = '>I' 109 | else: 110 | fmt = ' 1: 137 | data = data.reshape(-1, channels) 138 | return data 139 | 140 | 141 | def _skip_unknown_chunk(fid, is_big_endian): 142 | if is_big_endian: 143 | fmt = '>I' 144 | else: 145 | fmt = ' 0xFFFFFFFF: 374 | raise ValueError("Data exceeds wave file size limit") 375 | 376 | fid.write(header_data) 377 | 378 | # data chunk 379 | fid.write(b'data') 380 | fid.write(struct.pack('' or (data.dtype.byteorder == '=' and 382 | sys.byteorder == 'big'): 383 | data = data.byteswap() 384 | _array_tofile(fid, data) 385 | 386 | # Determine file size and place it in correct 387 | # position at start of the file. 388 | size = fid.tell() 389 | fid.seek(4) 390 | fid.write(struct.pack('= 3: 400 | def _array_tofile(fid, data): 401 | # ravel gives a c-contiguous buffer 402 | fid.write(data.ravel().view('b').data) 403 | else: 404 | def _array_tofile(fid, data): 405 | fid.write(data.tostring()) 406 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Ye Bai by1993@qq.com 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | import sys 17 | import os 18 | import argparse 19 | import logging 20 | import yaml 21 | import torch 22 | 23 | if "LAS_LOG_LEVEL" in os.environ: 24 | LOG_LEVEL = os.environ["LAS_LOG_LEVEL"] 25 | else: 26 | LOG_LEVEL = "INFO" 27 | if LOG_LEVEL == "DEBUG": 28 | logging.basicConfig( 29 | level=logging.DEBUG, 30 | format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s') 31 | else: 32 | logging.basicConfig( 33 | level=logging.INFO, 34 | format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s') 35 | 36 | import utils 37 | import data 38 | import sp_layers 39 | import encoder_layers 40 | import decoder_layers 41 | import lm_layers 42 | import models 43 | 44 | from trainer import Trainer 45 | 46 | 47 | def get_args(): 48 | parser = argparse.ArgumentParser(description=""" 49 | Usage: train.py """) 50 | parser.add_argument("config", help="path to config file") 51 | parser.add_argument('--continue-training', type=utils.str2bool, default=False, 52 | help='Continue training from last_model.pt.') 53 | args = parser.parse_args() 54 | return args 55 | 56 | 57 | if __name__ == "__main__": 58 | timer = utils.Timer() 59 | x = torch.zeros(2) 60 | x.cuda() # for initialize gpu 61 | 62 | args = get_args() 63 | timer.tic() 64 | with open(args.config) as f: 65 | config = yaml.load(f, Loader=yaml.FullLoader) 66 | dataconfig = config["data"] 67 | trainingconfig = config["training"] 68 | modelconfig = config["model"] 69 | 70 | training_set = data.SpeechDataset(dataconfig["trainset"]) 71 | valid_set = data.SpeechDataset(dataconfig["devset"], reverse=True) 72 | if "vocab_path" in dataconfig: 73 | tokenizer = data.CharTokenizer(dataconfig["vocab_path"]) 74 | else: 75 | raise ValueError("Unknown tokenizer.") 76 | if modelconfig['signal']["feature_type"] == 'offline': 77 | collate = data.FeatureCollate(tokenizer, dataconfig["maxlen"]) 78 | else: 79 | collate = data.WaveCollate(tokenizer, dataconfig["maxlen"]) 80 | 81 | ngpu = 1 82 | if "multi_gpu" in trainingconfig and trainingconfig["multi_gpu"] == True: 83 | ngpu = torch.cuda.device_count() 84 | trainingsampler = data.TimeBasedSampler(training_set, trainingconfig["batch_time"]*ngpu, ngpu, shuffle=True) 85 | validsampler = data.TimeBasedSampler(valid_set, trainingconfig["batch_time"]*ngpu, ngpu, shuffle=False) # for plot longer utterance 86 | 87 | tr_loader = torch.utils.data.DataLoader(training_set, 88 | collate_fn=collate, batch_sampler=trainingsampler, shuffle=False, num_workers=dataconfig["fetchworker_num"]) 89 | cv_loader = torch.utils.data.DataLoader(valid_set, 90 | collate_fn=collate, batch_sampler=validsampler, shuffle=False, num_workers=dataconfig["fetchworker_num"]) 91 | 92 | splayer = sp_layers.SPLayer(modelconfig["signal"]) 93 | encoder = encoder_layers.Transformer(modelconfig["encoder"]) 94 | modelconfig["decoder"]["vocab_size"] = tokenizer.unit_num() 95 | decoder = decoder_layers.TransformerDecoder(modelconfig["decoder"]) 96 | 97 | lm = None 98 | if "lst" in trainingconfig: 99 | logging.info("Load language model package from {} for LST training.".format(trainingconfig["lst"]["lm_path"])) 100 | lmpkg = torch.load(trainingconfig["lst"]["lm_path"], map_location=lambda storage, loc: storage) 101 | lmconfig = lmpkg["model"]["lm_config"] 102 | if lmconfig["type"] == "lstm": 103 | lmlayer = lm_layers.LSTM(lmconfig) 104 | else: 105 | raise ValueError("Unknown model") 106 | 107 | lm = models.LM(lmlayer) 108 | logging.info("\nLM info:\n{}".format(lm)) 109 | lm.restore(lmpkg["model"]) 110 | 111 | model = models.Model(splayer, encoder, decoder, lm=lm) 112 | logging.info("\nModel info:\n{}".format(model)) 113 | 114 | if args.continue_training: 115 | logging.info("Load package from {}.".format(os.path.join(trainingconfig["exp_dir"], "last-ckpt.pt"))) 116 | pkg = torch.load(os.path.join(trainingconfig["exp_dir"], "last-ckpt.pt")) 117 | model.restore(pkg["model"]) 118 | 119 | if "multi_gpu" in trainingconfig and trainingconfig["multi_gpu"] == True: 120 | logging.info("Let's use {} GPUs!".format(torch.cuda.device_count())) 121 | model = torch.nn.DataParallel(model) 122 | 123 | model = model.cuda() 124 | 125 | trainer = Trainer(model, trainingconfig, tr_loader, cv_loader) 126 | 127 | if args.continue_training: 128 | logging.info("Restore trainer states...") 129 | trainer.restore(pkg) 130 | logging.info("Start training...") 131 | trainer.train() 132 | logging.info("Total time: {:.4f} secs".format(timer.toc())) 133 | 134 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import logging 4 | import subprocess 5 | import time 6 | import math 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from third_party import wavfile 11 | from third_party import kaldi_io as kio 12 | 13 | TENSORBOARD_LOGGING = 1 14 | 15 | def cleanup_ckpt(expdir, num_last_ckpt_keep): 16 | ckptlist = [t for t in os.listdir(expdir) if t.endswith('.pt') and t != 'last-ckpt.pt'] 17 | ckptlist = sorted(ckptlist) 18 | ckptlist_rm = ckptlist[:-num_last_ckpt_keep] 19 | logging.info("Clean up checkpoints. Remain the last {} checkpoints.".format(num_last_ckpt_keep)) 20 | for name in ckptlist_rm: 21 | os.remove(os.path.join(expdir, name)) 22 | 23 | 24 | def get_command_stdout(command, require_zero_status=True): 25 | """ Executes a command and returns its stdout output as a string. The 26 | command is executed with shell=True, so it may contain pipes and 27 | other shell constructs. 28 | 29 | If require_zero_stats is True, this function will raise an exception if 30 | the command has nonzero exit status. If False, it just prints a warning 31 | if the exit status is nonzero. 32 | 33 | See also: execute_command, background_command 34 | """ 35 | p = subprocess.Popen(command, shell=True, 36 | stdout=subprocess.PIPE) 37 | 38 | stdout = p.communicate()[0] 39 | if p.returncode is not 0: 40 | output = "Command exited with status {0}: {1}".format( 41 | p.returncode, command) 42 | if require_zero_status: 43 | raise Exception(output) 44 | else: 45 | logger.warning(output) 46 | return stdout 47 | 48 | def load_wave(path): 49 | """ 50 | path can be wav filename or pipeline 51 | """ 52 | 53 | # parse path 54 | items = path.strip().split(":", 1) 55 | if len(items) != 2: 56 | raise ValueError("Unknown path format.") 57 | tag = items[0] 58 | path = items[1] 59 | if tag == "file": 60 | sample_rate, data = wavfile.read(path) 61 | elif tag == "pipe": 62 | path = path[:-1] 63 | out = get_command_stdout(path, require_zero_status=True) 64 | sample_rate, data = wavfile.read(io.BytesIO(out)) 65 | elif tag == "ark": 66 | fn, offset = path.split(":", 1) 67 | offset = int(offset) 68 | with open(fn, 'rb') as f: 69 | f.seek(offset) 70 | sample_rate, data = wavfile.read(f, offset=offset) 71 | else: 72 | raise ValueError("Unknown file tag.") 73 | data = data.astype(np.float32) 74 | return sample_rate, data 75 | 76 | 77 | def load_feat(path): 78 | items = path.strip().split(":", 1) 79 | if len(items) != 2: 80 | raise ValueError("Unknown path format.") 81 | tag = items[0] 82 | path = items[1] 83 | if tag == "ark": 84 | return kio.read_mat(path) 85 | else: 86 | raise ValueError("Unknown file tag.") 87 | 88 | 89 | def parse_scp(fn): 90 | dic = {} 91 | with open(fn, 'r') as f: 92 | cnt = 0 93 | for line in f: 94 | cnt += 1 95 | items = line.strip().split(' ', 1) 96 | if len(items) != 2: 97 | logging.warning('Wrong formated line {} in scp {}, skip it.'.format(cnt, fn)) 98 | continue 99 | dic[items[0]] = items[1] 100 | return dic 101 | 102 | def str2bool(v): 103 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 104 | return True 105 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 106 | return False 107 | else: 108 | raise argparse.ArgumentTypeError('Unsupported value encountered.') 109 | 110 | class Timer(object): 111 | def __init__(self): 112 | self.start = 0. 113 | 114 | def tic(self): 115 | self.start = time.time() 116 | 117 | def toc(self): 118 | return time.time() - self.start 119 | 120 | 121 | 122 | # ========================================== 123 | # auxilary functions for sequence 124 | # ========================================== 125 | 126 | def get_paddings(src, lengths): 127 | paddings = torch.zeros_like(src).to(src.device) 128 | for b in range(lengths.shape[0]): 129 | paddings[b, lengths[b]:, :] = 1 130 | return paddings 131 | 132 | def get_paddings_by_shape(shape, lengths, device="cpu"): 133 | paddings = torch.zeros(shape).to(device) 134 | if shape[0] != lengths.shape[0]: 135 | raise ValueError("shape[0] does not match lengths.shape[0]:" 136 | " {} vs. {}".format(shape[0], lengths.shape[0])) 137 | T = shape[1] 138 | for b in range(shape[0]): 139 | if lengths[b] < T: 140 | l = lengths[b] 141 | paddings[b, l:] = 1 142 | return paddings 143 | 144 | def get_transformer_padding_byte_masks(B, T, lengths): 145 | masks = get_paddings_by_shape([B, T], lengths).byte() 146 | return masks 147 | 148 | def get_transformer_casual_masks(T): 149 | masks = -torch.triu( 150 | torch.ones(T, T), diagonal=1)*9e20 151 | return masks 152 | 153 | 154 | # ========================================== 155 | # visualization 156 | # ========================================== 157 | if TENSORBOARD_LOGGING == 1: 158 | import logging 159 | mpl_logger = logging.getLogger("matplotlib") 160 | mpl_logger.setLevel(logging.WARNING) 161 | 162 | import matplotlib as mpl 163 | mpl.use('Agg') 164 | import matplotlib.pyplot as plt 165 | from tensorboardX import SummaryWriter 166 | 167 | class Visualizer(object): 168 | def __init__(self): 169 | self.writer = None 170 | self.fig_step = 0 171 | 172 | def set_writer(self, log_dir): 173 | if self.writer is not None: 174 | raise ValueError("Dont set writer twice.") 175 | self.writer = SummaryWriter(log_dir) 176 | 177 | def add_scalar(self, tag, value, step): 178 | self.writer.add_scalar(tag=tag, 179 | scalar_value=value, global_step=step) 180 | 181 | def add_graph(self, model): 182 | self.writer.add_graph(model) 183 | 184 | def add_image(self, tag, img, data_formats): 185 | self.writer.add_image(tag, 186 | img, 0, dataformats=data_formats) 187 | 188 | def add_img_figure(self, tag, img, step=None): 189 | fig, axes = plt.subplots(1,1) 190 | axes.imshow(img) 191 | self.writer.add_figure(tag, fig, global_step=step) 192 | 193 | def close(self): 194 | self.writer.close() 195 | 196 | visualizer = Visualizer() 197 | 198 | -------------------------------------------------------------------------------- /src/utils_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import subprocess 4 | import numpy as np 5 | import torch 6 | 7 | import utils 8 | 9 | os.chdir(os.path.abspath(os.path.dirname(__file__))) 10 | 11 | def test_cleanup(): 12 | expdir = "testdata/cleanup" 13 | os.makedirs(expdir) 14 | for i in range(120): 15 | with open(os.path.join(expdir, "ckpt-{:04d}.pt".format(i)), 'w') as f: 16 | f.write("") 17 | with open(os.path.join(expdir, "last-ckpt.pt"), 'w') as f: 18 | f.write("") 19 | utils.cleanup_ckpt(expdir, 3) 20 | len(os.listdir(expdir)) == 4 21 | 22 | 23 | def test_read_wave_from_pipe(): 24 | command = "flac -c -d -s testdata/100-121669-0000.flac " 25 | output = utils.get_command_stdout(command) 26 | with open("testdata/100-121669-0000.wav", 'rb') as f: 27 | wav_content = f.read() 28 | assert output == wav_content 29 | 30 | def test_load_wave(): 31 | pipe = "pipe:flac -c -d -s testdata/100-121669-0000.flac | " 32 | fn = "file:testdata/100-121669-0000.wav" 33 | ark = "ark:/data1/Corpora/LibriSpeech/ark/train_960.ark:16" 34 | ark2 = "ark:/data1/Corpora/LibriSpeech/ark/train_960.ark:2591436" 35 | timer = utils.Timer() 36 | timer.tic() 37 | s3, d3 = utils.load_wave(ark) 38 | print("Load ark time: {}s".format(timer.toc())) 39 | timer.tic() 40 | s2, d2 = utils.load_wave(fn) 41 | print("Load file time: {}s".format(timer.toc())) 42 | timer.tic() 43 | s1, d1 = utils.load_wave(pipe) 44 | print("Load flac pipe time: {}s".format(timer.toc())) 45 | print("Load ark2") 46 | 47 | s, d = utils.load_wave(ark2) 48 | 49 | assert s1 == s2 50 | assert s3 == s2 51 | assert np.sum(d1!=d2) == 0 52 | assert np.sum(d3!=d2) == 0 53 | 54 | def test_get_transformer_casual_masks(): 55 | print('test_get_transformer_casual_masks') 56 | print(utils.get_transformer_casual_masks(5)) 57 | 58 | def test_get_transformer_padding_byte_masks(): 59 | B = 3 60 | T = 5 61 | lengths = torch.tensor([3, 4, 5]).long() 62 | masks = utils.get_transformer_padding_byte_masks(B, T, lengths) 63 | print('test_get_transformer_padding_byte_masks') 64 | print(masks) 65 | 66 | if __name__ == "__main__": 67 | test_cleanup() 68 | test_read_wave_from_pipe() 69 | test_load_wave() 70 | test_get_transformer_casual_masks() 71 | test_get_transformer_padding_byte_masks() 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /tools/combine_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey). Apache 2.0. 3 | # 2014 David Snyder 4 | 5 | # This script combines the data from multiple source directories into 6 | # a single destination directory. 7 | 8 | # See http://kaldi-asr.org/doc/data_prep.html#data_prep_data for information 9 | # about what these directories contain. 10 | 11 | # Begin configuration section. 12 | extra_files= # specify additional files in 'src-data-dir' to merge, ex. "file1 file2 ..." 13 | skip_fix=false # skip the fix_data_dir.sh in the end 14 | # End configuration section. 15 | 16 | echo "$0 $@" # Print the command line for logging 17 | 18 | if [ -f path.sh ]; then . ./path.sh; fi 19 | . parse_options.sh || exit 1; 20 | 21 | if [ $# -lt 2 ]; then 22 | echo "Usage: combine_data.sh [--extra-files 'file1 file2'] ..." 23 | echo "Note, files that don't appear in all source dirs will not be combined," 24 | echo "with the exception of utt2uniq and segments, which are created where necessary." 25 | exit 1 26 | fi 27 | 28 | dest=$1; 29 | shift; 30 | 31 | first_src=$1; 32 | 33 | rm -r $dest 2>/dev/null 34 | mkdir -p $dest; 35 | 36 | export LC_ALL=C 37 | 38 | for dir in $*; do 39 | if [ ! -f $dir/utt2spk ]; then 40 | echo "$0: no such file $dir/utt2spk" 41 | exit 1; 42 | fi 43 | done 44 | 45 | # Check that frame_shift are compatible, where present together with features. 46 | dir_with_frame_shift= 47 | for dir in $*; do 48 | if [[ -f $dir/feats.scp && -f $dir/frame_shift ]]; then 49 | if [[ $dir_with_frame_shift ]] && 50 | ! cmp -s $dir_with_frame_shift/frame_shift $dir/frame_shift; then 51 | echo "$0:error: different frame_shift in directories $dir and " \ 52 | "$dir_with_frame_shift. Cannot combine features." 53 | exit 1; 54 | fi 55 | dir_with_frame_shift=$dir 56 | fi 57 | done 58 | 59 | # W.r.t. utt2uniq file the script has different behavior compared to other files 60 | # it is not compulsary for it to exist in src directories, but if it exists in 61 | # even one it should exist in all. We will create the files where necessary 62 | has_utt2uniq=false 63 | for in_dir in $*; do 64 | if [ -f $in_dir/utt2uniq ]; then 65 | has_utt2uniq=true 66 | break 67 | fi 68 | done 69 | 70 | if $has_utt2uniq; then 71 | # we are going to create an utt2uniq file in the destdir 72 | for in_dir in $*; do 73 | if [ ! -f $in_dir/utt2uniq ]; then 74 | # we assume that utt2uniq is a one to one mapping 75 | cat $in_dir/utt2spk | awk '{printf("%s %s\n", $1, $1);}' 76 | else 77 | cat $in_dir/utt2uniq 78 | fi 79 | done | sort -k1 > $dest/utt2uniq 80 | echo "$0: combined utt2uniq" 81 | else 82 | echo "$0 [info]: not combining utt2uniq as it does not exist" 83 | fi 84 | # some of the old scripts might provide utt2uniq as an extrafile, so just remove it 85 | extra_files=$(echo "$extra_files"|sed -e "s/utt2uniq//g") 86 | 87 | # segments are treated similarly to utt2uniq. If it exists in some, but not all 88 | # src directories, then we generate segments where necessary. 89 | has_segments=false 90 | for in_dir in $*; do 91 | if [ -f $in_dir/segments ]; then 92 | has_segments=true 93 | break 94 | fi 95 | done 96 | 97 | if $has_segments; then 98 | for in_dir in $*; do 99 | if [ ! -f $in_dir/segments ]; then 100 | echo "$0 [info]: will generate missing segments for $in_dir" 1>&2 101 | utils/data/get_segments_for_data.sh $in_dir 102 | else 103 | cat $in_dir/segments 104 | fi 105 | done | sort -k1 > $dest/segments 106 | echo "$0: combined segments" 107 | else 108 | echo "$0 [info]: not combining segments as it does not exist" 109 | fi 110 | 111 | for file in utt2spk utt2lang utt2dur utt2num_frames reco2dur feats.scp text cmvn.scp vad.scp reco2file_and_channel wav.scp spk2gender $extra_files; do 112 | exists_somewhere=false 113 | absent_somewhere=false 114 | for d in $*; do 115 | if [ -f $d/$file ]; then 116 | exists_somewhere=true 117 | else 118 | absent_somewhere=true 119 | fi 120 | done 121 | 122 | if ! $absent_somewhere; then 123 | set -o pipefail 124 | ( for f in $*; do cat $f/$file; done ) | sort -k1 > $dest/$file || exit 1; 125 | set +o pipefail 126 | echo "$0: combined $file" 127 | else 128 | if ! $exists_somewhere; then 129 | echo "$0 [info]: not combining $file as it does not exist" 130 | else 131 | echo "$0 [info]: **not combining $file as it does not exist everywhere**" 132 | fi 133 | fi 134 | done 135 | 136 | $TOOLS_ROOT/utt2spk_to_spk2utt.pl <$dest/utt2spk >$dest/spk2utt 137 | if [[ $dir_with_frame_shift ]]; then 138 | cp $dir_with_frame_shift/frame_shift $dest 139 | fi 140 | 141 | 142 | exit 0 143 | -------------------------------------------------------------------------------- /tools/filter_scp.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # Copyright 2010-2012 Microsoft Corporation 3 | # Johns Hopkins University (author: Daniel Povey) 4 | 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 12 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 13 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 14 | # MERCHANTABLITY OR NON-INFRINGEMENT. 15 | # See the Apache 2 License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | 19 | # This script takes a list of utterance-ids or any file whose first field 20 | # of each line is an utterance-id, and filters an scp 21 | # file (or any file whose "n-th" field is an utterance id), printing 22 | # out only those lines whose "n-th" field is in id_list. The index of 23 | # the "n-th" field is 1, by default, but can be changed by using 24 | # the -f switch 25 | 26 | $exclude = 0; 27 | $field = 1; 28 | $shifted = 0; 29 | 30 | do { 31 | $shifted=0; 32 | if ($ARGV[0] eq "--exclude") { 33 | $exclude = 1; 34 | shift @ARGV; 35 | $shifted=1; 36 | } 37 | if ($ARGV[0] eq "-f") { 38 | $field = $ARGV[1]; 39 | shift @ARGV; shift @ARGV; 40 | $shifted=1 41 | } 42 | } while ($shifted); 43 | 44 | if(@ARGV < 1 || @ARGV > 2) { 45 | die "Usage: filter_scp.pl [--exclude] [-f ] id_list [in.scp] > out.scp \n" . 46 | "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" . 47 | "Note: only the first field of each line in id_list matters. With --exclude, prints\n" . 48 | "only the lines that were *not* in id_list.\n" . 49 | "Caution: previously, the -f option was interpreted as a zero-based field index.\n" . 50 | "If your older scripts (written before Oct 2014) stopped working and you used the\n" . 51 | "-f option, add 1 to the argument.\n" . 52 | "See also: utils/filter_scp.pl .\n"; 53 | } 54 | 55 | 56 | $idlist = shift @ARGV; 57 | open(F, "<$idlist") || die "Could not open id-list file $idlist"; 58 | while() { 59 | @A = split; 60 | @A>=1 || die "Invalid id-list file line $_"; 61 | $seen{$A[0]} = 1; 62 | } 63 | 64 | if ($field == 1) { # Treat this as special case, since it is common. 65 | while(<>) { 66 | $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field."; 67 | # $1 is what we filter on. 68 | if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) { 69 | print $_; 70 | } 71 | } 72 | } else { 73 | while(<>) { 74 | @A = split; 75 | @A > 0 || die "Invalid scp file line $_"; 76 | @A >= $field || die "Invalid scp file line $_"; 77 | if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) { 78 | print $_; 79 | } 80 | } 81 | } 82 | 83 | # tests: 84 | # the following should print "foo 1" 85 | # ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo) 86 | # the following should print "bar 2". 87 | # ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2) 88 | -------------------------------------------------------------------------------- /tools/int2sym.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # Copyright 2010-2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey) 3 | # Apache 2.0. 4 | 5 | undef $field_begin; 6 | undef $field_end; 7 | 8 | 9 | if ($ARGV[0] eq "-f") { 10 | shift @ARGV; 11 | $field_spec = shift @ARGV; 12 | if ($field_spec =~ m/^\d+$/) { 13 | $field_begin = $field_spec - 1; $field_end = $field_spec - 1; 14 | } 15 | if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesy (properly, 1-10) 16 | if ($1 ne "") { 17 | $field_begin = $1 - 1; # Change to zero-based indexing. 18 | } 19 | if ($2 ne "") { 20 | $field_end = $2 - 1; # Change to zero-based indexing. 21 | } 22 | } 23 | if (!defined $field_begin && !defined $field_end) { 24 | die "Bad argument to -f option: $field_spec"; 25 | } 26 | } 27 | $symtab = shift @ARGV; 28 | if(!defined $symtab) { 29 | print STDERR "Usage: int2sym.pl [options] symtab [input] > output\n" . 30 | "options: [-f (|-)]\n" . 31 | "e.g.: -f 2, or -f 3-4\n"; 32 | exit(1); 33 | } 34 | 35 | open(F, "<$symtab") || die "Error opening symbol table file $symtab"; 36 | while() { 37 | @A = split(" ", $_); 38 | @A == 2 || die "bad line in symbol table file: $_"; 39 | $int2sym{$A[1]} = $A[0]; 40 | } 41 | 42 | sub int2sym { 43 | my $a = shift @_; 44 | my $pos = shift @_; 45 | if($a !~ m:^\d+$:) { # not all digits.. 46 | $pos1 = $pos+1; # make it one-based. 47 | die "int2sym.pl: found noninteger token $a [in position $pos1]\n"; 48 | } 49 | $s = $int2sym{$a}; 50 | if(!defined ($s)) { 51 | die "int2sym.pl: integer $a not in symbol table $symtab."; 52 | } 53 | return $s; 54 | } 55 | 56 | $error = 0; 57 | while (<>) { 58 | @A = split(" ", $_); 59 | for ($pos = 0; $pos <= $#A; $pos++) { 60 | $a = $A[$pos]; 61 | if ( (!defined $field_begin || $pos >= $field_begin) 62 | && (!defined $field_end || $pos <= $field_end)) { 63 | $a = int2sym($a, $pos); 64 | } 65 | print $a . " "; 66 | } 67 | print "\n"; 68 | } 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /tools/parse_options.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey); 4 | # Arnab Ghoshal, Karel Vesely 5 | 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 14 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 15 | # MERCHANTABLITY OR NON-INFRINGEMENT. 16 | # See the Apache 2 License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | 20 | # Parse command-line options. 21 | # To be sourced by another script (as in ". parse_options.sh"). 22 | # Option format is: --option-name arg 23 | # and shell variable "option_name" gets set to value "arg." 24 | # The exception is --help, which takes no arguments, but prints the 25 | # $help_message variable (if defined). 26 | 27 | 28 | ### 29 | ### The --config file options have lower priority to command line 30 | ### options, so we need to import them first... 31 | ### 32 | 33 | # Now import all the configs specified by command-line, in left-to-right order 34 | for ((argpos=1; argpos<$#; argpos++)); do 35 | if [ "${!argpos}" == "--config" ]; then 36 | argpos_plus1=$((argpos+1)) 37 | config=${!argpos_plus1} 38 | [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 39 | . $config # source the config file. 40 | fi 41 | done 42 | 43 | 44 | ### 45 | ### Now we process the command line options 46 | ### 47 | while true; do 48 | [ -z "${1:-}" ] && break; # break if there are no arguments 49 | case "$1" in 50 | # If the enclosing script is called with --help option, print the help 51 | # message and exit. Scripts should put help messages in $help_message 52 | --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; 53 | else printf "$help_message\n" 1>&2 ; fi; 54 | exit 0 ;; 55 | --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" 56 | exit 1 ;; 57 | # If the first command-line argument begins with "--" (e.g. --foo-bar), 58 | # then work out the variable name as $name, which will equal "foo_bar". 59 | --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; 60 | # Next we test whether the variable in question is undefned-- if so it's 61 | # an invalid option and we die. Note: $0 evaluates to the name of the 62 | # enclosing script. 63 | # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar 64 | # is undefined. We then have to wrap this test inside "eval" because 65 | # foo_bar is itself inside a variable ($name). 66 | eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; 67 | 68 | oldval="`eval echo \\$$name`"; 69 | # Work out whether we seem to be expecting a Boolean argument. 70 | if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then 71 | was_bool=true; 72 | else 73 | was_bool=false; 74 | fi 75 | 76 | # Set the variable to the right value-- the escaped quotes make it work if 77 | # the option had spaces, like --cmd "queue.pl -sync y" 78 | eval $name=\"$2\"; 79 | 80 | # Check that Boolean-valued arguments are really Boolean. 81 | if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then 82 | echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 83 | exit 1; 84 | fi 85 | shift 2; 86 | ;; 87 | *) break; 88 | esac 89 | done 90 | 91 | 92 | # Check for an empty argument to the --cmd option, which can easily occur as a 93 | # result of scripting errors. 94 | [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; 95 | 96 | 97 | true; # so this script returns exit code 0. 98 | -------------------------------------------------------------------------------- /tools/run.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | use warnings; #sed replacement for -w perl parameter 3 | 4 | # In general, doing 5 | # run.pl some.log a b c is like running the command a b c in 6 | # the bash shell, and putting the standard error and output into some.log. 7 | # To run parallel jobs (backgrounded on the host machine), you can do (e.g.) 8 | # run.pl JOB=1:4 some.JOB.log a b c JOB is like running the command a b c JOB 9 | # and putting it in some.JOB.log, for each one. [Note: JOB can be any identifier]. 10 | # If any of the jobs fails, this script will fail. 11 | 12 | # A typical example is: 13 | # run.pl some.log my-prog "--opt=foo bar" foo \| other-prog baz 14 | # and run.pl will run something like: 15 | # ( my-prog '--opt=foo bar' foo | other-prog baz ) >& some.log 16 | # 17 | # Basically it takes the command-line arguments, quotes them 18 | # as necessary to preserve spaces, and evaluates them with bash. 19 | # In addition it puts the command line at the top of the log, and 20 | # the start and end times of the command at the beginning and end. 21 | # The reason why this is useful is so that we can create a different 22 | # version of this program that uses a queueing system instead. 23 | 24 | # use Data::Dumper; 25 | 26 | @ARGV < 2 && die "usage: run.pl log-file command-line arguments..."; 27 | 28 | 29 | $max_jobs_run = -1; 30 | $jobstart = 1; 31 | $jobend = 1; 32 | $ignored_opts = ""; # These will be ignored. 33 | 34 | # First parse an option like JOB=1:4, and any 35 | # options that would normally be given to 36 | # queue.pl, which we will just discard. 37 | 38 | for (my $x = 1; $x <= 2; $x++) { # This for-loop is to 39 | # allow the JOB=1:n option to be interleaved with the 40 | # options to qsub. 41 | while (@ARGV >= 2 && $ARGV[0] =~ m:^-:) { 42 | # parse any options that would normally go to qsub, but which will be ignored here. 43 | my $switch = shift @ARGV; 44 | if ($switch eq "-V") { 45 | $ignored_opts .= "-V "; 46 | } elsif ($switch eq "--max-jobs-run" || $switch eq "-tc") { 47 | # we do support the option --max-jobs-run n, and its GridEngine form -tc n. 48 | $max_jobs_run = shift @ARGV; 49 | if (! ($max_jobs_run > 0)) { 50 | die "run.pl: invalid option --max-jobs-run $max_jobs_run"; 51 | } 52 | } else { 53 | my $argument = shift @ARGV; 54 | if ($argument =~ m/^--/) { 55 | print STDERR "run.pl: WARNING: suspicious argument '$argument' to $switch; starts with '-'\n"; 56 | } 57 | if ($switch eq "-sync" && $argument =~ m/^[yY]/) { 58 | $ignored_opts .= "-sync "; # Note: in the 59 | # corresponding code in queue.pl it says instead, just "$sync = 1;". 60 | } elsif ($switch eq "-pe") { # e.g. -pe smp 5 61 | my $argument2 = shift @ARGV; 62 | $ignored_opts .= "$switch $argument $argument2 "; 63 | } elsif ($switch eq "--gpu") { 64 | $using_gpu = $argument; 65 | } else { 66 | # Ignore option. 67 | $ignored_opts .= "$switch $argument "; 68 | } 69 | } 70 | } 71 | if ($ARGV[0] =~ m/^([\w_][\w\d_]*)+=(\d+):(\d+)$/) { # e.g. JOB=1:20 72 | $jobname = $1; 73 | $jobstart = $2; 74 | $jobend = $3; 75 | if ($jobstart > $jobend) { 76 | die "run.pl: invalid job range $ARGV[0]"; 77 | } 78 | if ($jobstart <= 0) { 79 | die "run.pl: invalid job range $ARGV[0], start must be strictly positive (this is required for GridEngine compatibility)."; 80 | } 81 | shift; 82 | } elsif ($ARGV[0] =~ m/^([\w_][\w\d_]*)+=(\d+)$/) { # e.g. JOB=1. 83 | $jobname = $1; 84 | $jobstart = $2; 85 | $jobend = $2; 86 | shift; 87 | } elsif ($ARGV[0] =~ m/.+\=.*\:.*$/) { 88 | print STDERR "run.pl: Warning: suspicious first argument to run.pl: $ARGV[0]\n"; 89 | } 90 | } 91 | 92 | # Users found this message confusing so we are removing it. 93 | # if ($ignored_opts ne "") { 94 | # print STDERR "run.pl: Warning: ignoring options \"$ignored_opts\"\n"; 95 | # } 96 | 97 | if ($max_jobs_run == -1) { # If --max-jobs-run option not set, 98 | # then work out the number of processors if possible, 99 | # and set it based on that. 100 | $max_jobs_run = 0; 101 | if ($using_gpu) { 102 | if (open(P, "nvidia-smi -L |")) { 103 | $max_jobs_run++ while (

); 104 | close(P); 105 | } 106 | if ($max_jobs_run == 0) { 107 | $max_jobs_run = 1; 108 | print STDERR "run.pl: Warning: failed to detect number of GPUs from nvidia-smi, using ${max_jobs_run}\n"; 109 | } 110 | } elsif (open(P, ") { if (m/^processor/) { $max_jobs_run++; } } 112 | if ($max_jobs_run == 0) { 113 | print STDERR "run.pl: Warning: failed to detect any processors from /proc/cpuinfo\n"; 114 | $max_jobs_run = 10; # reasonable default. 115 | } 116 | close(P); 117 | } elsif (open(P, "sysctl -a |")) { # BSD/Darwin 118 | while (

) { 119 | if (m/hw\.ncpu\s*[:=]\s*(\d+)/) { # hw.ncpu = 4, or hw.ncpu: 4 120 | $max_jobs_run = $1; 121 | last; 122 | } 123 | } 124 | close(P); 125 | if ($max_jobs_run == 0) { 126 | print STDERR "run.pl: Warning: failed to detect any processors from sysctl -a\n"; 127 | $max_jobs_run = 10; # reasonable default. 128 | } 129 | } else { 130 | # allow at most 32 jobs at once, on non-UNIX systems; change this code 131 | # if you need to change this default. 132 | $max_jobs_run = 32; 133 | } 134 | # The just-computed value of $max_jobs_run is just the number of processors 135 | # (or our best guess); and if it happens that the number of jobs we need to 136 | # run is just slightly above $max_jobs_run, it will make sense to increase 137 | # $max_jobs_run to equal the number of jobs, so we don't have a small number 138 | # of leftover jobs. 139 | $num_jobs = $jobend - $jobstart + 1; 140 | if (!$using_gpu && 141 | $num_jobs > $max_jobs_run && $num_jobs < 1.4 * $max_jobs_run) { 142 | $max_jobs_run = $num_jobs; 143 | } 144 | } 145 | 146 | $logfile = shift @ARGV; 147 | 148 | if (defined $jobname && $logfile !~ m/$jobname/ && 149 | $jobend > $jobstart) { 150 | print STDERR "run.pl: you are trying to run a parallel job but " 151 | . "you are putting the output into just one log file ($logfile)\n"; 152 | exit(1); 153 | } 154 | 155 | $cmd = ""; 156 | 157 | foreach $x (@ARGV) { 158 | if ($x =~ m/^\S+$/) { $cmd .= $x . " "; } 159 | elsif ($x =~ m:\":) { $cmd .= "'$x' "; } 160 | else { $cmd .= "\"$x\" "; } 161 | } 162 | 163 | #$Data::Dumper::Indent=0; 164 | $ret = 0; 165 | $numfail = 0; 166 | %active_pids=(); 167 | 168 | use POSIX ":sys_wait_h"; 169 | for ($jobid = $jobstart; $jobid <= $jobend; $jobid++) { 170 | if (scalar(keys %active_pids) >= $max_jobs_run) { 171 | 172 | # Lets wait for a change in any child's status 173 | # Then we have to work out which child finished 174 | $r = waitpid(-1, 0); 175 | $code = $?; 176 | if ($r < 0 ) { die "run.pl: Error waiting for child process"; } # should never happen. 177 | if ( defined $active_pids{$r} ) { 178 | $jid=$active_pids{$r}; 179 | $fail[$jid]=$code; 180 | if ($code !=0) { $numfail++;} 181 | delete $active_pids{$r}; 182 | # print STDERR "Finished: $r/$jid " . Dumper(\%active_pids) . "\n"; 183 | } else { 184 | die "run.pl: Cannot find the PID of the child process that just finished."; 185 | } 186 | 187 | # In theory we could do a non-blocking waitpid over all jobs running just 188 | # to find out if only one or more jobs finished during the previous waitpid() 189 | # However, we just omit this and will reap the next one in the next pass 190 | # through the for(;;) cycle 191 | } 192 | $childpid = fork(); 193 | if (!defined $childpid) { die "run.pl: Error forking in run.pl (writing to $logfile)"; } 194 | if ($childpid == 0) { # We're in the child... this branch 195 | # executes the job and returns (possibly with an error status). 196 | if (defined $jobname) { 197 | $cmd =~ s/$jobname/$jobid/g; 198 | $logfile =~ s/$jobname/$jobid/g; 199 | } 200 | system("mkdir -p `dirname $logfile` 2>/dev/null"); 201 | open(F, ">$logfile") || die "run.pl: Error opening log file $logfile"; 202 | print F "# " . $cmd . "\n"; 203 | print F "# Started at " . `date`; 204 | $starttime = `date +'%s'`; 205 | print F "#\n"; 206 | close(F); 207 | 208 | # Pipe into bash.. make sure we're not using any other shell. 209 | open(B, "|bash") || die "run.pl: Error opening shell command"; 210 | print B "( " . $cmd . ") 2>>$logfile >> $logfile"; 211 | close(B); # If there was an error, exit status is in $? 212 | $ret = $?; 213 | 214 | $lowbits = $ret & 127; 215 | $highbits = $ret >> 8; 216 | if ($lowbits != 0) { $return_str = "code $highbits; signal $lowbits" } 217 | else { $return_str = "code $highbits"; } 218 | 219 | $endtime = `date +'%s'`; 220 | open(F, ">>$logfile") || die "run.pl: Error opening log file $logfile (again)"; 221 | $enddate = `date`; 222 | chop $enddate; 223 | print F "# Accounting: time=" . ($endtime - $starttime) . " threads=1\n"; 224 | print F "# Ended ($return_str) at " . $enddate . ", elapsed time " . ($endtime-$starttime) . " seconds\n"; 225 | close(F); 226 | exit($ret == 0 ? 0 : 1); 227 | } else { 228 | $pid[$jobid] = $childpid; 229 | $active_pids{$childpid} = $jobid; 230 | # print STDERR "Queued: " . Dumper(\%active_pids) . "\n"; 231 | } 232 | } 233 | 234 | # Now we have submitted all the jobs, lets wait until all the jobs finish 235 | foreach $child (keys %active_pids) { 236 | $jobid=$active_pids{$child}; 237 | $r = waitpid($pid[$jobid], 0); 238 | $code = $?; 239 | if ($r == -1) { die "run.pl: Error waiting for child process"; } # should never happen. 240 | if ($r != 0) { $fail[$jobid]=$code; $numfail++ if $code!=0; } # Completed successfully 241 | } 242 | 243 | # Some sanity checks: 244 | # The $fail array should not contain undefined codes 245 | # The number of non-zeros in that array should be equal to $numfail 246 | # We cannot do foreach() here, as the JOB ids do not start at zero 247 | $failed_jids=0; 248 | for ($jobid = $jobstart; $jobid <= $jobend; $jobid++) { 249 | $job_return = $fail[$jobid]; 250 | if (not defined $job_return ) { 251 | # print Dumper(\@fail); 252 | 253 | die "run.pl: Sanity check failed: we have indication that some jobs are running " . 254 | "even after we waited for all jobs to finish" ; 255 | } 256 | if ($job_return != 0 ){ $failed_jids++;} 257 | } 258 | if ($failed_jids != $numfail) { 259 | die "run.pl: Sanity check failed: cannot find out how many jobs failed ($failed_jids x $numfail)." 260 | } 261 | if ($numfail > 0) { $ret = 1; } 262 | 263 | if ($ret != 0) { 264 | $njobs = $jobend - $jobstart + 1; 265 | if ($njobs == 1) { 266 | if (defined $jobname) { 267 | $logfile =~ s/$jobname/$jobstart/; # only one numbered job, so replace name with 268 | # that job. 269 | } 270 | print STDERR "run.pl: job failed, log is in $logfile\n"; 271 | if ($logfile =~ m/JOB/) { 272 | print STDERR "run.pl: probably you forgot to put JOB=1:\$nj in your script."; 273 | } 274 | } 275 | else { 276 | $logfile =~ s/$jobname/*/g; 277 | print STDERR "run.pl: $numfail / $njobs failed, log is in $logfile\n"; 278 | } 279 | } 280 | 281 | 282 | exit ($ret); 283 | -------------------------------------------------------------------------------- /tools/sclite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/by2101/OpenASR/c5213d68304a270a0448b2d53adc72b57f4efdb3/tools/sclite -------------------------------------------------------------------------------- /tools/spk2utt_to_utt2spk.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # Copyright 2010-2011 Microsoft Corporation 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 11 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 12 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 13 | # MERCHANTABLITY OR NON-INFRINGEMENT. 14 | # See the Apache 2 License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | 18 | while(<>){ 19 | @A = split(" ", $_); 20 | @A > 1 || die "Invalid line in spk2utt file: $_"; 21 | $s = shift @A; 22 | foreach $u ( @A ) { 23 | print "$u $s\n"; 24 | } 25 | } 26 | 27 | 28 | -------------------------------------------------------------------------------- /tools/utt2spk_to_spk2utt.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # Copyright 2010-2011 Microsoft Corporation 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 11 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 12 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 13 | # MERCHANTABLITY OR NON-INFRINGEMENT. 14 | # See the Apache 2 License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # converts an utt2spk file to a spk2utt file. 18 | # Takes input from the stdin or from a file argument; 19 | # output goes to the standard out. 20 | 21 | if ( @ARGV > 1 ) { 22 | die "Usage: utt2spk_to_spk2utt.pl [ utt2spk ] > spk2utt"; 23 | } 24 | 25 | while(<>){ 26 | @A = split(" ", $_); 27 | @A == 2 || die "Invalid line in utt2spk file: $_"; 28 | ($u,$s) = @A; 29 | if(!$seen_spk{$s}) { 30 | $seen_spk{$s} = 1; 31 | push @spklist, $s; 32 | } 33 | push (@{$spk_hash{$s}}, "$u"); 34 | } 35 | foreach $s (@spklist) { 36 | $l = join(' ',@{$spk_hash{$s}}); 37 | print "$s $l\n"; 38 | } 39 | --------------------------------------------------------------------------------