├── .gitignore ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── data ├── dw961.wav ├── fail1.wav ├── gsp1.wav ├── kaldi-asr.pc ├── lsen1.wav └── single.wav ├── examples ├── asr_client.py ├── asr_server.py ├── chain_incremental.py ├── chain_live.py ├── chain_online.py ├── chain_wavfile.py ├── chain_wavfile3.py └── gmm_incremental.py ├── kaldiasr ├── __init__.py ├── gmm.pyx ├── gmm_wrappers.cpp ├── gmm_wrappers.h ├── nnet3.pyx ├── nnet3_wrappers.cpp └── nnet3_wrappers.h └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # use glob syntax. 2 | syntax: glob 3 | *.swp 4 | *.swo 5 | *.pyc 6 | tmp 7 | old 8 | build 9 | data/models 10 | *.log 11 | TODO 12 | *.so 13 | kaldiasr/nnet3.cpp 14 | kaldiasr/gmm.cpp 15 | data/atlas.pc 16 | dist 17 | py_kaldi_asr.egg-info 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include kaldiasr/nnet3_wrappers.h 2 | include kaldiasr/gmm_wrappers.h 3 | include data/dw961.wav 4 | include data/gso1.wav 5 | include data/lsen1.wav 6 | include README.md 7 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | CFLAGS = -Wall -pthread -std=c++11 -DKALDI_DOUBLEPRECISION=0 -Wno-sign-compare \ 2 | -Wno-unused-local-typedefs -Winit-self -DHAVE_EXECINFO_H=1 -DHAVE_CXXABI_H -DHAVE_ATLAS \ 3 | `pkg-config --cflags kaldi-asr` -g 4 | 5 | LDFLAGS = -rdynamic -lm -lpthread -ldl `pkg-config --libs kaldi-asr` 6 | 7 | .PHONY: clean dist upload 8 | 9 | all: kaldiasr/nnet3.so 10 | 11 | kaldiasr/nnet3.so: kaldiasr/nnet3.pyx kaldiasr/nnet3_wrappers.cpp kaldiasr/nnet3_wrappers.h 12 | python setup.py build_ext --inplace 13 | 14 | dist: 15 | python setup.py sdist 16 | # python setup.py bdist_wheel 17 | 18 | upload: 19 | twine upload dist/* 20 | 21 | clean: 22 | rm -f kaldiasr/nnet3.cpp kaldiasr/*.so kaldiasr/*.pyc MANIFEST 23 | rm -rf build dist kaldiasr.egg-info py_kaldi_asr.egg-info kaldiasr/__pycache__ 24 | 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # py-kaldi-asr 2 | 3 | Some simple wrappers around kaldi-asr intended to make using kaldi's online nnet3-chain 4 | decoders as convenient as possible. Kaldi's online GMM decoders are also supported. 5 | 6 | Target audience are developers who would like to use kaldi-asr as-is for speech 7 | recognition in their application on GNU/Linux operating systems. 8 | 9 | Constructive comments, patches and pull-requests are very welcome. 10 | 11 | Getting Started 12 | =============== 13 | 14 | We recommend using pre-trained modules from the [zamia-speech](http://zamia-speech.org/) project 15 | to get started. There you will also find a tutorial complete with links to pre-built binary packages 16 | to get you up and running with free and open source speech recognition in a matter of minutes: 17 | 18 | [Zamia Speech Tutorial](https://github.com/gooofy/zamia-speech#get-started-with-our-pre-trained-models) 19 | 20 | Example Code 21 | ------------ 22 | 23 | Simple wav file decoding: 24 | 25 | ```python 26 | from kaldiasr.nnet3 import KaldiNNet3OnlineModel, KaldiNNet3OnlineDecoder 27 | 28 | MODELDIR = 'data/models/kaldi-generic-en-tdnn_sp-latest' 29 | WAVFILE = 'data/dw961.wav' 30 | 31 | kaldi_model = KaldiNNet3OnlineModel (MODELDIR) 32 | decoder = KaldiNNet3OnlineDecoder (kaldi_model) 33 | 34 | if decoder.decode_wav_file(WAVFILE): 35 | 36 | s, l = decoder.get_decoded_string() 37 | 38 | print 39 | print u"*****************************************************************" 40 | print u"**", WAVFILE 41 | print u"**", s 42 | print u"** %s likelihood:" % MODELDIR, l 43 | print u"*****************************************************************" 44 | print 45 | 46 | else: 47 | 48 | print "***ERROR: decoding of %s failed." % WAVFILE 49 | ``` 50 | 51 | Please check the examples directory for more example code. 52 | 53 | Requirements 54 | ============ 55 | 56 | * Python 2.7 or 3.5+ 57 | * NumPy 58 | * Cython 59 | * [kaldi-asr](http://kaldi-asr.org/ "kaldi-asr.org") 60 | 61 | Setup Notes 62 | =========== 63 | 64 | Source 65 | ------ 66 | 67 | At the time of this writing kaldi-asr does not seem to have an official way to 68 | install it on a system. 69 | 70 | So, for now we will rely on pkg-config to provide LIBS and CFLAGS for compilation: 71 | Create a file called `kaldi-asr.pc` somewhere in your `PKG_CONFIG_PATH` that provides 72 | this information - here is what such a file could look like (details depend on your OS environment): 73 | 74 | ```bash 75 | kaldi_root=/opt/kaldi 76 | 77 | Name: kaldi-asr 78 | Description: kaldi-asr speech recognition toolkit 79 | Version: 5.2 80 | Requires: atlas 81 | Libs: -L${kaldi_root}/tools/openfst/lib -L${kaldi_root}/src/lib -lkaldi-decoder -lkaldi-lat -lkaldi-fstext -lkaldi-hmm -lkaldi-feat -lkaldi-transform -lkaldi-gmm -lkaldi-tree -lkaldi-util -lkaldi-matrix -lkaldi-base -lkaldi-nnet3 -lkaldi-online2 -lkaldi-cudamatrix -lkaldi-ivector -lfst 82 | Cflags: -I${kaldi_root}/src -I${kaldi_root}/tools/openfst/include 83 | ``` 84 | 85 | make sure `kaldi_root` points to wherever your kaldi checkout lives in your filesystem. 86 | 87 | ATLAS 88 | ----- 89 | 90 | You may need to install ATLAS headers even if you didn't need them to compile Kaldi. 91 | 92 | ``` 93 | $ sudo apt install libatlas-dev 94 | ``` 95 | 96 | License 97 | ======= 98 | 99 | My own code is Apache licensed unless otherwise noted in the script's copyright 100 | headers. 101 | 102 | Some scripts and files are based on works of others, in those cases it is my 103 | intention to keep the original license intact. Please make sure to check the 104 | copyright headers inside for more information. 105 | 106 | Author 107 | ====== 108 | 109 | Guenter Bartsch
110 | Kaldi 5.1 adaptation contributed by mariasmo https://github.com/mariasmo 111 | Kaldi GMM model support contributed by David Zurow https://github.com/daanzu 112 | Python > 3.5 support contributed by Jakob Kruse https://github.com/jakob1111996 113 | -------------------------------------------------------------------------------- /data/dw961.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gooofy/py-kaldi-asr/331c725d45beff96a789b88a2d690f883379639d/data/dw961.wav -------------------------------------------------------------------------------- /data/fail1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gooofy/py-kaldi-asr/331c725d45beff96a789b88a2d690f883379639d/data/fail1.wav -------------------------------------------------------------------------------- /data/gsp1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gooofy/py-kaldi-asr/331c725d45beff96a789b88a2d690f883379639d/data/gsp1.wav -------------------------------------------------------------------------------- /data/kaldi-asr.pc: -------------------------------------------------------------------------------- 1 | kaldi_root=/apps/kaldi 2 | 3 | Name: kaldi-asr 4 | Description: kaldi-asr speech recognition toolkit 5 | Version: 5.1 6 | Requires: atlas 7 | Libs: -L${kaldi_root}/tools/openfst/lib -L${kaldi_root}/src/lib -lkaldi-decoder -lkaldi-lat -lkaldi-fstext -lkaldi-hmm -lkaldi-feat -lkaldi-transform -lkaldi-gmm -lkaldi-tree -lkaldi-util -lkaldi-matrix -lkaldi-base -lkaldi-nnet3 -lkaldi-online2 8 | Cflags: -I${kaldi_root}/src -I${kaldi_root}/tools/openfst/include 9 | -------------------------------------------------------------------------------- /data/lsen1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gooofy/py-kaldi-asr/331c725d45beff96a789b88a2d690f883379639d/data/lsen1.wav -------------------------------------------------------------------------------- /data/single.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gooofy/py-kaldi-asr/331c725d45beff96a789b88a2d690f883379639d/data/single.wav -------------------------------------------------------------------------------- /examples/asr_client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # 5 | # Copyright 2017, 2018 Guenter Bartsch 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 15 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 16 | # MERCHANTABLITY OR NON-INFRINGEMENT. 17 | # See the Apache 2 License for the specific language governing permissions and 18 | # limitations under the License. 19 | # 20 | # 21 | # very basic example client for our example speech asr server 22 | # 23 | 24 | 25 | import os 26 | import sys 27 | import logging 28 | import traceback 29 | import json 30 | import wave 31 | import struct 32 | import requests 33 | 34 | from time import time 35 | from optparse import OptionParser 36 | 37 | DEFAULT_HOST = 'localhost' 38 | DEFAULT_PORT = 8301 39 | 40 | # 41 | # commandline 42 | # 43 | 44 | parser = OptionParser("usage: %prog [options] foo.wav") 45 | 46 | parser.add_option ("-v", "--verbose", action="store_true", dest="verbose", 47 | help="verbose output") 48 | 49 | parser.add_option ("-H", "--host", dest="host", type = "string", default=DEFAULT_HOST, 50 | help="host, default: %s" % DEFAULT_HOST) 51 | 52 | parser.add_option ("-p", "--port", dest="port", type = "int", default=DEFAULT_PORT, 53 | help="port, default: %d" % DEFAULT_PORT) 54 | 55 | 56 | (options, args) = parser.parse_args() 57 | 58 | if options.verbose: 59 | logging.basicConfig(level=logging.DEBUG) 60 | else: 61 | logging.basicConfig(level=logging.INFO) 62 | logging.getLogger("requests").setLevel(logging.WARNING) 63 | 64 | if len(args) != 1: 65 | parser.print_help() 66 | sys.exit(1) 67 | 68 | wavfn = args[0] 69 | 70 | url = 'http://%s:%d/decode' % (options.host, options.port) 71 | 72 | # 73 | # read samples from wave file, hand them over to asr server incrementally to simulate online decoding 74 | # 75 | 76 | time_start = time() 77 | 78 | wavf = wave.open(wavfn, 'rb') 79 | 80 | # check format 81 | assert wavf.getnchannels()==1 82 | assert wavf.getsampwidth()==2 83 | 84 | # process file in 250ms chunks 85 | 86 | chunk_frames = 250 * wavf.getframerate() / 1000 87 | tot_frames = wavf.getnframes() 88 | 89 | num_frames = 0 90 | while num_frames < tot_frames: 91 | 92 | finalize = False 93 | if (num_frames + chunk_frames) < tot_frames: 94 | nframes = chunk_frames 95 | else: 96 | nframes = tot_frames - num_frames 97 | finalize = True 98 | 99 | frames = wavf.readframes(nframes) 100 | num_frames += nframes 101 | samples = struct.unpack_from('<%dh' % nframes, frames) 102 | 103 | data = {'audio' : samples, 104 | 'do_record' : False, 105 | 'do_asr' : True, 106 | 'do_finalize': finalize} 107 | 108 | response = requests.post(url, data=json.dumps(data)) 109 | 110 | logging.info("%6.3fs: %5d frames (%6.3fs) decoded, status=%d." % (time()-time_start, 111 | num_frames, 112 | float(num_frames) / float(wavf.getframerate()), 113 | response.status_code)) 114 | assert response.status_code == 200 115 | 116 | 117 | wavf.close() 118 | 119 | data = response.json() 120 | 121 | logging.debug("raw response data: %s" % repr(data)) 122 | 123 | logging.info ( "*****************************************************************") 124 | logging.info ( "** wavfn : %s" % wavfn) 125 | logging.info ( "** hstr : %s" % data['hstr']) 126 | logging.info ( "** confidence : %f" % data['confidence']) 127 | logging.info ( "** decoding time : %8.2fs" % ( time() - time_start )) 128 | logging.info ( "*****************************************************************") 129 | 130 | -------------------------------------------------------------------------------- /examples/asr_server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # 5 | # Copyright 2017 Guenter Bartsch 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 15 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 16 | # MERCHANTABLITY OR NON-INFRINGEMENT. 17 | # See the Apache 2 License for the specific language governing permissions and 18 | # limitations under the License. 19 | # 20 | # 21 | # simple speech recognition http api server 22 | # 23 | # WARNING: 24 | # right now, this supports a single client only - needs a lot more work 25 | # to become (at least somewhat) scalable 26 | # 27 | # Decode WAV Data 28 | # --------------- 29 | # 30 | # * POST `/decode` 31 | # * args (JSON encoded dict): 32 | # * "audio" : array of signed int16 samples 33 | # * "do_record" : boolean, if true record to wav file on disk 34 | # * "do_asr" : boolean, if true start/continue kaldi ASR 35 | # * "do_finalize" : boolean, if true finish kaldi ASR, return decoded string 36 | # 37 | # Returns: 38 | # 39 | # * 400 if request is invalid 40 | # * 200 OK 41 | # * 201 OK {"hstr": "hello world", "confidence": 0.02, "audiofn": "data/recordings/anonymous-20170105-rec/wav/de5-005.wav"} 42 | # 43 | # Example: 44 | # 45 | # curl -i -H "Content-Type: application/json" -X POST \ 46 | # -d '{"audio": [1,2,3,4], "do_record": true, "do_asr": true, "do_finalize": true}' \ 47 | # http://localhost:8301/decode 48 | 49 | 50 | import os 51 | import sys 52 | import logging 53 | import traceback 54 | import json 55 | import datetime 56 | import wave 57 | import errno 58 | import struct 59 | 60 | from time import time 61 | from optparse import OptionParser 62 | from setproctitle import setproctitle 63 | from BaseHTTPServer import BaseHTTPRequestHandler,HTTPServer 64 | 65 | from kaldiasr.nnet3 import KaldiNNet3OnlineModel, KaldiNNet3OnlineDecoder 66 | import numpy as np 67 | 68 | DEFAULT_HOST = 'localhost' 69 | DEFAULT_PORT = 8301 70 | 71 | DEFAULT_MODEL_DIR = 'data/models/kaldi-nnet3-voxforge-de-latest' 72 | DEFAULT_MODEL = 'nnet_tdnn_a' 73 | 74 | DEFAULT_VF_LOGIN = 'anonymous' 75 | DEFAULT_REC_DIR = 'data/recordings' 76 | SAMPLE_RATE = 16000 77 | 78 | PROC_TITLE = 'asr_server' 79 | 80 | # 81 | # globals 82 | # 83 | # FIXME: get rid of these, implement proper session management 84 | # 85 | 86 | audiofn = '' # path to current wav file being written 87 | wf = None # current wav file being written 88 | decoder = None # kaldi nnet3 online decoder 89 | 90 | def mkdirs(path): 91 | try: 92 | os.makedirs(path) 93 | except OSError as exception: 94 | if exception.errno != errno.EEXIST: 95 | raise 96 | 97 | class SpeechHandler(BaseHTTPRequestHandler): 98 | 99 | def do_GET(self): 100 | self.send_error(400, 'Invalid request') 101 | 102 | def do_HEAD(self): 103 | self._set_headers() 104 | 105 | def do_POST(self): 106 | 107 | global wf, decoder, vf_login, recordings_dir, audiofn 108 | 109 | logging.debug("POST %s" % self.path) 110 | 111 | if self.path=="/decode": 112 | 113 | data = json.loads(self.rfile.read(int(self.headers.getheader('content-length')))) 114 | 115 | # print data 116 | 117 | audio = data['audio'] 118 | do_record = data['do_record'] 119 | do_asr = data['do_asr'] 120 | do_finalize = data['do_finalize'] 121 | 122 | hstr = '' 123 | confidence = 0.0 124 | 125 | # FIXME: remove audio = map(lambda x: int(x), audios.split(',')) 126 | 127 | if do_record: 128 | 129 | # store recording in WAV format 130 | 131 | if not wf: 132 | 133 | ds = datetime.date.strftime(datetime.date.today(), '%Y%m%d') 134 | audiodirfn = '%s/%s-%s-rec/wav' % (recordings_dir, vf_login, ds) 135 | logging.debug('audiodirfn: %s' % audiodirfn) 136 | mkdirs(audiodirfn) 137 | 138 | cnt = 0 139 | while True: 140 | cnt += 1 141 | audiofn = '%s/de5-%03d.wav' % (audiodirfn, cnt) 142 | if not os.path.isfile(audiofn): 143 | break 144 | 145 | logging.debug('audiofn: %s' % audiofn) 146 | 147 | # create wav file 148 | 149 | wf = wave.open(audiofn, 'wb') 150 | wf.setnchannels(1) 151 | wf.setsampwidth(2) 152 | wf.setframerate(SAMPLE_RATE) 153 | 154 | packed_audio = struct.pack('%sh' % len(audio), *audio) 155 | wf.writeframes(packed_audio) 156 | 157 | if do_finalize: 158 | 159 | wf.close() 160 | wf = None 161 | 162 | else: 163 | audiofn = '' 164 | 165 | if do_asr: 166 | decoder.decode(SAMPLE_RATE, np.array(audio, dtype=np.float32), do_finalize) 167 | 168 | if do_finalize: 169 | 170 | hstr, confidence = decoder.get_decoded_string() 171 | 172 | logging.debug ( "*****************************************************************************") 173 | logging.debug ( "**") 174 | logging.debug ( "** %9.5f %s" % (confidence, hstr)) 175 | logging.debug ( "**") 176 | logging.debug ( "*****************************************************************************") 177 | 178 | self.send_response(200) 179 | self.send_header('Content-Type', 'application/json') 180 | self.end_headers() 181 | 182 | reply = {'hstr': hstr, 'confidence': confidence, 'audiofn': audiofn} 183 | 184 | self.wfile.write(json.dumps(reply)) 185 | return 186 | 187 | 188 | if __name__ == '__main__': 189 | 190 | setproctitle (PROC_TITLE) 191 | 192 | # 193 | # commandline 194 | # 195 | 196 | parser = OptionParser("usage: %prog [options] ") 197 | 198 | parser.add_option ("-v", "--verbose", action="store_true", dest="verbose", 199 | help="verbose output") 200 | 201 | parser.add_option ("-H", "--host", dest="host", type = "string", default=DEFAULT_HOST, 202 | help="host, default: %s" % DEFAULT_HOST) 203 | 204 | parser.add_option ("-p", "--port", dest="port", type = "int", default=DEFAULT_PORT, 205 | help="port, default: %d" % DEFAULT_PORT) 206 | 207 | parser.add_option ("-d", "--model-dir", dest="model_dir", type = "string", default=DEFAULT_MODEL_DIR, 208 | help="kaldi model directory, default: %s" % DEFAULT_MODEL_DIR) 209 | 210 | parser.add_option ("-m", "--model", dest="model", type = "string", default=DEFAULT_MODEL, 211 | help="kaldi model, default: %s" % DEFAULT_MODEL) 212 | 213 | parser.add_option ("-r", "--recordings-dir", dest="recordings_dir", type = "string", default=DEFAULT_REC_DIR, 214 | help="wav recordings directory, default: %s" % DEFAULT_REC_DIR) 215 | 216 | parser.add_option ("-l", "--voxforge-login", dest="vf_login", type = "string", default=DEFAULT_VF_LOGIN, 217 | help="voxforge login (used in recording filename generation), default: %s" % DEFAULT_VF_LOGIN) 218 | 219 | (options, args) = parser.parse_args() 220 | 221 | if options.verbose: 222 | logging.basicConfig(level=logging.DEBUG) 223 | else: 224 | logging.basicConfig(level=logging.INFO) 225 | 226 | kaldi_model_dir = options.model_dir 227 | kaldi_model = options.model 228 | 229 | vf_login = options.vf_login 230 | recordings_dir = options.recordings_dir 231 | 232 | # 233 | # setup kaldi decoder 234 | # 235 | 236 | start_time = time() 237 | logging.info('%s loading model from %s ...' % (kaldi_model, kaldi_model_dir)) 238 | nnet3_model = KaldiNNet3OnlineModel (kaldi_model_dir, kaldi_model) 239 | logging.info('%s loading model... done. took %fs.' % (kaldi_model, time()-start_time)) 240 | decoder = KaldiNNet3OnlineDecoder (nnet3_model) 241 | 242 | # 243 | # run HTTP server 244 | # 245 | 246 | try: 247 | server = HTTPServer((options.host, options.port), SpeechHandler) 248 | logging.info('listening for HTTP requests on %s:%d' % (options.host, options.port)) 249 | 250 | # wait forever for incoming http requests 251 | server.serve_forever() 252 | 253 | except KeyboardInterrupt: 254 | logging.error('^C received, shutting down the web server') 255 | server.socket.close() 256 | 257 | -------------------------------------------------------------------------------- /examples/chain_incremental.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # 5 | # Copyright 2016, 2017, 2018 Guenter Bartsch 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 15 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 16 | # MERCHANTABLITY OR NON-INFRINGEMENT. 17 | # See the Apache 2 License for the specific language governing permissions and 18 | # limitations under the License. 19 | # 20 | # 21 | # slightly more advanced demonstration program for kaldiasr online nnet3 22 | # decoding where we stream audio frames incrementally to the decoder 23 | # 24 | 25 | import sys 26 | import os 27 | import wave 28 | import struct 29 | import numpy as np 30 | 31 | from time import time 32 | 33 | from kaldiasr.nnet3 import KaldiNNet3OnlineModel, KaldiNNet3OnlineDecoder 34 | 35 | # this is useful for benchmarking purposes 36 | NUM_DECODER_RUNS = 1 37 | 38 | MODELDIR = 'data/models/kaldi-generic-en-tdnn_sp-latest' 39 | # MODELDIR = 'data/models/kaldi-generic-de-tdnn_sp-latest' 40 | WAVFILE = 'data/dw961.wav' 41 | # WAVFILE = 'data/gsp1.wav' 42 | 43 | print '%s loading model...' % MODELDIR 44 | time_start = time() 45 | kaldi_model = KaldiNNet3OnlineModel (MODELDIR) 46 | print '%s loading model... done, took %fs.' % (MODELDIR, time()-time_start) 47 | 48 | print '%s creating decoder...' % MODELDIR 49 | time_start = time() 50 | decoder = KaldiNNet3OnlineDecoder (kaldi_model) 51 | print '%s creating decoder... done, took %fs.' % (MODELDIR, time()-time_start) 52 | 53 | for i in range(NUM_DECODER_RUNS): 54 | 55 | time_start = time() 56 | 57 | print 'decoding %s...' % WAVFILE 58 | wavf = wave.open(WAVFILE, 'rb') 59 | 60 | # check format 61 | assert wavf.getnchannels()==1 62 | assert wavf.getsampwidth()==2 63 | 64 | # process file in 250ms chunks 65 | 66 | chunk_frames = 250 * wavf.getframerate() / 1000 67 | tot_frames = wavf.getnframes() 68 | 69 | num_frames = 0 70 | while num_frames < tot_frames: 71 | 72 | finalize = False 73 | if (num_frames + chunk_frames) < tot_frames: 74 | nframes = chunk_frames 75 | else: 76 | nframes = tot_frames - num_frames 77 | finalize = True 78 | 79 | frames = wavf.readframes(nframes) 80 | num_frames += nframes 81 | samples = struct.unpack_from('<%dh' % nframes, frames) 82 | 83 | decoder.decode(wavf.getframerate(), np.array(samples, dtype=np.float32), finalize) 84 | 85 | s, l = decoder.get_decoded_string() 86 | 87 | print "%6.3fs: %5d frames (%6.3fs) decoded. %s" % (time()-time_start, num_frames, float(num_frames) / float(wavf.getframerate()), s) 88 | 89 | wavf.close() 90 | 91 | s, l = decoder.get_decoded_string() 92 | print 93 | print "*****************************************************************" 94 | print "**", WAVFILE 95 | print "**", s 96 | print "** %s likelihood:" % MODELDIR, l 97 | print "*****************************************************************" 98 | print 99 | print "%s decoding took %8.2fs" % (MODELDIR, time() - time_start ) 100 | 101 | -------------------------------------------------------------------------------- /examples/chain_live.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # 5 | # Copyright 2018 Guenter Bartsch 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # 19 | 20 | # 21 | # example program for kaldi live nnet3 chain online decoding 22 | # 23 | # configured for embedded systems (e.g. an rpi3) with models 24 | # installed in /opt/kaldi/model/ 25 | # 26 | 27 | import traceback 28 | import logging 29 | import datetime 30 | 31 | from time import time 32 | from nltools import misc 33 | from nltools.pulserecorder import PulseRecorder 34 | from nltools.vad import VAD, BUFFER_DURATION 35 | from nltools.asr import ASR, ASR_ENGINE_NNET3 36 | from optparse import OptionParser 37 | 38 | PROC_TITLE = 'kaldi_live_demo' 39 | 40 | DEFAULT_VOLUME = 150 41 | DEFAULT_AGGRESSIVENESS = 2 42 | 43 | # DEFAULT_MODEL_DIR = '/opt/kaldi/model/kaldi-generic-de-tdnn_250' 44 | DEFAULT_MODEL_DIR = '/opt/kaldi/model/kaldi-generic-en-tdnn_250' 45 | DEFAULT_ACOUSTIC_SCALE = 1.0 46 | DEFAULT_BEAM = 7.0 47 | DEFAULT_FRAME_SUBSAMPLING_FACTOR = 3 48 | 49 | STREAM_ID = 'mic' 50 | 51 | # 52 | # init 53 | # 54 | 55 | misc.init_app(PROC_TITLE) 56 | logging.basicConfig(level=logging.INFO) 57 | 58 | print "Kaldi live demo V0.3" 59 | 60 | # 61 | # cmdline, logging 62 | # 63 | 64 | parser = OptionParser("usage: %prog [options]") 65 | 66 | parser.add_option ("-a", "--aggressiveness", dest="aggressiveness", type = "int", default=DEFAULT_AGGRESSIVENESS, 67 | help="VAD aggressiveness, default: %d" % DEFAULT_AGGRESSIVENESS) 68 | 69 | parser.add_option ("-m", "--model-dir", dest="model_dir", type = "string", default=DEFAULT_MODEL_DIR, 70 | help="kaldi model directory, default: %s" % DEFAULT_MODEL_DIR) 71 | 72 | parser.add_option ("-v", "--verbose", action="store_true", dest="verbose", 73 | help="verbose output") 74 | 75 | parser.add_option ("-s", "--source", dest="source", type = "string", default=None, 76 | help="pulseaudio source, default: auto-detect") 77 | 78 | parser.add_option ("-V", "--volume", dest="volume", type = "int", default=DEFAULT_VOLUME, 79 | help="broker port, default: %d" % DEFAULT_VOLUME) 80 | 81 | (options, args) = parser.parse_args() 82 | 83 | if options.verbose: 84 | logging.basicConfig(level=logging.DEBUG) 85 | else: 86 | logging.basicConfig(level=logging.INFO) 87 | 88 | source = options.source 89 | volume = options.volume 90 | aggressiveness = options.aggressiveness 91 | model_dir = options.model_dir 92 | 93 | # 94 | # pulseaudio recorder 95 | # 96 | 97 | rec = PulseRecorder (source_name=source, volume=volume) 98 | 99 | # 100 | # VAD 101 | # 102 | 103 | vad = VAD(aggressiveness=aggressiveness) 104 | 105 | # 106 | # ASR 107 | # 108 | 109 | print "Loading model from %s ..." % model_dir 110 | 111 | asr = ASR(engine = ASR_ENGINE_NNET3, model_dir = model_dir, 112 | kaldi_beam = DEFAULT_BEAM, kaldi_acoustic_scale = DEFAULT_ACOUSTIC_SCALE, 113 | kaldi_frame_subsampling_factor = DEFAULT_FRAME_SUBSAMPLING_FACTOR) 114 | 115 | 116 | # 117 | # main 118 | # 119 | 120 | rec.start_recording() 121 | 122 | print "Please speak." 123 | 124 | while True: 125 | 126 | samples = rec.get_samples() 127 | 128 | audio, finalize = vad.process_audio(samples) 129 | 130 | if not audio: 131 | continue 132 | 133 | logging.debug ('decoding audio len=%d finalize=%s audio=%s' % (len(audio), repr(finalize), audio[0].__class__)) 134 | 135 | user_utt, confidence = asr.decode(audio, finalize, stream_id=STREAM_ID) 136 | 137 | print "\r%s " % user_utt, 138 | 139 | if finalize: 140 | print 141 | 142 | 143 | -------------------------------------------------------------------------------- /examples/chain_online.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # 5 | # Copyright 2016, 2017, 2018 Guenter Bartsch 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 15 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 16 | # MERCHANTABLITY OR NON-INFRINGEMENT. 17 | # See the Apache 2 License for the specific language governing permissions and 18 | # limitations under the License. 19 | # 20 | # 21 | # simple demonstration program for kaldiasr online nnet3-chain decoding 22 | # 23 | 24 | import sys 25 | import os 26 | import wave 27 | import struct 28 | import numpy as np 29 | 30 | from time import time 31 | 32 | from kaldiasr.nnet3 import KaldiNNet3OnlineModel, KaldiNNet3OnlineDecoder 33 | 34 | # MODELDIR = 'data/models/kaldi-generic-en-tdnn_sp-latest' 35 | MODELDIR = 'data/models/kaldi-generic-de-tdnn_sp-latest' 36 | WAVFILES = [ 'data/single.wav', 'data/gsp1.wav'] 37 | 38 | print '%s loading model...' % MODELDIR 39 | kaldi_model = KaldiNNet3OnlineModel (MODELDIR) 40 | print '%s loading model... done.' % MODELDIR 41 | 42 | decoder = KaldiNNet3OnlineDecoder (kaldi_model) 43 | 44 | for WAVFILE in WAVFILES: 45 | 46 | print 'decoding %s...' % WAVFILE 47 | time_start = time() 48 | if decoder.decode_wav_file(WAVFILE): 49 | print '%s decoding worked!' % MODELDIR 50 | 51 | s,l = decoder.get_decoded_string() 52 | print 53 | print "*****************************************************************" 54 | print "**", WAVFILE 55 | print "**", s 56 | print "** %s likelihood:" % MODELDIR, l 57 | 58 | time_scale = 0.01 59 | words, times, lengths = decoder.get_word_alignment() 60 | 61 | print "** word alignment: :" 62 | for i, word in enumerate(words): 63 | print '** %f\t%f\t%s' % (time_scale * float(times[i]), time_scale*float(times[i] + lengths[i]), word) 64 | 65 | print "*****************************************************************" 66 | print 67 | 68 | else: 69 | print '%s decoding did not work :(' % MODELDIR 70 | 71 | print "%s decoding took %8.2fs" % (MODELDIR, time() - time_start ) 72 | 73 | -------------------------------------------------------------------------------- /examples/chain_wavfile.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # 5 | # Copyright 2016, 2017, 2018 Guenter Bartsch 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 15 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 16 | # MERCHANTABLITY OR NON-INFRINGEMENT. 17 | # See the Apache 2 License for the specific language governing permissions and 18 | # limitations under the License. 19 | # 20 | 21 | # 22 | # very simple single WAV file speech recognition (decoding) example 23 | # 24 | 25 | from kaldiasr.nnet3 import KaldiNNet3OnlineModel, KaldiNNet3OnlineDecoder 26 | 27 | MODELDIR = 'data/models/kaldi-generic-en-tdnn_sp-latest' 28 | # MODELDIR = 'data/models/kaldi-generic-de-tdnn_sp-latest' 29 | WAVFILE = 'data/dw961.wav' 30 | # WAVFILE = 'data/fail1.wav' 31 | # WAVFILE = 'data/gsp1.wav' 32 | 33 | print "Loading model from %s ..." % MODELDIR 34 | 35 | kaldi_model = KaldiNNet3OnlineModel (MODELDIR) 36 | decoder = KaldiNNet3OnlineDecoder (kaldi_model) 37 | 38 | print "Decoding %s ..." % WAVFILE 39 | 40 | if decoder.decode_wav_file(WAVFILE): 41 | 42 | s, l = decoder.get_decoded_string() 43 | 44 | print 45 | print u"*****************************************************************" 46 | print u"** %s" % WAVFILE 47 | print u"** %s" % s 48 | print u"** likelihood: %f" % l 49 | print u"*****************************************************************" 50 | print 51 | 52 | else: 53 | 54 | print "***ERROR: decoding of %s failed." % WAVFILE 55 | 56 | -------------------------------------------------------------------------------- /examples/chain_wavfile3.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # 5 | # Copyright 2016, 2017, 2018 Guenter Bartsch 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 15 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 16 | # MERCHANTABLITY OR NON-INFRINGEMENT. 17 | # See the Apache 2 License for the specific language governing permissions and 18 | # limitations under the License. 19 | # 20 | 21 | # 22 | # very simple single WAV file speech recognition (decoding) example 23 | # 24 | # Python 3 version 25 | # 26 | 27 | from kaldiasr.nnet3 import KaldiNNet3OnlineModel, KaldiNNet3OnlineDecoder 28 | 29 | MODELDIR = 'data/models/kaldi-generic-en-tdnn_sp-latest' 30 | # MODELDIR = 'data/models/kaldi-generic-de-tdnn_sp-latest' 31 | WAVFILE = 'data/dw961.wav' 32 | # WAVFILE = 'data/fail1.wav' 33 | # WAVFILE = 'data/gsp1.wav' 34 | 35 | kaldi_model = KaldiNNet3OnlineModel (MODELDIR, acoustic_scale=1.0, beam=7.0, frame_subsampling_factor=3) 36 | decoder = KaldiNNet3OnlineDecoder (kaldi_model) 37 | 38 | if decoder.decode_wav_file(WAVFILE): 39 | 40 | s, l = decoder.get_decoded_string() 41 | 42 | print() 43 | print("*****************************************************************") 44 | print("** %s" % WAVFILE) 45 | print("** %s" % s) 46 | print("** %s likelihood: %f" % (MODELDIR, l)) 47 | print("*****************************************************************") 48 | print() 49 | 50 | else: 51 | 52 | print("***ERROR: decoding of %s failed." % WAVFILE) 53 | 54 | -------------------------------------------------------------------------------- /examples/gmm_incremental.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # 5 | # Author: David Zurow, adapted from G. Bartsch 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 15 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 16 | # MERCHANTABLITY OR NON-INFRINGEMENT. 17 | # See the Apache 2 License for the specific language governing permissions and 18 | # limitations under the License. 19 | # 20 | # 21 | # slightly more advanced demonstration program for kaldiasr online gmm 22 | # decoding where we stream audio frames incrementally to the decoder 23 | # 24 | 25 | from __future__ import print_function 26 | 27 | import sys 28 | import os 29 | import wave 30 | import struct 31 | import numpy as np 32 | 33 | from time import time 34 | 35 | from kaldiasr.gmm import KaldiGmmOnlineModel, KaldiGmmOnlineDecoder 36 | 37 | # this is useful for benchmarking purposes 38 | NUM_DECODER_RUNS = 1 39 | 40 | # ../../training/kaldi_tmp/exp/ 41 | # |-- tri3b 42 | # | |-- graph 43 | # | | |-- HCLG.fst 44 | # | | |-- disambig_tid.int 45 | # | | |-- num_pdfs 46 | # | | |-- phones 47 | # | | | |-- align_lexicon.int 48 | # | | | |-- align_lexicon.txt 49 | # | | | |-- disambig.int 50 | # | | | |-- disambig.txt 51 | # | | | |-- optional_silence.csl 52 | # | | | |-- optional_silence.int 53 | # | | | |-- optional_silence.txt 54 | # | | | |-- silence.csl 55 | # | | | |-- word_boundary.int 56 | # | | | `-- word_boundary.txt 57 | # | | |-- phones.txt 58 | # | | `-- words.txt 59 | # |-- tri3b_mmi_online 60 | # | |-- cmvn_opts 61 | # | |-- conf 62 | # | | |-- mfcc.conf 63 | # | | |-- online_cmvn.conf 64 | # | | |-- online_decoding.conf 65 | # | | `-- splice.conf 66 | # | |-- final.mat 67 | # | |-- final.mdl 68 | # | |-- final.oalimdl 69 | # | |-- final.rescore_mdl 70 | # | |-- fmllr.basis 71 | # | |-- global_cmvn.stats 72 | # | |-- phones.txt 73 | # | `-- splice_opts 74 | 75 | MODELDIR = '../../training/kaldi_tmp/exp/tri3b_mmi_online' 76 | GRAPHDIR = '../../training/kaldi_tmp/exp/tri3b' 77 | WAVFILE = 'data/dw961.wav' 78 | 79 | print('%s loading model...' % MODELDIR) 80 | time_start = time() 81 | kaldi_model = KaldiGmmOnlineModel (MODELDIR, GRAPHDIR) 82 | print('%s loading model... done, took %fs.' % (MODELDIR, time()-time_start)) 83 | 84 | print('%s creating decoder...' % MODELDIR) 85 | time_start = time() 86 | decoder = KaldiGmmOnlineDecoder (kaldi_model) 87 | print('%s creating decoder... done, took %fs.' % (MODELDIR, time()-time_start)) 88 | 89 | for i in range(NUM_DECODER_RUNS): 90 | 91 | time_start = time() 92 | 93 | print('decoding %s...' % WAVFILE) 94 | wavf = wave.open(WAVFILE, 'rb') 95 | 96 | # check format 97 | assert wavf.getnchannels()==1 98 | assert wavf.getsampwidth()==2 99 | 100 | # process file in 250ms chunks 101 | 102 | chunk_frames = int(250 * wavf.getframerate() / 1000) 103 | tot_frames = wavf.getnframes() 104 | 105 | num_frames = 0 106 | while num_frames < tot_frames: 107 | 108 | finalize = False 109 | if (num_frames + chunk_frames) < tot_frames: 110 | nframes = chunk_frames 111 | else: 112 | nframes = tot_frames - num_frames 113 | finalize = True 114 | 115 | frames = wavf.readframes(nframes) 116 | num_frames += nframes 117 | samples = struct.unpack_from('<%dh' % nframes, frames) 118 | 119 | decoder.decode(wavf.getframerate(), np.array(samples, dtype=np.float32), finalize) 120 | 121 | s, l = decoder.get_decoded_string() 122 | 123 | print("%6.3fs: %5d frames (%6.3fs) decoded. %s" % (time()-time_start, num_frames, float(num_frames) / float(wavf.getframerate()), s)) 124 | 125 | wavf.close() 126 | 127 | s, l = decoder.get_decoded_string() 128 | print() 129 | print("*****************************************************************") 130 | print("**", WAVFILE) 131 | print("**", s) 132 | print("** %s likelihood:" % MODELDIR, l) 133 | 134 | time_scale = 0.01 135 | words, times, lengths = decoder.get_word_alignment() 136 | print("** word alignment: :") 137 | for i, word in enumerate(words): 138 | print('** %f\t%f\t%s' % (time_scale * float(times[i]), time_scale*float(times[i] + lengths[i]), word)) 139 | 140 | print("*****************************************************************") 141 | print() 142 | print("%s decoding took %8.2fs" % (MODELDIR, time() - time_start )) 143 | -------------------------------------------------------------------------------- /kaldiasr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gooofy/py-kaldi-asr/331c725d45beff96a789b88a2d690f883379639d/kaldiasr/__init__.py -------------------------------------------------------------------------------- /kaldiasr/gmm.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language = c++ 2 | # distutils: sources = gmm.cpp 3 | 4 | # 5 | # Author: David Zurow, adapted from G. Bartsch 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 15 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 16 | # MERCHANTABLITY OR NON-INFRINGEMENT. 17 | # See the Apache 2 License for the specific language governing permissions and 18 | # limitations under the License. 19 | # 20 | 21 | import cython 22 | from libcpp.string cimport string 23 | from libcpp.vector cimport vector 24 | import numpy as np 25 | cimport numpy as cnp 26 | import struct 27 | import wave 28 | import os, os.path 29 | import re 30 | from tempfile import NamedTemporaryFile 31 | import subprocess 32 | from cpython.version cimport PY_MAJOR_VERSION 33 | 34 | cdef unicode _text(s): 35 | if type(s) is unicode: 36 | # Fast path for most common case(s). 37 | return s 38 | 39 | elif PY_MAJOR_VERSION < 3 and isinstance(s, bytes): 40 | # Only accept byte strings as text input in Python 2.x, not in Py3. 41 | return (s).decode('utf8') 42 | 43 | elif isinstance(s, unicode): 44 | # We know from the fast path above that 's' can only be a subtype here. 45 | # An evil cast to might still work in some(!) cases, 46 | # depending on what the further processing does. To be safe, 47 | # we can always create a copy instead. 48 | return unicode(s) 49 | 50 | else: 51 | raise TypeError("Could not convert to unicode.") 52 | 53 | cdef extern from "gmm_wrappers.h" namespace "kaldi": 54 | 55 | cdef cppclass GmmOnlineModelWrapper: 56 | GmmOnlineModelWrapper() except + 57 | GmmOnlineModelWrapper(float, int, int, float, string, string, string, string) except + 58 | 59 | cdef cppclass GmmOnlineDecoderWrapper: 60 | GmmOnlineDecoderWrapper() except + 61 | GmmOnlineDecoderWrapper(GmmOnlineModelWrapper *) except + 62 | 63 | bint decode(float, int, float *, bint) except + 64 | 65 | void get_decoded_string(string &, float &) except + 66 | bint get_word_alignment(vector[string] &, vector[int] &, vector[int] &) except + 67 | 68 | cdef class KaldiGmmOnlineModel: 69 | 70 | cdef GmmOnlineModelWrapper* model_wrapper 71 | cdef unicode model_dir, graph_dir 72 | cdef object conf_file 73 | 74 | def __cinit__(self, object model_dir, 75 | object graph_dir, 76 | float beam = 7.0, # nnet3: 15.0 77 | int max_active = 7000, 78 | int min_active = 200, 79 | float lattice_beam = 8.0): 80 | 81 | self.model_dir = _text(model_dir) 82 | self.graph_dir = _text(graph_dir) 83 | 84 | cdef unicode config = u'%s/conf/online_decoding.conf' % self.model_dir 85 | cdef unicode word_symbol_table = u'%s/graph/words.txt' % self.graph_dir 86 | cdef unicode fst_in_str = u'%s/graph/HCLG.fst' % self.graph_dir 87 | cdef unicode align_lex_filename = u'%s/graph/phones/align_lexicon.int' % self.graph_dir 88 | 89 | # 90 | # make sure all model files required exist 91 | # 92 | 93 | for filename in [config, word_symbol_table, fst_in_str, align_lex_filename]: 94 | if not os.path.isfile(filename.encode('utf8')): 95 | raise Exception ('%s not found.' % filename) 96 | if not os.access(filename.encode('utf8'), os.R_OK): 97 | raise Exception ('%s is not readable' % filename) 98 | 99 | # 100 | # generate .conf file from existing one, modifying paths 101 | # 102 | 103 | self.conf_file = NamedTemporaryFile(prefix=u'py_online_decoding_', suffix=u'.conf', delete=True) 104 | # print(self.conf_file.name) 105 | with open(config) as file: 106 | for line in file: 107 | # modify any path, then write 108 | line = re.sub(r'=(.*/.*)', 109 | lambda match: '=' + os.path.join(self.model_dir, '..', '..', match.group(1)), 110 | line) 111 | self.conf_file.write(line.encode('utf8')) 112 | self.conf_file.flush() 113 | # subprocess.run('cat ' + self.conf_file.name, shell=True) 114 | 115 | # 116 | # instantiate our C++ wrapper class 117 | # 118 | 119 | self.model_wrapper = new GmmOnlineModelWrapper(beam, 120 | max_active, 121 | min_active, 122 | lattice_beam, 123 | word_symbol_table.encode('utf8'), 124 | fst_in_str.encode('utf8'), 125 | self.conf_file.name.encode('utf8'), 126 | align_lex_filename.encode('utf8')) 127 | 128 | def __dealloc__(self): 129 | if self.conf_file: 130 | self.conf_file.close() 131 | if self.model_wrapper: 132 | del self.model_wrapper 133 | 134 | cdef class KaldiGmmOnlineDecoder: 135 | 136 | cdef GmmOnlineDecoderWrapper* decoder_wrapper 137 | 138 | def __cinit__(self, KaldiGmmOnlineModel model): 139 | 140 | # 141 | # instantiate our C++ wrapper class 142 | # 143 | 144 | self.decoder_wrapper = new GmmOnlineDecoderWrapper(model.model_wrapper) 145 | 146 | def __dealloc__(self): 147 | del self.decoder_wrapper 148 | 149 | def decode(self, samp_freq, cnp.ndarray[float, ndim=1, mode="c"] samples not None, finalize): 150 | return self.decoder_wrapper.decode(samp_freq, samples.shape[0], samples.data, finalize) 151 | 152 | def get_decoded_string(self): 153 | cdef string decoded_string 154 | cdef double likelihood=0.0 155 | self.decoder_wrapper.get_decoded_string(decoded_string, likelihood) 156 | return decoded_string.decode('utf8'), likelihood 157 | 158 | def get_word_alignment(self): 159 | cdef vector[string] words 160 | cdef vector[int] times 161 | cdef vector[int] lengths 162 | if not self.decoder_wrapper.get_word_alignment(words, times, lengths): 163 | return None 164 | return words, times, lengths 165 | 166 | # 167 | # various convenience functions below 168 | # 169 | 170 | def decode_wav_file(self, object wavfile): 171 | 172 | wavf = wave.open(wavfile, 'rb') 173 | 174 | # check format 175 | assert wavf.getnchannels()==1 176 | assert wavf.getsampwidth()==2 177 | assert wavf.getnframes()>0 178 | 179 | # read the whole file into memory, for now 180 | num_frames = wavf.getnframes() 181 | frames = wavf.readframes(num_frames) 182 | 183 | samples = struct.unpack_from('<%dh' % num_frames, frames) 184 | 185 | wavf.close() 186 | 187 | return self.decode(wavf.getframerate(), np.array(samples, dtype=np.float32), True) 188 | 189 | -------------------------------------------------------------------------------- /kaldiasr/gmm_wrappers.cpp: -------------------------------------------------------------------------------- 1 | // gmm_wrappers.cpp 2 | // 3 | // Author: David Zurow, adapted from G. Bartsch 4 | // 5 | // based on Kaldi's decoder/decoder-wrappers.cc 6 | 7 | // Copyright 2014 Johns Hopkins University (author: Daniel Povey) 8 | 9 | // See ../../COPYING for clarification regarding multiple authors 10 | // 11 | // Licensed under the Apache License, Version 2.0 (the "License"); 12 | // you may not use this file except in compliance with the License. 13 | // You may obtain a copy of the License at 14 | // 15 | // http://www.apache.org/licenses/LICENSE-2.0 16 | // 17 | // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 18 | // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 19 | // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 20 | // MERCHANTABLITY OR NON-INFRINGEMENT. 21 | // See the Apache 2 License for the specific language governing permissions and 22 | // limitations under the License. 23 | // 24 | 25 | #include "gmm_wrappers.h" 26 | 27 | #include "feat/wave-reader.h" 28 | #include "online2/online-feature-pipeline.h" 29 | #include "online2/online-gmm-decoding.h" 30 | #include "online2/onlinebin-util.h" 31 | #include "online2/online-timing.h" 32 | #include "online2/online-endpoint.h" 33 | #include "fstext/fstext-lib.h" 34 | #include "lat/lattice-functions.h" 35 | #include "lat/word-align-lattice-lexicon.h" 36 | 37 | #define VERBOSE 0 38 | 39 | namespace kaldi { 40 | 41 | /* 42 | * GmmOnlineDecoderWrapper 43 | */ 44 | 45 | GmmOnlineDecoderWrapper::GmmOnlineDecoderWrapper(GmmOnlineModelWrapper *aModel) : model(aModel) { 46 | decoder = NULL; 47 | adaptation_state = NULL; 48 | 49 | tot_frames = 0; 50 | tot_frames_decoded = 0; 51 | 52 | 53 | } 54 | 55 | GmmOnlineDecoderWrapper::~GmmOnlineDecoderWrapper() { 56 | free_decoder(); 57 | } 58 | 59 | void GmmOnlineDecoderWrapper::start_decoding(void) { 60 | #if VERBOSE 61 | KALDI_LOG << "start_decoding..." ; 62 | KALDI_LOG << "max_active :" << model->decode_config.faster_decoder_opts.max_active; 63 | KALDI_LOG << "min_active :" << model->decode_config.faster_decoder_opts.min_active; 64 | KALDI_LOG << "beam :" << model->decode_config.faster_decoder_opts.beam; 65 | KALDI_LOG << "lattice_beam:" << model->decode_config.faster_decoder_opts.lattice_beam; 66 | #endif 67 | free_decoder(); 68 | #if VERBOSE 69 | KALDI_LOG << "alloc: OnlineGmmAdaptationState"; 70 | #endif 71 | adaptation_state = new OnlineGmmAdaptationState (); 72 | #if VERBOSE 73 | KALDI_LOG << "alloc: SingleUtteranceGmmDecoder"; 74 | #endif 75 | decoder = new SingleUtteranceGmmDecoder (model->decode_config, 76 | *model->gmm_models, 77 | *model->feature_pipeline_prototype, 78 | *model->decode_fst,//ok 79 | *adaptation_state); 80 | #if VERBOSE 81 | KALDI_LOG << "start_decoding...done" ; 82 | #endif 83 | } 84 | 85 | void GmmOnlineDecoderWrapper::free_decoder(void) { 86 | if (decoder) { 87 | #if VERBOSE 88 | KALDI_LOG << "free_decoder"; 89 | #endif 90 | delete decoder; 91 | decoder = NULL; 92 | } 93 | if (adaptation_state) { 94 | delete adaptation_state; 95 | adaptation_state = NULL; 96 | } 97 | } 98 | 99 | void GmmOnlineDecoderWrapper::get_decoded_string(std::string &decoded_string, double &likelihood) { 100 | 101 | //std::string decoded_string; 102 | //double likelihood; 103 | 104 | Lattice best_path_lat; 105 | 106 | decoded_string = ""; 107 | 108 | if (decoder) { 109 | 110 | // decoding is not finished yet, so we will look up the best partial result so far 111 | 112 | // if (decoder->NumFramesDecoded() == 0) { 113 | // likelihood = 0.0; 114 | // return; 115 | // } 116 | 117 | decoder->GetBestPath(false, &best_path_lat); 118 | 119 | } else { 120 | ConvertLattice(best_path_clat, &best_path_lat); 121 | } 122 | 123 | std::vector words; 124 | std::vector alignment; 125 | LatticeWeight weight; 126 | int32 num_frames; 127 | GetLinearSymbolSequence(best_path_lat, &alignment, &words, &weight); 128 | num_frames = alignment.size(); 129 | likelihood = -(weight.Value1() + weight.Value2()) / num_frames; 130 | 131 | for (size_t i = 0; i < words.size(); i++) { 132 | std::string s = model->word_syms->Find(words[i]); 133 | if (s == "") 134 | KALDI_ERR << "Word-id " << words[i] << " not in symbol table."; 135 | decoded_string += s + ' '; 136 | } 137 | } 138 | 139 | bool GmmOnlineDecoderWrapper::get_word_alignment(std::vector &words, 140 | std::vector ×, 141 | std::vector &lengths) { 142 | 143 | WordAlignLatticeLexiconInfo lexicon_info(model->word_alignment_lexicon); 144 | 145 | #if VERBOSE 146 | KALDI_LOG << "word alignment starts..."; 147 | #endif 148 | CompactLattice aligned_clat; 149 | WordAlignLatticeLexiconOpts opts; 150 | 151 | bool ok = WordAlignLatticeLexicon(best_path_clat, model->gmm_models->GetTransitionModel(), lexicon_info, opts, &aligned_clat); 152 | 153 | if (!ok) { 154 | KALDI_WARN << "Lattice did not align correctly"; 155 | return false; 156 | } else { 157 | if (aligned_clat.Start() == fst::kNoStateId) { 158 | KALDI_WARN << "Lattice was empty"; 159 | return false; 160 | } else { 161 | #if VERBOSE 162 | KALDI_LOG << "Aligned lattice."; 163 | #endif 164 | TopSortCompactLatticeIfNeeded(&aligned_clat); 165 | 166 | // lattice-1best 167 | 168 | CompactLattice best_path_aligned; 169 | CompactLatticeShortestPath(aligned_clat, &best_path_aligned); 170 | 171 | // nbest-to-ctm 172 | 173 | std::vector word_idxs; 174 | if (!CompactLatticeToWordAlignment(best_path_aligned, &word_idxs, ×, &lengths)) { 175 | KALDI_WARN << "CompactLatticeToWordAlignment failed."; 176 | return false; 177 | } 178 | 179 | // lexicon lookup 180 | words.clear(); 181 | for (size_t i = 0; i < word_idxs.size(); i++) { 182 | std::string s = model->word_syms->Find(word_idxs[i]); 183 | if (s == "") { 184 | KALDI_ERR << "Word-id " << word_idxs[i] << " not in symbol table."; 185 | } 186 | words.push_back(s); 187 | } 188 | } 189 | } 190 | return true; 191 | } 192 | 193 | 194 | 195 | bool GmmOnlineDecoderWrapper::decode(BaseFloat samp_freq, int32 num_frames, BaseFloat *frames, bool finalize) { 196 | 197 | using fst::VectorFst; 198 | 199 | if (!decoder) { 200 | start_decoding(); 201 | } 202 | 203 | Vector wave_part(num_frames, kUndefined); 204 | for (int i=0; iFeaturePipeline().AcceptWaveform(samp_freq, wave_part); 213 | 214 | if (finalize) { 215 | // no more input. flush out last frames 216 | decoder->FeaturePipeline().InputFinished(); 217 | } 218 | 219 | decoder->AdvanceDecoding(); 220 | 221 | if (finalize) { 222 | decoder->FinalizeDecoding(); 223 | 224 | CompactLattice clat; 225 | bool end_of_utterance = true; 226 | decoder->EstimateFmllr(end_of_utterance); 227 | bool rescore_if_needed = true; 228 | decoder->GetLattice(rescore_if_needed, end_of_utterance, &clat); 229 | 230 | if (clat.NumStates() == 0) { 231 | KALDI_WARN << "Empty lattice."; 232 | return false; 233 | } 234 | 235 | CompactLatticeShortestPath(clat, &best_path_clat); 236 | 237 | tot_frames_decoded = tot_frames; 238 | tot_frames = 0; 239 | 240 | free_decoder(); 241 | 242 | } 243 | 244 | return true; 245 | } 246 | 247 | 248 | /* 249 | * GmmOnlineModelWrapper 250 | */ 251 | 252 | // typedef void (*LogHandler)(const LogMessageEnvelope &envelope, 253 | // const char *message); 254 | void silent_log_handler (const LogMessageEnvelope &envelope, 255 | const char *message) { 256 | // nothing - this handler simply keeps silent 257 | } 258 | 259 | GmmOnlineModelWrapper::GmmOnlineModelWrapper(BaseFloat beam, 260 | int32 max_active, 261 | int32 min_active, 262 | BaseFloat lattice_beam, 263 | std::string &word_syms_filename, 264 | std::string &fst_in_str, 265 | std::string &config, 266 | std::string &align_lex_filename) 267 | 268 | { 269 | 270 | using namespace kaldi; 271 | using namespace fst; 272 | 273 | typedef kaldi::int32 int32; 274 | typedef kaldi::int64 int64; 275 | 276 | #if VERBOSE 277 | KALDI_LOG << "fst_in_str: " << fst_in_str; 278 | KALDI_LOG << "config: " << config; 279 | KALDI_LOG << "align_lex_filename: " << align_lex_filename; 280 | #else 281 | // silence kaldi output as well 282 | SetLogHandler(silent_log_handler); 283 | #endif 284 | 285 | ParseOptions po(""); 286 | feature_cmdline_config.Register(&po); 287 | decode_config.Register(&po); 288 | endpoint_config.Register(&po); 289 | po.ReadConfigFile(config); 290 | 291 | decode_config.faster_decoder_opts.max_active = max_active; 292 | decode_config.faster_decoder_opts.min_active = min_active; 293 | decode_config.faster_decoder_opts.beam = beam; 294 | decode_config.faster_decoder_opts.lattice_beam = lattice_beam; 295 | 296 | feature_config = new OnlineFeaturePipelineConfig(feature_cmdline_config); 297 | feature_pipeline_prototype = new OnlineFeaturePipeline(*this->feature_config); 298 | 299 | // load model... 300 | gmm_models = new OnlineGmmDecodingModels(decode_config); 301 | 302 | // Input FST is just one FST, not a table of FSTs. 303 | decode_fst = fst::ReadFstKaldiGeneric(fst_in_str); 304 | 305 | word_syms = NULL; 306 | if (word_syms_filename != "") 307 | if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename))) 308 | KALDI_ERR << "Could not read symbol table from file " 309 | << word_syms_filename; 310 | 311 | #if VERBOSE 312 | KALDI_LOG << "loading word alignment lexicon..."; 313 | #endif 314 | { 315 | bool binary_in; 316 | Input ki(align_lex_filename, &binary_in); 317 | KALDI_ASSERT(!binary_in && "Not expecting binary file for lexicon"); 318 | if (!ReadLexiconForWordAlign(ki.Stream(), &word_alignment_lexicon)) { 319 | KALDI_ERR << "Error reading alignment lexicon from " 320 | << align_lex_filename; 321 | } 322 | } 323 | } 324 | 325 | GmmOnlineModelWrapper::~GmmOnlineModelWrapper() { 326 | delete feature_config; 327 | delete feature_pipeline_prototype; 328 | delete gmm_models; 329 | } 330 | 331 | } 332 | 333 | -------------------------------------------------------------------------------- /kaldiasr/gmm_wrappers.h: -------------------------------------------------------------------------------- 1 | // gmm_wrappers.h 2 | // 3 | // Author: David Zurow, adapted from G. Bartsch 4 | // 5 | // based on Kaldi's decoder/decoder-wrappers.cc 6 | 7 | // Copyright 2014 Johns Hopkins University (author: Daniel Povey) 8 | 9 | // See ../../COPYING for clarification regarding multiple authors 10 | // 11 | // Licensed under the Apache License, Version 2.0 (the "License"); 12 | // you may not use this file except in compliance with the License. 13 | // You may obtain a copy of the License at 14 | // 15 | // http://www.apache.org/licenses/LICENSE-2.0 16 | // 17 | // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 18 | // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 19 | // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 20 | // MERCHANTABLITY OR NON-INFRINGEMENT. 21 | // See the Apache 2 License for the specific language governing permissions and 22 | // limitations under the License. 23 | // 24 | 25 | #include "feat/wave-reader.h" 26 | #include "online2/online-feature-pipeline.h" 27 | #include "online2/online-gmm-decoding.h" 28 | #include "online2/onlinebin-util.h" 29 | #include "online2/online-timing.h" 30 | #include "online2/online-endpoint.h" 31 | #include "fstext/fstext-lib.h" 32 | #include "lat/lattice-functions.h" 33 | #include "lat/word-align-lattice-lexicon.h" 34 | 35 | 36 | namespace kaldi { 37 | class GmmOnlineModelWrapper { 38 | friend class GmmOnlineDecoderWrapper; 39 | public: 40 | 41 | GmmOnlineModelWrapper(BaseFloat beam, 42 | int32 max_active, 43 | int32 min_active, 44 | BaseFloat lattice_beam, 45 | std::string &word_syms_filename, 46 | std::string &fst_in_str, 47 | std::string &config, 48 | std::string &align_lex_filename 49 | ); 50 | ~GmmOnlineModelWrapper(); 51 | 52 | private: 53 | 54 | fst::SymbolTable *word_syms; 55 | 56 | OnlineGmmDecodingConfig decode_config; 57 | 58 | OnlineFeaturePipelineCommandLineConfig feature_cmdline_config; 59 | OnlineFeaturePipelineConfig *feature_config; 60 | OnlineFeaturePipeline *feature_pipeline_prototype; 61 | OnlineEndpointConfig endpoint_config; 62 | 63 | OnlineGmmDecodingModels *gmm_models; 64 | fst::Fst *decode_fst; 65 | 66 | // word alignment: 67 | std::vector > word_alignment_lexicon; 68 | }; 69 | 70 | class GmmOnlineDecoderWrapper { 71 | public: 72 | 73 | GmmOnlineDecoderWrapper(GmmOnlineModelWrapper *aModel); 74 | ~GmmOnlineDecoderWrapper(); 75 | 76 | bool decode(BaseFloat samp_freq, 77 | int32 num_frames, 78 | BaseFloat *frames, 79 | bool finalize); 80 | 81 | void get_decoded_string(std::string &decoded_string, 82 | double &likelihood); 83 | bool get_word_alignment(std::vector &words, 84 | std::vector ×, 85 | std::vector &lengths); 86 | 87 | private: 88 | 89 | void start_decoding(void); 90 | void free_decoder(void); 91 | 92 | GmmOnlineModelWrapper *model; 93 | 94 | OnlineGmmAdaptationState *adaptation_state; 95 | SingleUtteranceGmmDecoder *decoder; 96 | 97 | int32 tot_frames, tot_frames_decoded; 98 | 99 | // decoding result: 100 | CompactLattice best_path_clat; 101 | 102 | }; 103 | 104 | 105 | 106 | } 107 | 108 | -------------------------------------------------------------------------------- /kaldiasr/nnet3.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language = c++ 2 | # distutils: sources = nnet3.cpp 3 | 4 | # 5 | # Copyright 2016, 2017, 2018 G. Bartsch 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 15 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 16 | # MERCHANTABLITY OR NON-INFRINGEMENT. 17 | # See the Apache 2 License for the specific language governing permissions and 18 | # limitations under the License. 19 | # 20 | 21 | import cython 22 | from libcpp.string cimport string 23 | from libcpp.vector cimport vector 24 | import numpy as np 25 | cimport numpy as cnp 26 | import struct 27 | import wave 28 | import os 29 | from tempfile import NamedTemporaryFile 30 | from cpython.version cimport PY_MAJOR_VERSION 31 | 32 | cdef unicode _text(s): 33 | if type(s) is unicode: 34 | # Fast path for most common case(s). 35 | return s 36 | 37 | elif PY_MAJOR_VERSION < 3 and isinstance(s, bytes): 38 | # Only accept byte strings as text input in Python 2.x, not in Py3. 39 | return (s).decode('utf8') 40 | 41 | elif isinstance(s, unicode): 42 | # We know from the fast path above that 's' can only be a subtype here. 43 | # An evil cast to might still work in some(!) cases, 44 | # depending on what the further processing does. To be safe, 45 | # we can always create a copy instead. 46 | return unicode(s) 47 | 48 | else: 49 | raise TypeError("Could not convert to unicode.") 50 | 51 | cdef extern from "nnet3_wrappers.h" namespace "kaldi": 52 | 53 | cdef cppclass NNet3OnlineModelWrapper: 54 | NNet3OnlineModelWrapper() except + 55 | NNet3OnlineModelWrapper(float, int, int, float, float, int, string, string, string, string, string, string) except + 56 | 57 | cdef cppclass NNet3OnlineDecoderWrapper: 58 | NNet3OnlineDecoderWrapper() except + 59 | NNet3OnlineDecoderWrapper(NNet3OnlineModelWrapper *) except + 60 | 61 | bint decode(float, int, float *, bint) except + 62 | 63 | void get_decoded_string(string &, float &) except + 64 | bint get_word_alignment(vector[string] &, vector[int] &, vector[int] &) except + 65 | 66 | cdef class KaldiNNet3OnlineModel: 67 | 68 | cdef NNet3OnlineModelWrapper* model_wrapper 69 | cdef unicode modeldir, model 70 | cdef object ie_conf_f 71 | 72 | def __cinit__(self, object modeldir, 73 | object model = None, 74 | float beam = 7.0, # nnet3: 15.0 75 | int max_active = 7000, 76 | int min_active = 200, 77 | float lattice_beam = 8.0, 78 | float acoustic_scale = 1.0, # nnet3: 0.1 79 | int frame_subsampling_factor = 3, # neet3: 1 80 | 81 | int num_gselect = 5, 82 | float min_post = 0.025, 83 | float posterior_scale = 0.1, 84 | int max_count = 0, 85 | int online_ivector_period = 10): 86 | 87 | self.modeldir = _text(modeldir) 88 | if model is None: 89 | self.model = _text('model') 90 | else: 91 | self.model = _text(model) 92 | 93 | cdef unicode mfcc_config = u'%s/conf/mfcc_hires.conf' % self.modeldir 94 | cdef unicode word_symbol_table = u'%s/%s/graph/words.txt' % (self.modeldir, self.model) 95 | cdef unicode model_in_filename = u'%s/%s/final.mdl' % (self.modeldir, self.model) 96 | cdef unicode splice_conf_filename = u'%s/ivectors_test_hires/conf/splice.conf' % self.modeldir 97 | cdef unicode fst_in_str = u'%s/%s/graph/HCLG.fst' % (self.modeldir, self.model) 98 | cdef unicode align_lex_filename = u'%s/%s/graph/phones/align_lexicon.int' % (self.modeldir, self.model) 99 | 100 | # 101 | # make sure all model files required exist 102 | # 103 | 104 | for conff in [mfcc_config, word_symbol_table, model_in_filename, splice_conf_filename, fst_in_str, align_lex_filename]: 105 | if not os.path.isfile(conff.encode('utf8')): 106 | raise Exception ('%s not found.' % conff) 107 | if not os.access(conff.encode('utf8'), os.R_OK): 108 | raise Exception ('%s is not readable' % conff) 109 | 110 | # 111 | # generate ivector_extractor.conf 112 | # 113 | 114 | self.ie_conf_f = NamedTemporaryFile(prefix=u'ivector_extractor_', suffix=u'.conf', delete=True) 115 | 116 | self.ie_conf_f.write((u"--cmvn-config=%s/conf/online_cmvn.conf\n" % self.modeldir).encode('utf8')) 117 | self.ie_conf_f.write((u"--ivector-period=%d\n" % online_ivector_period).encode('utf8')) 118 | self.ie_conf_f.write((u"--splice-config=%s\n" % splice_conf_filename).encode('utf8')) 119 | self.ie_conf_f.write((u"--lda-matrix=%s/extractor/final.mat\n" % self.modeldir).encode('utf8')) 120 | self.ie_conf_f.write((u"--global-cmvn-stats=%s/extractor/global_cmvn.stats\n" % self.modeldir).encode('utf8')) 121 | self.ie_conf_f.write((u"--diag-ubm=%s/extractor/final.dubm\n" % self.modeldir).encode('utf8')) 122 | self.ie_conf_f.write((u"--ivector-extractor=%s/extractor/final.ie\n" % self.modeldir).encode('utf8')) 123 | self.ie_conf_f.write((u"--num-gselect=%d\n" % num_gselect).encode('utf8')) 124 | self.ie_conf_f.write((u"--min-post=%f\n" % min_post).encode('utf8')) 125 | self.ie_conf_f.write((u"--posterior-scale=%f\n" % posterior_scale).encode('utf8')) 126 | self.ie_conf_f.write((u"--max-remembered-frames=1000\n").encode('utf8')) 127 | self.ie_conf_f.write((u"--max-count=%d\n" % max_count).encode('utf8')) 128 | self.ie_conf_f.flush() 129 | 130 | # 131 | # instantiate our C++ wrapper class 132 | # 133 | 134 | self.model_wrapper = new NNet3OnlineModelWrapper(beam, 135 | max_active, 136 | min_active, 137 | lattice_beam, 138 | acoustic_scale, 139 | frame_subsampling_factor, 140 | word_symbol_table.encode('utf8'), 141 | model_in_filename.encode('utf8'), 142 | fst_in_str.encode('utf8'), 143 | mfcc_config.encode('utf8'), 144 | self.ie_conf_f.name.encode('utf8'), 145 | align_lex_filename.encode('utf8')) 146 | 147 | def __dealloc__(self): 148 | if self.ie_conf_f: 149 | self.ie_conf_f.close() 150 | if self.model_wrapper: 151 | del self.model_wrapper 152 | 153 | cdef class KaldiNNet3OnlineDecoder: 154 | 155 | cdef NNet3OnlineDecoderWrapper* decoder_wrapper 156 | cdef object ie_conf_f 157 | 158 | def __cinit__(self, KaldiNNet3OnlineModel model): 159 | 160 | # 161 | # instantiate our C++ wrapper class 162 | # 163 | 164 | self.decoder_wrapper = new NNet3OnlineDecoderWrapper(model.model_wrapper) 165 | 166 | def __dealloc__(self): 167 | del self.decoder_wrapper 168 | 169 | def decode(self, samp_freq, cnp.ndarray[float, ndim=1, mode="c"] samples not None, finalize): 170 | return self.decoder_wrapper.decode(samp_freq, samples.shape[0], samples.data, finalize) 171 | 172 | def get_decoded_string(self): 173 | cdef string decoded_string 174 | cdef double likelihood=0.0 175 | self.decoder_wrapper.get_decoded_string(decoded_string, likelihood) 176 | return decoded_string.decode('utf8'), likelihood 177 | 178 | def get_word_alignment(self): 179 | cdef vector[string] words 180 | cdef vector[int] times 181 | cdef vector[int] lengths 182 | if not self.decoder_wrapper.get_word_alignment(words, times, lengths): 183 | return None 184 | return words, times, lengths 185 | 186 | # 187 | # various convenience functions below 188 | # 189 | 190 | def decode_wav_file(self, object wavfile): 191 | 192 | wavf = wave.open(wavfile, 'rb') 193 | 194 | # check format 195 | assert wavf.getnchannels()==1 196 | assert wavf.getsampwidth()==2 197 | assert wavf.getnframes()>0 198 | 199 | # read the whole file into memory, for now 200 | num_frames = wavf.getnframes() 201 | frames = wavf.readframes(num_frames) 202 | 203 | samples = struct.unpack_from('<%dh' % num_frames, frames) 204 | 205 | wavf.close() 206 | 207 | return self.decode(wavf.getframerate(), np.array(samples, dtype=np.float32), True) 208 | 209 | -------------------------------------------------------------------------------- /kaldiasr/nnet3_wrappers.cpp: -------------------------------------------------------------------------------- 1 | // nnet3_wrappers.cpp 2 | // 3 | // Copyright 2016, 2017 G. Bartsch 4 | // 5 | // based on Kaldi's decoder/decoder-wrappers.cc 6 | 7 | // Copyright 2014 Johns Hopkins University (author: Daniel Povey) 8 | 9 | // See ../../COPYING for clarification regarding multiple authors 10 | // 11 | // Licensed under the Apache License, Version 2.0 (the "License"); 12 | // you may not use this file except in compliance with the License. 13 | // You may obtain a copy of the License at 14 | // 15 | // http://www.apache.org/licenses/LICENSE-2.0 16 | // 17 | // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 18 | // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 19 | // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 20 | // MERCHANTABLITY OR NON-INFRINGEMENT. 21 | // See the Apache 2 License for the specific language governing permissions and 22 | // limitations under the License. 23 | // 24 | 25 | #include "nnet3_wrappers.h" 26 | 27 | #include "lat/lattice-functions.h" 28 | #include "lat/word-align-lattice-lexicon.h" 29 | #include "nnet3/nnet-utils.h" 30 | 31 | #define VERBOSE 0 32 | 33 | namespace kaldi { 34 | 35 | /* 36 | * NNet3OnlineDecoderWrapper 37 | */ 38 | 39 | NNet3OnlineDecoderWrapper::NNet3OnlineDecoderWrapper(NNet3OnlineModelWrapper *aModel) : model(aModel) { 40 | decoder = NULL; 41 | silence_weighting = NULL; 42 | feature_pipeline = NULL; 43 | adaptation_state = NULL; 44 | decodable_info = NULL; 45 | 46 | tot_frames = 0; 47 | tot_frames_decoded = 0; 48 | 49 | #if VERBOSE 50 | KALDI_LOG << "alloc: OnlineIvectorExtractorAdaptationState"; 51 | #endif 52 | adaptation_state = new OnlineIvectorExtractorAdaptationState (model->feature_info->ivector_extractor_info); 53 | 54 | #if VERBOSE 55 | KALDI_LOG << "alloc: OnlineSilenceWeighting"; 56 | #endif 57 | silence_weighting = new OnlineSilenceWeighting (model->trans_model, 58 | model->feature_info->silence_weighting_config, 59 | model->decodable_opts.frame_subsampling_factor); 60 | 61 | #if VERBOSE 62 | KALDI_LOG << "alloc: nnet3::DecodableNnetSimpleLoopedInfo"; 63 | #endif 64 | decodable_info = new nnet3::DecodableNnetSimpleLoopedInfo(model->decodable_opts, &model->am_nnet); 65 | } 66 | 67 | NNet3OnlineDecoderWrapper::~NNet3OnlineDecoderWrapper() { 68 | free_decoder(); 69 | if (silence_weighting) { 70 | delete silence_weighting ; 71 | silence_weighting = NULL; 72 | } 73 | if (adaptation_state) { 74 | delete adaptation_state ; 75 | adaptation_state = NULL; 76 | } 77 | if (decodable_info) { 78 | delete decodable_info; 79 | decodable_info = NULL; 80 | } 81 | } 82 | 83 | void NNet3OnlineDecoderWrapper::start_decoding(void) { 84 | #if VERBOSE 85 | KALDI_LOG << "start_decoding..." ; 86 | KALDI_LOG << "max_active :" << model->lattice_faster_decoder_config.max_active; 87 | KALDI_LOG << "min_active :" << model->lattice_faster_decoder_config.min_active; 88 | KALDI_LOG << "beam :" << model->lattice_faster_decoder_config.beam; 89 | KALDI_LOG << "lattice_beam:" << model->lattice_faster_decoder_config.lattice_beam; 90 | #endif 91 | free_decoder(); 92 | #if VERBOSE 93 | KALDI_LOG << "alloc: OnlineNnet2FeaturePipeline"; 94 | #endif 95 | feature_pipeline = new OnlineNnet2FeaturePipeline (*model->feature_info); 96 | feature_pipeline->SetAdaptationState(*adaptation_state); 97 | #if VERBOSE 98 | KALDI_LOG << "alloc: SingleUtteranceNnet3Decoder"; 99 | #endif 100 | decoder = new SingleUtteranceNnet3Decoder (model->lattice_faster_decoder_config, 101 | model->trans_model, 102 | *decodable_info, 103 | *model->decode_fst, 104 | feature_pipeline); 105 | #if VERBOSE 106 | KALDI_LOG << "start_decoding...done" ; 107 | #endif 108 | } 109 | 110 | void NNet3OnlineDecoderWrapper::free_decoder(void) { 111 | if (decoder) { 112 | #if VERBOSE 113 | KALDI_LOG << "free_decoder"; 114 | #endif 115 | delete decoder ; 116 | decoder = NULL; 117 | } 118 | if (feature_pipeline) { 119 | delete feature_pipeline ; 120 | feature_pipeline = NULL; 121 | } 122 | } 123 | 124 | void NNet3OnlineDecoderWrapper::get_decoded_string(std::string &decoded_string, double &likelihood) { 125 | 126 | //std::string decoded_string; 127 | //double likelihood; 128 | 129 | Lattice best_path_lat; 130 | 131 | decoded_string = ""; 132 | 133 | if (decoder) { 134 | 135 | // decoding is not finished yet, so we will look up the best partial result so far 136 | 137 | if (decoder->NumFramesDecoded() == 0) { 138 | likelihood = 0.0; 139 | return; 140 | } 141 | 142 | decoder->GetBestPath(false, &best_path_lat); 143 | 144 | } else { 145 | ConvertLattice(best_path_clat, &best_path_lat); 146 | } 147 | 148 | std::vector words; 149 | std::vector alignment; 150 | LatticeWeight weight; 151 | int32 num_frames; 152 | GetLinearSymbolSequence(best_path_lat, &alignment, &words, &weight); 153 | num_frames = alignment.size(); 154 | likelihood = -(weight.Value1() + weight.Value2()) / num_frames; 155 | 156 | for (size_t i = 0; i < words.size(); i++) { 157 | std::string s = model->word_syms->Find(words[i]); 158 | if (s == "") 159 | KALDI_ERR << "Word-id " << words[i] << " not in symbol table."; 160 | decoded_string += s + ' '; 161 | } 162 | } 163 | 164 | bool NNet3OnlineDecoderWrapper::get_word_alignment(std::vector &words, 165 | std::vector ×, 166 | std::vector &lengths) { 167 | 168 | WordAlignLatticeLexiconInfo lexicon_info(model->word_alignment_lexicon); 169 | 170 | #if VERBOSE 171 | KALDI_LOG << "word alignment starts..."; 172 | #endif 173 | CompactLattice aligned_clat; 174 | WordAlignLatticeLexiconOpts opts; 175 | 176 | bool ok = WordAlignLatticeLexicon(best_path_clat, model->trans_model, lexicon_info, opts, &aligned_clat); 177 | 178 | if (!ok) { 179 | KALDI_WARN << "Lattice did not align correctly"; 180 | return false; 181 | } else { 182 | if (aligned_clat.Start() == fst::kNoStateId) { 183 | KALDI_WARN << "Lattice was empty"; 184 | return false; 185 | } else { 186 | #if VERBOSE 187 | KALDI_LOG << "Aligned lattice."; 188 | #endif 189 | TopSortCompactLatticeIfNeeded(&aligned_clat); 190 | 191 | // lattice-1best 192 | 193 | CompactLattice best_path_aligned; 194 | CompactLatticeShortestPath(aligned_clat, &best_path_aligned); 195 | 196 | // nbest-to-ctm 197 | 198 | std::vector word_idxs; 199 | if (!CompactLatticeToWordAlignment(best_path_aligned, &word_idxs, ×, &lengths)) { 200 | KALDI_WARN << "CompactLatticeToWordAlignment failed."; 201 | return false; 202 | } 203 | 204 | // lexicon lookup 205 | words.clear(); 206 | for (size_t i = 0; i < word_idxs.size(); i++) { 207 | std::string s = model->word_syms->Find(word_idxs[i]); 208 | if (s == "") { 209 | KALDI_ERR << "Word-id " << word_idxs[i] << " not in symbol table."; 210 | } 211 | words.push_back(s); 212 | } 213 | } 214 | } 215 | return true; 216 | } 217 | 218 | 219 | 220 | bool NNet3OnlineDecoderWrapper::decode(BaseFloat samp_freq, int32 num_frames, BaseFloat *frames, bool finalize) { 221 | 222 | using fst::VectorFst; 223 | 224 | if (!decoder) { 225 | start_decoding(); 226 | } 227 | 228 | Vector wave_part(num_frames, kUndefined); 229 | for (int i=0; iAcceptWaveform(samp_freq, wave_part); 238 | 239 | if (finalize) { 240 | // no more input. flush out last frames 241 | feature_pipeline->InputFinished(); 242 | } 243 | 244 | if (silence_weighting->Active() && feature_pipeline->IvectorFeature() != NULL) { 245 | silence_weighting->ComputeCurrentTraceback(decoder->Decoder()); 246 | silence_weighting->GetDeltaWeights(feature_pipeline->NumFramesReady(), 247 | &delta_weights); 248 | feature_pipeline->IvectorFeature()->UpdateFrameWeights(delta_weights); 249 | } 250 | 251 | decoder->AdvanceDecoding(); 252 | 253 | if (finalize) { 254 | decoder->FinalizeDecoding(); 255 | 256 | CompactLattice clat; 257 | bool end_of_utterance = true; 258 | decoder->GetLattice(end_of_utterance, &clat); 259 | 260 | if (clat.NumStates() == 0) { 261 | KALDI_WARN << "Empty lattice."; 262 | return false; 263 | } 264 | 265 | CompactLatticeShortestPath(clat, &best_path_clat); 266 | 267 | tot_frames_decoded = tot_frames; 268 | tot_frames = 0; 269 | 270 | free_decoder(); 271 | 272 | } 273 | 274 | return true; 275 | } 276 | 277 | 278 | /* 279 | * NNet3OnlineModelWrapper 280 | */ 281 | 282 | // typedef void (*LogHandler)(const LogMessageEnvelope &envelope, 283 | // const char *message); 284 | void silent_log_handler (const LogMessageEnvelope &envelope, 285 | const char *message) { 286 | // nothing - this handler simply keeps silent 287 | } 288 | 289 | NNet3OnlineModelWrapper::NNet3OnlineModelWrapper(BaseFloat beam, 290 | int32 max_active, 291 | int32 min_active, 292 | BaseFloat lattice_beam, 293 | BaseFloat acoustic_scale, 294 | int32 frame_subsampling_factor, 295 | std::string &word_syms_filename, 296 | std::string &model_in_filename, 297 | std::string &fst_in_str, 298 | std::string &mfcc_config, 299 | std::string &ie_conf_filename, 300 | std::string &align_lex_filename) 301 | 302 | { 303 | 304 | using namespace kaldi; 305 | using namespace fst; 306 | 307 | typedef kaldi::int32 int32; 308 | typedef kaldi::int64 int64; 309 | 310 | #if VERBOSE 311 | KALDI_LOG << "model_in_filename: " << model_in_filename; 312 | KALDI_LOG << "fst_in_str: " << fst_in_str; 313 | KALDI_LOG << "mfcc_config: " << mfcc_config; 314 | KALDI_LOG << "ie_conf_filename: " << ie_conf_filename; 315 | KALDI_LOG << "align_lex_filename: " << align_lex_filename; 316 | #else 317 | // silence kaldi output as well 318 | SetLogHandler(silent_log_handler); 319 | #endif 320 | 321 | feature_config.mfcc_config = mfcc_config; 322 | feature_config.ivector_extraction_config = ie_conf_filename; 323 | 324 | lattice_faster_decoder_config.max_active = max_active; 325 | lattice_faster_decoder_config.min_active = min_active; 326 | lattice_faster_decoder_config.beam = beam; 327 | lattice_faster_decoder_config.lattice_beam = lattice_beam; 328 | decodable_opts.acoustic_scale = acoustic_scale; 329 | decodable_opts.frame_subsampling_factor = frame_subsampling_factor; 330 | 331 | feature_info = new OnlineNnet2FeaturePipelineInfo(this->feature_config); 332 | 333 | // load model... 334 | { 335 | bool binary; 336 | Input ki(model_in_filename, &binary); 337 | this->trans_model.Read(ki.Stream(), binary); 338 | this->am_nnet.Read(ki.Stream(), binary); 339 | SetBatchnormTestMode(true, &(this->am_nnet.GetNnet())); 340 | SetDropoutTestMode(true, &(this->am_nnet.GetNnet())); 341 | nnet3::CollapseModel(nnet3::CollapseModelConfig(), &(this->am_nnet.GetNnet())); 342 | } 343 | 344 | // Input FST is just one FST, not a table of FSTs. 345 | decode_fst = fst::ReadFstKaldiGeneric(fst_in_str); 346 | 347 | word_syms = NULL; 348 | if (word_syms_filename != "") 349 | if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename))) 350 | KALDI_ERR << "Could not read symbol table from file " 351 | << word_syms_filename; 352 | 353 | #if VERBOSE 354 | KALDI_LOG << "loading word alignment lexicon..."; 355 | #endif 356 | { 357 | bool binary_in; 358 | Input ki(align_lex_filename, &binary_in); 359 | KALDI_ASSERT(!binary_in && "Not expecting binary file for lexicon"); 360 | if (!ReadLexiconForWordAlign(ki.Stream(), &word_alignment_lexicon)) { 361 | KALDI_ERR << "Error reading alignment lexicon from " 362 | << align_lex_filename; 363 | } 364 | } 365 | } 366 | 367 | NNet3OnlineModelWrapper::~NNet3OnlineModelWrapper() { 368 | delete feature_info; 369 | } 370 | 371 | } 372 | 373 | -------------------------------------------------------------------------------- /kaldiasr/nnet3_wrappers.h: -------------------------------------------------------------------------------- 1 | // nnet3_wrappers.h 2 | // 3 | // Copyright 2016, 2017 G. Bartsch 4 | // 5 | // based on Kaldi's decoder/decoder-wrappers.cc 6 | 7 | // Copyright 2014 Johns Hopkins University (author: Daniel Povey) 8 | 9 | // See ../../COPYING for clarification regarding multiple authors 10 | // 11 | // Licensed under the Apache License, Version 2.0 (the "License"); 12 | // you may not use this file except in compliance with the License. 13 | // You may obtain a copy of the License at 14 | // 15 | // http://www.apache.org/licenses/LICENSE-2.0 16 | // 17 | // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 18 | // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 19 | // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 20 | // MERCHANTABLITY OR NON-INFRINGEMENT. 21 | // See the Apache 2 License for the specific language governing permissions and 22 | // limitations under the License. 23 | // 24 | 25 | #include "base/kaldi-common.h" 26 | #include "util/common-utils.h" 27 | #include "fstext/fstext-lib.h" 28 | #include "nnet3/nnet-am-decodable-simple.h" 29 | #include "online2/online-nnet3-decoding.h" 30 | #include "online2/online-nnet2-feature-pipeline.h" 31 | #include "decoder/lattice-faster-decoder.h" 32 | #include "decoder/lattice-faster-decoder.h" 33 | #include "nnet3/decodable-simple-looped.h" 34 | 35 | namespace kaldi { 36 | class NNet3OnlineModelWrapper { 37 | friend class NNet3OnlineDecoderWrapper; 38 | public: 39 | 40 | NNet3OnlineModelWrapper(BaseFloat beam, 41 | int32 max_active, 42 | int32 min_active, 43 | BaseFloat lattice_beam, 44 | BaseFloat acoustic_scale, 45 | int32 frame_subsampling_factor, 46 | std::string &word_syms_filename, 47 | std::string &model_in_filename, 48 | std::string &fst_in_str, 49 | std::string &mfcc_config, 50 | std::string &ie_conf_filename, 51 | std::string &align_lex_filename 52 | ) ; 53 | ~NNet3OnlineModelWrapper(); 54 | 55 | private: 56 | 57 | fst::SymbolTable *word_syms; 58 | 59 | // feature_config includes configuration for the iVector adaptation, 60 | // as well as the basic features. 61 | OnlineNnet2FeaturePipelineConfig feature_config; 62 | LatticeFasterDecoderConfig lattice_faster_decoder_config; 63 | 64 | OnlineNnet2FeaturePipelineInfo *feature_info; 65 | 66 | nnet3::AmNnetSimple am_nnet; 67 | nnet3::NnetSimpleLoopedComputationOptions decodable_opts; 68 | 69 | TransitionModel trans_model; 70 | //fst::VectorFst *decode_fst; 71 | fst::Fst *decode_fst; 72 | std::string *ie_conf_filename; 73 | 74 | // word alignment: 75 | std::vector > word_alignment_lexicon; 76 | }; 77 | 78 | class NNet3OnlineDecoderWrapper { 79 | public: 80 | 81 | NNet3OnlineDecoderWrapper(NNet3OnlineModelWrapper *aModel); 82 | ~NNet3OnlineDecoderWrapper(); 83 | 84 | bool decode(BaseFloat samp_freq, 85 | int32 num_frames, 86 | BaseFloat *frames, 87 | bool finalize); 88 | 89 | void get_decoded_string(std::string &decoded_string, 90 | double &likelihood); 91 | bool get_word_alignment(std::vector &words, 92 | std::vector ×, 93 | std::vector &lengths); 94 | 95 | private: 96 | 97 | void start_decoding(void); 98 | void free_decoder(void); 99 | 100 | NNet3OnlineModelWrapper *model; 101 | 102 | OnlineIvectorExtractorAdaptationState *adaptation_state; 103 | OnlineNnet2FeaturePipeline *feature_pipeline; 104 | OnlineSilenceWeighting *silence_weighting; 105 | nnet3::DecodableNnetSimpleLoopedInfo *decodable_info; 106 | SingleUtteranceNnet3Decoder *decoder; 107 | 108 | std::vector > delta_weights; 109 | int32 tot_frames, tot_frames_decoded; 110 | 111 | // decoding result: 112 | CompactLattice best_path_clat; 113 | 114 | }; 115 | 116 | 117 | 118 | } 119 | 120 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from setuptools import setup, Extension 3 | import numpy 4 | import subprocess 5 | import sys 6 | import os 7 | 8 | try: 9 | from Cython.Distutils import build_ext 10 | except ImportError: 11 | raise Exception("*** cython is needed to build this extension.") 12 | 13 | cmdclass = {} 14 | ext_modules = [] 15 | 16 | 17 | def getstatusoutput(command): 18 | process = subprocess.Popen(command, stdout=subprocess.PIPE) 19 | out, _ = process.communicate() 20 | return (process.returncode, out) 21 | 22 | 23 | def find_dependencies(): 24 | kw = {} 25 | 26 | flag_map = {'-I': 'include_dirs', '-L': 'library_dirs', '-l': 'libraries'} 27 | 28 | # 29 | # find atlas library (try pkgconfig, if that fails look at usual places) 30 | # 31 | 32 | print("looking for atlas library, trying pkg-config first...") 33 | 34 | # Try pky-config values for atlas 35 | status, output = getstatusoutput( 36 | ["pkg-config", "--libs", "--cflags", "blas-atlas"]) 37 | 38 | # Try pkg-config values for atlas 39 | if status != 0: 40 | status, output = getstatusoutput( 41 | ["pkg-config", "--libs", "--cflags", "atlas"]) 42 | 43 | if status != 0: 44 | 45 | print("looking for atlas library, trying hard-coded paths...") 46 | 47 | found = False 48 | 49 | for libdir in ['/usr/lib', '/usr/lib64', '/usr/lib/x86_64-linux-gnu', 50 | '/usr/lib/i386-linux-gnu']: 51 | if os.path.isfile('%s/libatlas.so.3' % libdir): 52 | found = True 53 | break 54 | if not found: 55 | raise Exception('Failed to find libatlas.so.3 on your system.') 56 | 57 | kw.setdefault('libraries', []).append('%s/atlas.so.3' % libdir) 58 | kw.setdefault('libraries', []).append('%s/cblas.so.3' % libdir) 59 | kw.setdefault('libraries', []).append('%s/f77blas.so.3' % libdir) 60 | kw.setdefault('libraries', []).append('%s/lapack_atlas.so.3' % libdir) 61 | 62 | include_path = None 63 | include_found = False 64 | for include_dir in ['/usr/include/atlas', 65 | '/usr/include/x86_64-linux-gnu/atlas']: 66 | if os.path.isdir(include_dir): 67 | include_found = True 68 | include_path = include_dir 69 | 70 | if not include_found: 71 | raise Exception('Failed to find atlas includes your system.') 72 | 73 | kw.setdefault('include_dirs', []).append(include_path) 74 | 75 | print("looking for atlas library, found it.") 76 | else: 77 | print("looking for atlas library, pkg-config found it") 78 | for token in output.split(): 79 | token = token.decode('utf8') 80 | kw.setdefault(flag_map.get(token[:2]), []).append(token[2:]) 81 | 82 | # 83 | # pkgconfig: kaldi-asr 84 | # 85 | 86 | status, output = getstatusoutput( 87 | ["pkg-config", "--libs", "--cflags", "kaldi-asr"]) 88 | 89 | if status != 0: 90 | raise Exception("*** failed to find pkgconfig for kaldi-asr") 91 | 92 | for token in output.split(): 93 | token = token.decode('utf8') 94 | 95 | prefix = token[:2] 96 | arg = token[2:] 97 | 98 | # print(repr(token)) 99 | # print(repr(prefix)) 100 | 101 | kw.setdefault(flag_map.get(prefix), []).append(arg) 102 | 103 | # print (repr(kw)) 104 | 105 | return kw 106 | 107 | 108 | # CFLAGS = -Wall -pthread -std=c++11 -DKALDI_DOUBLEPRECISION=0 -Wno-sign-compare \ 109 | # -Wno-unused-local-typedefs -Winit-self -DHAVE_EXECINFO_H=1 -DHAVE_CXXABI_H -DHAVE_ATLAS \ 110 | # `pkg-config --cflags kaldi-asr` -g 111 | 112 | ext_modules += [ 113 | Extension("kaldiasr.nnet3", 114 | sources=["kaldiasr/nnet3.pyx", "kaldiasr/nnet3_wrappers.cpp"], 115 | language="c++", 116 | extra_compile_args=['-Wall', '-pthread', '-std=c++11', 117 | '-DKALDI_DOUBLEPRECISION=0', 118 | '-Wno-sign-compare', 119 | '-Wno-unused-local-typedefs', '-Winit-self', 120 | '-DHAVE_EXECINFO_H=1', '-DHAVE_CXXABI_H', 121 | '-DHAVE_ATLAS', '-g'], 122 | **find_dependencies()), 123 | Extension("kaldiasr.gmm", 124 | sources=["kaldiasr/gmm.pyx", "kaldiasr/gmm_wrappers.cpp"], 125 | language="c++", 126 | extra_compile_args=['-Wall', '-pthread', '-std=c++11', 127 | '-DKALDI_DOUBLEPRECISION=0', 128 | '-Wno-sign-compare', 129 | '-Wno-unused-local-typedefs', '-Winit-self', 130 | '-DHAVE_EXECINFO_H=1', '-DHAVE_CXXABI_H', 131 | '-DHAVE_ATLAS', '-g'], 132 | **find_dependencies()), 133 | ] 134 | cmdclass.update({'build_ext': build_ext}) 135 | for e in ext_modules: 136 | e.cython_directives = {'language_level': "3"} 137 | setup( 138 | name='py-kaldi-asr', 139 | version='0.5.2', 140 | description='Simple Python/Cython interface to kaldi-asr nnet3/chain and gmm decoders', 141 | long_description=open('README.md').read(), 142 | author='Guenter Bartsch', 143 | author_email='guenter@zamia.org', 144 | maintainer='Guenter Bartsch', 145 | maintainer_email='guenter@zamia.org', 146 | url='https://github.com/gooofy/py-kaldi-asr', 147 | packages=['kaldiasr'], 148 | cmdclass=cmdclass, 149 | ext_modules=ext_modules, 150 | include_dirs=[numpy.get_include()], 151 | classifiers=[ 152 | 'Operating System :: POSIX :: Linux', 153 | 'License :: OSI Approved :: Apache Software License', 154 | 'Programming Language :: Python :: 2', 155 | 'Programming Language :: Python :: 2.7', 156 | 'Programming Language :: Python :: 3', 157 | 'Programming Language :: Python :: 3.5', 158 | 'Programming Language :: Cython', 159 | 'Programming Language :: C++', 160 | 'Intended Audience :: Developers', 161 | 'Topic :: Software Development :: Libraries :: Python Modules', 162 | 'Topic :: Multimedia :: Sound/Audio :: Speech' 163 | ], 164 | license='Apache', 165 | keywords='kaldi asr', 166 | include_package_data=True, 167 | install_requires=['numpy', 'cython', ], 168 | ) 169 | 170 | --------------------------------------------------------------------------------