├── .bmpfile ├── .clang-format ├── .dockerignore ├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── Dockerfile ├── LICENSE ├── README.md ├── build └── .gitignore ├── include └── kaldiserve │ ├── config.hpp │ ├── decoder.hpp │ ├── model.hpp │ ├── types.hpp │ ├── utils.hpp │ └── vendor │ └── cpptoml.h ├── plugins └── grpc │ ├── Dockerfile │ ├── Dockerfile.lb │ ├── Makefile │ ├── README.md │ ├── build │ ├── .gitignore │ └── config │ │ ├── consul │ │ ├── etc │ │ │ ├── consul │ │ │ │ └── conf.d │ │ │ │ │ └── bootstrap │ │ │ │ │ └── config.json │ │ │ ├── cont-init.d │ │ │ │ ├── 30-consul │ │ │ │ └── 40-consul │ │ │ └── services.d │ │ │ │ ├── consul │ │ │ │ └── run │ │ │ │ └── resolver │ │ │ │ └── run │ │ └── usr │ │ │ └── bin │ │ │ ├── consul-available │ │ │ ├── consul-debug │ │ │ ├── consul-join │ │ │ ├── consul-join-wan │ │ │ ├── consul-node-id │ │ │ └── container-find │ │ └── supervisord │ │ ├── consul-template.conf │ │ ├── consul.conf │ │ ├── grpc_server.conf │ │ ├── register.conf │ │ ├── supervisor-nginx.conf │ │ └── supervisord.conf │ ├── client │ ├── .gitignore │ ├── Makefile │ ├── README.md │ ├── kaldi_serve │ │ ├── __init__.py │ │ ├── core.py │ │ ├── kaldi_serve_pb2.py │ │ ├── kaldi_serve_pb2_grpc.py │ │ └── utils.py │ ├── poetry.lock │ ├── pyproject.toml │ ├── scripts │ │ ├── batch_decode.py │ │ ├── example_client.py │ │ └── list_models.py │ └── tests │ │ ├── conftest.py │ │ ├── resources │ │ └── hi │ │ │ └── .gitignore │ │ └── test_hi.yaml │ ├── examples │ └── aspire │ │ ├── README.md │ │ ├── model │ │ ├── conf │ │ │ ├── ivector_extractor.conf │ │ │ ├── mfcc.conf │ │ │ ├── online.conf │ │ │ ├── online_cmvn.conf │ │ │ └── splice.conf │ │ ├── ivector_extractor │ │ │ ├── final.dubm │ │ │ ├── final.ie │ │ │ ├── final.mat │ │ │ ├── global_cmvn.stats │ │ │ ├── online_cmvn.conf │ │ │ └── splice_opts │ │ ├── word_boundary.int │ │ └── words.txt │ │ ├── model_spec.toml │ │ ├── run_server.sh │ │ └── utils │ │ ├── parse_options.sh │ │ └── setup_aspire_chain_model.sh │ ├── protos │ ├── kaldi_serve.grpc.pb.cc │ ├── kaldi_serve.grpc.pb.h │ ├── kaldi_serve.pb.cc │ ├── kaldi_serve.pb.h │ └── kaldi_serve.proto │ ├── registrar │ └── main.go │ └── src │ ├── app.cc │ ├── config.hpp │ ├── server.hpp │ └── vendor │ └── CLI11.hpp ├── python ├── CMakeLists.txt ├── Dockerfile ├── README.md ├── kaldiserve │ └── __init__.py ├── kaldiserve_pybind │ ├── decoder.cpp │ ├── kaldiserve_pybind.cpp │ ├── kaldiserve_pybind.h │ ├── model.cpp │ ├── types.cpp │ └── utils.cpp ├── scripts │ ├── batch_transcribe.py │ ├── requirements.txt │ └── transcribe.py └── setup.py ├── resources └── model-spec.toml └── src ├── CMakeLists.txt ├── config.cpp ├── decoder ├── decoder-common.cpp ├── decoder-factory.cpp ├── decoder-queue.cpp └── decoder.cpp ├── model └── model-chain.cpp └── utils └── utils-io.cpp /.bmpfile: -------------------------------------------------------------------------------- 1 | ;; -*- mode: emacs-lisp -*- 2 | 3 | (("./src/app.cc" . "VERSION \"\\([0-9]\.[0-9]\.[0-9]\\)\"")) -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | UseTab: false 2 | IndentWidth: 4 3 | AllowShortIfStatementsOnASingleLine: false 4 | IndentCaseLabels: false 5 | ColumnLimit: 0 6 | BreakBeforeBraces: Custom 7 | BraceWrapping: 8 | AfterFunction: false -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | tests/ 2 | .pytest_cache 3 | .pyproject.lock 4 | .vscode 5 | models/ 6 | 7 | # gitlab ci 8 | skaffold.yaml 9 | tox.ini 10 | .git 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /models 2 | /.ccls-cache 3 | *.so 4 | *.o 5 | protos/*.pb* 6 | .vscode 7 | notes.md 8 | 9 | audio 10 | 11 | /python_stream/kaldi/__pycache__ 12 | /.mypy_cache 13 | models.toml 14 | 15 | examples/*/model 16 | 17 | dev* 18 | 19 | .DS_Store 20 | .idea 21 | *.egg-info/ 22 | *.pyc 23 | 24 | # pyenv 25 | python*-env 26 | .python-version 27 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "3rdparty/pybind11"] 2 | path = 3rdparty/pybind11 3 | url = https://github.com/pybind/pybind11.git 4 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.13 FATAL_ERROR) 2 | 3 | set(CMAKE_C_COMPILER "gcc") 4 | set(CMAKE_CXX_COMPILER "g++") 5 | 6 | # Project settings 7 | project(kaldiserve VERSION 1.0.0 LANGUAGES C CXX) 8 | 9 | # Config options 10 | option(BUILD_SHARED_LIB "Build shared library" ON) 11 | option(BUILD_PYTHON_MODULE "Build the python module" OFF) 12 | option(BUILD_PYBIND11 "Build pybind11 for python bindings" OFF) 13 | 14 | # CXX compiler options 15 | set(CMAKE_POSITION_INDEPENDENT_CODE ON) 16 | set(CMAKE_CXX_FLAGS "-std=c++11 -DKALDI_DOUBLEPRECISION=0 -Wno-sign-compare -Wno-unused-local-typedefs -Wno-unused-variable -Winit-self -O2") 17 | 18 | # kaldi config 19 | set(KALDI_ROOT /opt/kaldi CACHE STRING "Path to Kaldi root directory") 20 | 21 | # Build shared library 22 | if(BUILD_SHARED_LIB) 23 | find_package(Boost REQUIRED) 24 | include_directories(${Boost_INCLUDE_DIRS}) 25 | 26 | add_subdirectory(src) 27 | endif() 28 | 29 | # Build python port 30 | if (BUILD_PYTHON_MODULE) 31 | # Pybind11 32 | if (BUILD_PYBIND11) 33 | set(BUILD_PYBIND11 ON) 34 | add_subdirectory(3rdparty/pybind11) 35 | else() 36 | find_package(pybind11 REQUIRED) 37 | endif() 38 | 39 | add_subdirectory(python) 40 | endif() -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM kaldiasr/kaldi:latest 2 | 3 | # build latest cmake 4 | WORKDIR /root 5 | 6 | RUN apt-get update && \ 7 | apt-get upgrade -y && \ 8 | apt-get install -y \ 9 | libssl-dev \ 10 | cmake 11 | 12 | RUN wget https://github.com/Kitware/CMake/releases/download/v3.17.3/cmake-3.17.3.tar.gz && \ 13 | tar -xvf cmake-3.17.3.tar.gz 14 | 15 | WORKDIR /root/cmake-3.17.3 16 | 17 | # using an older cmake to build a newer cmake (>=3.13) 18 | RUN cmake . && \ 19 | make -j$(nproc) && \ 20 | make install 21 | 22 | # install c++ std & boost libs 23 | RUN apt-get update && \ 24 | apt-get install -y \ 25 | g++ \ 26 | make \ 27 | automake \ 28 | libc++-dev \ 29 | libboost-all-dev 30 | 31 | WORKDIR /root/kaldi-serve 32 | COPY . . 33 | 34 | # build libkaldiserve.so 35 | RUN cd build/ && \ 36 | cmake .. -DBUILD_SHARED_LIBS=ON -DBUILD_PYTHON_MODULE=OFF && \ 37 | make -j$(nproc) VERBOSE=1 && \ 38 | cd /root/kaldi-serve 39 | 40 | # KALDISERVE HEADERS & LIB 41 | RUN cp build/src/libkaldiserve.so* /usr/local/lib/ 42 | RUN cp -r include/kaldiserve /usr/include/ 43 | 44 | WORKDIR /root 45 | 46 | # cleanup 47 | RUN rm -rf kaldi-serve cmake-* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Kaldi-Serve 2 | 3 | ![GitHub tag (latest by date)](https://img.shields.io/github/v/tag/Vernacular-ai/kaldi-serve?style=flat-square) ![GitHub](https://img.shields.io/github/license/Vernacular-ai/kaldi-serve?style=flat-square) 4 | 5 | A plug-and-play abstraction over [Kaldi](https://kaldi-asr.org/) ASR toolkit, designed for ease of deployment and optimal runtime performance. 6 | 7 | **Key Features**: 8 | 9 | - Real-time streaming (uni & bi-directional) audio recognition. 10 | - Thread-safe concurrent Decoder queue for server environments. 11 | - RNNLM lattice rescoring. 12 | - N-best alternatives with AM/LM costs, word-level timings and confidence scores. 13 | - Easy extensibility for custom applications. 14 | 15 | ## Installation 16 | 17 | ### Dependencies 18 | 19 | Make sure you have the following dependencies installed on your system before beginning the build process: 20 | 21 | * g++ compiler (>=4.7) that supports C++11 std 22 | * [CMake](https://cmake.org/install/) (>=3.13) 23 | * [Kaldi](https://kaldi-asr.org/) 24 | * [Boost C++](https://www.boost.org/) libraries 25 | 26 | ### Build from Source 27 | 28 | Let's build the shared library: 29 | 30 | ```bash 31 | cd build/ 32 | cmake .. 33 | make -j${nproc} 34 | ``` 35 | 36 | You will find the the built shared library in `build/src/` to use for linking against custom applications. 37 | 38 | #### Python bindings 39 | 40 | We also provide python bindings for the library. You can find the build instructions [here](./python). 41 | 42 | ### Docker Image 43 | 44 | #### Using pre-built images 45 | 46 | You can also pull a pre-built docker image from our [Docker Hub repository](https://hub.docker.com/repository/docker/vernacularai/kaldi-serve): 47 | 48 | ```bash 49 | docker pull vernacularai/kaldi-serve:latest 50 | docker run -it -v /path/to/my/app:/home/app vernacularai/kaldi-serve:latest 51 | ``` 52 | 53 | You will find our headers in `/usr/include/kaldiserve` and the shared library `libkaldiserve.so` in `/usr/local/lib`. 54 | 55 | #### Building the image 56 | 57 | You can build the docker image using the [Dockerfile](./Dockerfile) provided. 58 | 59 | ```bash 60 | docker build -t kaldi-serve:lib . 61 | ``` 62 | 63 | ## Getting Started 64 | 65 | 68 | 69 | ### Usage 70 | 71 | You can include the [headers](./include) and link the shared library you get after the build process, against your application and start using it. 72 | 73 | ### Plugins 74 | 75 | It's also worth noting that there are a few [plugins](./plugins) we actively maintain and will keep adding to, that use the library: 76 | - [gRPC Server](./plugins/grpc) 77 | 78 | ## License 79 | 80 | This project is licensed under the Apache License version 2.0. Please see [LICENSE](./LICENSE) for more details. -------------------------------------------------------------------------------- /build/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /include/kaldiserve/config.hpp: -------------------------------------------------------------------------------- 1 | // Configuration options. 2 | #pragma once 3 | 4 | // stl includes 5 | #include 6 | #include 7 | 8 | // definitions 9 | #define VERSION "1.0.0" 10 | #define ENDL '\n' 11 | // #define DEBUG false 12 | 13 | 14 | namespace kaldiserve { 15 | 16 | template 17 | std::unique_ptr make_uniq(Args&&... args) { 18 | return std::unique_ptr(new T(std::forward(args)...)); 19 | } 20 | 21 | static bool DEBUG = false; 22 | 23 | // prints library version 24 | void print_version(); 25 | 26 | // returns current timestamp 27 | std::string timestamp_now(); 28 | 29 | } // namespace kaldiserve -------------------------------------------------------------------------------- /include/kaldiserve/decoder.hpp: -------------------------------------------------------------------------------- 1 | // Decoding graph and operations. 2 | #pragma once 3 | 4 | // stl includes 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | // kaldi includes 14 | #include "base/kaldi-common.h" 15 | #include "rnnlm/rnnlm-lattice-rescoring.h" 16 | #include "util/common-utils.h" 17 | #include "lat/compose-lattice-pruned.h" 18 | #include "feat/wave-reader.h" 19 | #include "fstext/fstext-lib.h" 20 | #include "lat/kaldi-lattice.h" 21 | #include "lat/lattice-functions.h" 22 | #include "lat/word-align-lattice.h" 23 | #include "lat/sausages.h" 24 | #include "nnet3/nnet-utils.h" 25 | #include "online2/online-endpoint.h" 26 | #include "online2/online-nnet2-feature-pipeline.h" 27 | #include "online2/online-nnet3-decoding.h" 28 | #include "online2/onlinebin-util.h" 29 | #include "util/kaldi-thread.h" 30 | 31 | // local includes 32 | #include "config.hpp" 33 | #include "types.hpp" 34 | #include "model.hpp" 35 | #include "utils.hpp" 36 | 37 | 38 | namespace kaldiserve { 39 | 40 | // Forward declare class for friendship (hack for now) 41 | class ChainModel; 42 | 43 | 44 | class Decoder final { 45 | 46 | public: 47 | explicit Decoder(ChainModel *const model); 48 | 49 | ~Decoder() noexcept; 50 | 51 | // SETUP METHODS 52 | void start_decoding(const std::string &uuid="") noexcept; 53 | 54 | void free_decoder() noexcept; 55 | 56 | // STREAMING METHODS 57 | 58 | // decode an intermediate frame/chunk of a wav audio stream 59 | void decode_stream_wav_chunk(std::istream &wav_stream); 60 | 61 | // decode an intermediate frame/chunk of a raw headerless wav audio stream 62 | void decode_stream_raw_wav_chunk(std::istream &wav_stream, 63 | const float &samp_freq, 64 | const int &data_bytes); 65 | 66 | // NON-STREAMING METHODS 67 | 68 | // decodes an (independent) wav audio stream 69 | // internally chunks a wav audio stream and decodes them 70 | void decode_wav_audio(std::istream &wav_stream, 71 | const float &chunk_size=1); 72 | 73 | // decodes an (independent) raw headerless wav audio stream 74 | // internally chunks a wav audio stream and decodes them 75 | void decode_raw_wav_audio(std::istream &wav_stream, 76 | const float &samp_freq, 77 | const int &data_bytes, 78 | const float &chunk_size=1); 79 | 80 | // LATTICE DECODING METHODS 81 | 82 | // get the final utterances based on the compact lattice 83 | void get_decoded_results(const int &n_best, 84 | utterance_results_t &results, 85 | const bool &word_level=false, 86 | const bool &bidi_streaming=false); 87 | 88 | DecoderOptions options{false, false}; 89 | 90 | private: 91 | // decodes an intermediate wavepart 92 | void _decode_wave(kaldi::SubVector &wave_part, 93 | std::vector> &delta_weights, 94 | const kaldi::BaseFloat &samp_freq); 95 | 96 | // gets the final decoded transcripts from lattice 97 | void _find_alternatives(kaldi::CompactLattice &clat, 98 | const std::size_t &n_best, 99 | utterance_results_t &results, 100 | const bool &word_level) const; 101 | 102 | // model vars 103 | ChainModel *model_; 104 | 105 | // decoder vars (per utterance) 106 | kaldi::SingleUtteranceNnet3Decoder *decoder_; 107 | kaldi::OnlineNnet2FeaturePipeline *feature_pipeline_; 108 | kaldi::OnlineSilenceWeighting *silence_weighting_; 109 | kaldi::OnlineIvectorExtractorAdaptationState *adaptation_state_; 110 | 111 | // req-specific vars 112 | std::string uuid_; 113 | }; 114 | 115 | 116 | // Factory for creating decoders with shared decoding graph and model parameters 117 | // Caches the graph and params to be able to produce decoders on demand. 118 | class DecoderFactory final { 119 | 120 | public: 121 | ModelSpec model_spec; 122 | 123 | explicit DecoderFactory(const ModelSpec &model_spec); 124 | 125 | inline Decoder *produce() const { 126 | return new Decoder(model_.get()); 127 | } 128 | 129 | // friendly alias for the producer method 130 | inline Decoder *operator()() const { 131 | return produce(); 132 | } 133 | 134 | private: 135 | std::unique_ptr model_; 136 | }; 137 | 138 | 139 | // Decoder Queue for providing thread safety to multiple request handler 140 | // threads producing and consuming decoder instances on demand. 141 | class DecoderQueue final { 142 | 143 | public: 144 | explicit DecoderQueue(const ModelSpec &); 145 | 146 | DecoderQueue(const DecoderQueue &) = delete; // disable copying 147 | 148 | DecoderQueue &operator=(const DecoderQueue &) = delete; // disable assignment 149 | 150 | ~DecoderQueue(); 151 | 152 | // friendly alias for `pop` 153 | inline Decoder *acquire() { 154 | return pop_(); 155 | } 156 | 157 | // friendly alias for `push` 158 | inline void release(Decoder *const decoder) { 159 | return push_(decoder); 160 | } 161 | 162 | private: 163 | // Push method that supports multi-threaded thread-safe concurrency 164 | // pushes a decoder object onto the queue 165 | void push_(Decoder *const); 166 | 167 | // Pop method that supports multi-threaded thread-safe concurrency 168 | // pops a decoder object from the queue 169 | Decoder *pop_(); 170 | 171 | // underlying STL "unsafe" queue for storing decoder objects 172 | std::queue queue_; 173 | // custom mutex to make queue "thread-safe" 174 | std::mutex mutex_; 175 | // helper for holding mutex and notification on waiting threads when concerned resources are available 176 | std::condition_variable cond_; 177 | // factory for producing new decoders on demand 178 | std::unique_ptr decoder_factory_; 179 | }; 180 | 181 | 182 | void find_alternatives(kaldi::CompactLattice &clat, 183 | const std::size_t &n_best, 184 | utterance_results_t &results, 185 | const bool &word_level, 186 | ChainModel *const model, 187 | const DecoderOptions &options); 188 | 189 | 190 | // Find confidence by merging lm and am scores. Taken from 191 | // https://github.com/dialogflow/asr-server/blob/master/src/OnlineDecoder.cc#L90 192 | // NOTE: This might not be very useful for us right now. Depending on the 193 | // situation, we might actually want to weigh components differently. 194 | static inline double calculate_confidence(const float &lm_score, const float &am_score, const int &n_words) noexcept { 195 | return std::max(0.0, std::min(1.0, -0.0001466488 * (2.388449 * lm_score + am_score) / (n_words + 1) + 0.956)); 196 | } 197 | 198 | 199 | static inline void print_wav_info(const kaldi::WaveInfo &wave_info) noexcept { 200 | std::cout << "sample freq: " << wave_info.SampFreq() << ENDL 201 | << "sample count: " << wave_info.SampleCount() << ENDL 202 | << "num channels: " << wave_info.NumChannels() << ENDL 203 | << "reverse bytes: " << wave_info.ReverseBytes() << ENDL 204 | << "dat bytes: " << wave_info.DataBytes() << ENDL 205 | << "is streamed: " << wave_info.IsStreamed() << ENDL 206 | << "block align: " << wave_info.BlockAlign() << ENDL; 207 | } 208 | 209 | 210 | static void read_raw_wav_stream(std::istream &wav_stream, 211 | const size_t &data_bytes, 212 | kaldi::Matrix &wav_data, 213 | const size_t &num_channels = 1, 214 | const size_t &sample_width = 2) { 215 | const size_t bits_per_sample = sample_width * 8; 216 | const size_t block_align = num_channels * sample_width; 217 | 218 | std::vector buffer(data_bytes); 219 | wav_stream.read(&buffer[0], data_bytes); 220 | 221 | if (wav_stream.bad()) 222 | KALDI_ERR << "WaveData: file read error"; 223 | 224 | if (buffer.size() == 0) 225 | KALDI_ERR << "WaveData: empty file (no data)"; 226 | 227 | if (buffer.size() < data_bytes) { 228 | KALDI_WARN << "Expected " << data_bytes << " bytes of wave data, " 229 | << "but read only " << buffer.size() << " bytes. " 230 | << "Truncated file?"; 231 | } 232 | 233 | uint16 *data_ptr = reinterpret_cast(&buffer[0]); 234 | 235 | // The matrix is arranged row per channel, column per sample. 236 | wav_data.Resize(num_channels, data_bytes / block_align); 237 | for (uint32 i = 0; i < wav_data.NumCols(); ++i) { 238 | for (uint32 j = 0; j < wav_data.NumRows(); ++j) { 239 | int16 k = *data_ptr++; 240 | wav_data(j, i) = k; 241 | } 242 | } 243 | } 244 | 245 | } // namespace kaldiserve -------------------------------------------------------------------------------- /include/kaldiserve/model.hpp: -------------------------------------------------------------------------------- 1 | // model.hpp - Model Wrapper Interface 2 | #pragma once 3 | 4 | // stl includes 5 | #include 6 | 7 | // kaldi includes 8 | #include "base/kaldi-common.h" 9 | #include "util/common-utils.h" 10 | #include "rnnlm/rnnlm-lattice-rescoring.h" 11 | #include "fstext/fstext-lib.h" 12 | #include "nnet3/nnet-utils.h" 13 | #include "online2/online-nnet2-feature-pipeline.h" 14 | #include "online2/online-nnet3-decoding.h" 15 | #include "online2/onlinebin-util.h" 16 | #include "util/kaldi-thread.h" 17 | 18 | // local includes 19 | #include "config.hpp" 20 | #include "types.hpp" 21 | #include "decoder.hpp" 22 | #include "utils.hpp" 23 | 24 | 25 | namespace kaldiserve { 26 | 27 | // Chain (DNN-HMM NNet3) Model is a data class that holds all the 28 | // immutable ASR Model components that can be shared across Decoder instances. 29 | class ChainModel final { 30 | 31 | public: 32 | explicit ChainModel(const ModelSpec &model_spec); 33 | 34 | // Model Config 35 | ModelSpec model_spec; 36 | 37 | // HCLG.fst graph 38 | std::unique_ptr> decode_fst; 39 | 40 | // NNet3 AM 41 | kaldi::nnet3::AmNnetSimple am_nnet; 42 | // Transition Model (HMM) 43 | kaldi::TransitionModel trans_model; 44 | 45 | // Word Symbols table (int->word) 46 | std::unique_ptr word_syms; 47 | 48 | // Online Feature Pipeline options 49 | std::unique_ptr feature_info; 50 | // 51 | std::unique_ptr decodable_info; 52 | 53 | kaldi::LatticeFasterDecoderConfig lattice_faster_decoder_config; 54 | kaldi::nnet3::NnetSimpleLoopedComputationOptions decodable_opts; 55 | 56 | // Word Boundary info (for word level timings) 57 | std::unique_ptr wb_info; 58 | 59 | // NNet3 RNNLM 60 | kaldi::nnet3::Nnet rnnlm; 61 | // Word Embeddings matrix 62 | kaldi::CuMatrix word_embedding_mat; 63 | // Original G.fst LM 64 | std::unique_ptr> lm_to_subtract_fst; 65 | // RNNLM info object (encapsulates RNNLM, Word Embeddings and RNNLM options) 66 | std::unique_ptr rnnlm_info; 67 | 68 | // RNNLM interpolation weight 69 | kaldi::BaseFloat rnnlm_weight; 70 | // RNNLM options 71 | kaldi::rnnlm::RnnlmComputeStateComputationOptions rnnlm_opts; 72 | // LM composition options 73 | kaldi::ComposeLatticePrunedOptions compose_opts; 74 | }; 75 | 76 | } // namespace kaldiserve -------------------------------------------------------------------------------- /include/kaldiserve/types.hpp: -------------------------------------------------------------------------------- 1 | // Configuration options. 2 | #pragma once 3 | 4 | // stl includes 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "config.hpp" 11 | 12 | 13 | namespace kaldiserve { 14 | 15 | // Model Specification for Kaldi ASR 16 | // contains model config for a particular model. 17 | struct ModelSpec { 18 | std::string name; 19 | std::string language_code; 20 | std::string path; 21 | int n_decoders = 1; 22 | 23 | // decoding parameters 24 | int min_active = 200; 25 | int max_active = 7000; 26 | int frame_subsampling_factor = 3; 27 | float beam = 16.0; 28 | float lattice_beam = 6.0; 29 | float acoustic_scale = 1.0; 30 | float silence_weight = 1.0; 31 | 32 | // rnnlm config 33 | int max_ngram_order = 3; 34 | float rnnlm_weight = 0.5; 35 | std::string bos_index = "1"; 36 | std::string eos_index = "2"; 37 | }; 38 | 39 | struct Word { 40 | float start_time, end_time, confidence; 41 | std::string word; 42 | }; 43 | 44 | // An alternative defines a single hypothesis and certain details about the 45 | // parse (only scores for now). 46 | struct Alternative { 47 | std::string transcript; 48 | double confidence; 49 | float am_score, lm_score; 50 | std::vector words; 51 | }; 52 | 53 | // Options for decoder 54 | struct DecoderOptions { 55 | bool enable_word_level; 56 | bool enable_rnnlm; 57 | }; 58 | 59 | // Result for one continuous utterance 60 | using utterance_results_t = std::vector; 61 | 62 | // a pair of model_name and language_code 63 | using model_id_t = std::pair; 64 | 65 | } // namespace kaldiserve -------------------------------------------------------------------------------- /include/kaldiserve/utils.hpp: -------------------------------------------------------------------------------- 1 | // Utility functions. 2 | #pragma once 3 | 4 | // stl includes 5 | #include 6 | #include 7 | 8 | // local includes 9 | #include "config.hpp" 10 | #include "types.hpp" 11 | 12 | 13 | namespace kaldiserve { 14 | 15 | // If the provided path is relative, expand by prefixing the root_path 16 | std::string expand_relative_path(std::string path, std::string root_path); 17 | 18 | std::string join_path(std::string a, std::string b); 19 | 20 | bool exists(std::string path); 21 | 22 | // Fills a list of model specifications from the config 23 | void parse_model_specs(const std::string &toml_path, std::vector &model_specs); 24 | 25 | // Joins vector of strings together using a separator token 26 | void string_join(const std::vector &strings, std::string separator, std::string &output); 27 | 28 | } // namespace kaldiserve -------------------------------------------------------------------------------- /plugins/grpc/Dockerfile: -------------------------------------------------------------------------------- 1 | # Stage 1: dev/build 2 | FROM vernacularai/kaldi-serve:latest as builder 3 | 4 | # gRPC Pre-requisites - https://github.com/grpc/grpc/blob/master/BUILDING.md 5 | RUN apt-get update && \ 6 | apt-get upgrade -y && \ 7 | apt-get install -y \ 8 | build-essential \ 9 | autoconf \ 10 | libtool \ 11 | pkg-config \ 12 | libgflags-dev \ 13 | libgtest-dev \ 14 | clang \ 15 | libc++-dev \ 16 | libboost-all-dev \ 17 | curl \ 18 | vim 19 | 20 | # Install gRPC 21 | RUN cd /root/ && \ 22 | git clone -b v1.28.1 https://github.com/grpc/grpc && \ 23 | cd /root/grpc/ && \ 24 | git submodule update --init && \ 25 | make -j$(nproc) && \ 26 | make install 27 | 28 | # Install Protobuf v3 29 | RUN cd /root/grpc/third_party/protobuf && make install 30 | 31 | WORKDIR /root/kaldi-serve 32 | COPY . . 33 | 34 | WORKDIR /root/kaldi-serve/plugins/grpc 35 | ENV LD_LIBRARY_PATH="/opt/kaldi/tools/openfst/lib:/opt/kaldi/src/lib" 36 | RUN make clean && \ 37 | make KALDI_ROOT="/opt/kaldi" KALDISERVE_INCLUDE="/usr/include" -j$(nproc) 38 | 39 | RUN bash -c "mkdir /so-files/; cp /opt/intel/mkl/lib/intel64/lib*.so /so-files/" 40 | 41 | # Stage 2: prod 42 | FROM debian:jessie-slim 43 | WORKDIR /home/app 44 | 45 | COPY --from=builder /root/kaldi-serve/plugins/grpc/build/kaldi_serve_app . 46 | 47 | # LIBS 48 | COPY --from=builder /usr/lib/x86_64-linux-gnu/libssl.so* /usr/local/lib/ 49 | COPY --from=builder /usr/lib/x86_64-linux-gnu/libcrypto.so* /usr/local/lib/ 50 | 51 | # CPP LIBS 52 | COPY --from=builder /usr/lib/x86_64-linux-gnu/libstdc++.so* /usr/local/lib/ 53 | 54 | # BOOST LIBS 55 | COPY --from=builder /usr/lib/x86_64-linux-gnu/libboost_system.so* /usr/local/lib/ 56 | COPY --from=builder /usr/lib/x86_64-linux-gnu/libboost_filesystem.so* /usr/local/lib/ 57 | 58 | # GRPC LIBS 59 | COPY --from=builder /usr/local/lib/libgrpc++.so* /usr/local/lib/ 60 | COPY --from=builder /usr/local/lib/libgrpc++_reflection.so* /usr/local/lib/ 61 | COPY --from=builder /usr/local/lib/libgrpc.so* /usr/local/lib/ 62 | COPY --from=builder /usr/local/lib/libgpr.so* /usr/local/lib/ 63 | COPY --from=builder /usr/local/lib/libupb.so* /usr/local/lib/ 64 | 65 | # INTEL MKL 66 | COPY --from=builder /so-files /opt/intel/mkl/lib/intel64 67 | 68 | # KALDI LIBS 69 | COPY --from=builder /opt/kaldi/tools/openfst/lib/libfst.so.10 /opt/kaldi/tools/openfst/lib/ 70 | COPY --from=builder /opt/kaldi/src/lib/libkaldi-*.so /opt/kaldi/src/lib/ 71 | 72 | # KALDISERVE LIB 73 | COPY --from=builder /usr/local/lib/libkaldiserve.so* /usr/local/lib/ 74 | 75 | ENV LD_LIBRARY_PATH="/usr/local/lib:/opt/kaldi/tools/openfst/lib:/opt/kaldi/src/lib" 76 | 77 | CMD [ "./kaldi_serve_app" ] 78 | -------------------------------------------------------------------------------- /plugins/grpc/Dockerfile.lb: -------------------------------------------------------------------------------- 1 | # Stage 1: dev/build 2 | FROM vernacularai/kaldi-serve:latest as builder 3 | 4 | # gRPC Pre-requisites - https://github.com/grpc/grpc/blob/master/BUILDING.md 5 | RUN apt-get update && \ 6 | apt-get upgrade -y && \ 7 | apt-get install -y \ 8 | build-essential \ 9 | autoconf \ 10 | libtool \ 11 | pkg-config \ 12 | libgflags-dev \ 13 | libgtest-dev \ 14 | clang \ 15 | libc++-dev \ 16 | libboost-all-dev \ 17 | curl \ 18 | vim 19 | 20 | # Install gRPC 21 | RUN cd /root/ && \ 22 | git clone -b v1.28.1 https://github.com/grpc/grpc && \ 23 | cd /root/grpc/ && \ 24 | git submodule update --init && \ 25 | make -j$(nproc) && \ 26 | make install 27 | 28 | # Install Protobuf v3 29 | RUN cd /root/grpc/third_party/protobuf && make install 30 | 31 | WORKDIR /root/kaldi-serve 32 | COPY . . 33 | 34 | WORKDIR /root/kaldi-serve/plugins/grpc 35 | ENV LD_LIBRARY_PATH="/opt/kaldi/tools/openfst/lib:/opt/kaldi/src/lib" 36 | RUN rm -f ./protos/*.pb.* ./protos/*.o&& \ 37 | make KALDI_ROOT="/opt/kaldi" KALDISERVE_INCLUDE="/usr/include" -j$(nproc) 38 | 39 | RUN bash -c "mkdir /so-files/; cp /opt/intel/mkl/lib/intel64/lib*.so /so-files/" 40 | 41 | # Stage 2: registrar 42 | 43 | FROM golang:1.12.5 as builder2 44 | 45 | RUN mkdir -p $GOPATH/src/kaldi_serve/ 46 | 47 | WORKDIR $GOPATH/src/kaldi_serve/ 48 | COPY . . 49 | 50 | ENV GO111MODULE on 51 | 52 | RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix cgo -o build/register plugins/grpc/registrar/main.go 53 | 54 | # Stage 3: prod 55 | FROM debian:jessie-slim 56 | 57 | COPY --from=builder /root/kaldi-serve/plugins/grpc/build/kaldi_serve_app /bin/ 58 | COPY --from=builder2 /go/src/kaldi_serve/build/register /bin/ 59 | 60 | # LIBS 61 | COPY --from=builder /usr/lib/x86_64-linux-gnu/libssl.so* /usr/local/lib/ 62 | COPY --from=builder /usr/lib/x86_64-linux-gnu/libcrypto.so* /usr/local/lib/ 63 | 64 | # CPP LIBS 65 | COPY --from=builder /usr/lib/x86_64-linux-gnu/libstdc++.so* /usr/local/lib/ 66 | 67 | # BOOST LIBS 68 | COPY --from=builder /usr/lib/x86_64-linux-gnu/libboost_system.so* /usr/local/lib/ 69 | COPY --from=builder /usr/lib/x86_64-linux-gnu/libboost_filesystem.so* /usr/local/lib/ 70 | 71 | # GRPC LIBS 72 | COPY --from=builder /usr/local/lib/libgrpc++.so* /usr/local/lib/ 73 | COPY --from=builder /usr/local/lib/libgrpc++_reflection.so* /usr/local/lib/ 74 | COPY --from=builder /usr/local/lib/libgrpc.so* /usr/local/lib/ 75 | COPY --from=builder /usr/local/lib/libgpr.so* /usr/local/lib/ 76 | COPY --from=builder /usr/local/lib/libupb.so* /usr/local/lib/ 77 | 78 | # INTEL MKL 79 | COPY --from=builder /so-files /opt/intel/mkl/lib/intel64 80 | 81 | # KALDI LIBS 82 | COPY --from=builder /opt/kaldi/tools/openfst/lib/libfst.so.10 /opt/kaldi/tools/openfst/lib/ 83 | COPY --from=builder /opt/kaldi/src/lib/libkaldi-*.so /opt/kaldi/src/lib/ 84 | 85 | # KALDISERVE LIB 86 | COPY --from=builder /usr/local/lib/libkaldiserve.so* /usr/local/lib/ 87 | 88 | ENV LD_LIBRARY_PATH="/usr/local/lib:/opt/kaldi/tools/openfst/lib:/opt/kaldi/src/lib" 89 | 90 | ADD plugins/grpc/build/config/supervisord /etc/ 91 | 92 | ENV CONSUL_VERSION=1.9.1 \ 93 | CONSUL_DOMAIN=consul \ 94 | CONSUL_DATA_DIR=/data/consul \ 95 | CONSUL_CONFIG_DIR=/etc/consul/conf.d/bootstrap \ 96 | CONSUL_SERVER_NAME=consul \ 97 | CONSUL_DC=dc1 \ 98 | CONSUL_CLIENT=0.0.0.0 \ 99 | CONSUL_RETRY_INTERVAL=5s 100 | 101 | # Download and install Consul 102 | RUN apt-get update && \ 103 | apt-get install curl util-linux unzip supervisor -y && \ 104 | mkdir -p /var/log/supervisor/ && \ 105 | mkdir -p /etc/supervisor/ && \ 106 | curl -sSLo /tmp/consul.zip https://releases.hashicorp.com/consul/{$CONSUL_VERSION}/consul_${CONSUL_VERSION}_linux_amd64.zip && \ 107 | unzip -d /bin /tmp/consul.zip && \ 108 | rm /tmp/consul.zip && \ 109 | apt-get autoremove --purge curl unzip -y && \ 110 | groupadd --system consul && \ 111 | useradd -s /sbin/nologin --system -g consul consul && \ 112 | mkdir -p /data/consul && \ 113 | chown -R consul:consul /data/consul && \ 114 | rm -rf /tmp/* /var/cache/apt/lists/* && \ 115 | mkdir /etc/consul.d/ 116 | 117 | # Add the files 118 | COPY plugins/grpc/build/config/consul / 119 | 120 | # Supervisor files 121 | COPY plugins/grpc/build/config/supervisord/grpc_server.conf /etc/supervisor/conf.d/grpc_server.conf 122 | COPY plugins/grpc/build/config/supervisord/consul.conf /etc/supervisor/conf.d/consul.conf 123 | COPY plugins/grpc/build/config/supervisord/register.conf /etc/supervisor/conf.d/register.conf 124 | COPY plugins/grpc/build/config/supervisord/supervisord.conf /etc/supervisor/supervisord.conf 125 | 126 | VOLUME ["/data/consul"] 127 | 128 | # Same exposed ports than consul 129 | EXPOSE 8300 8301 8301/udp 8302 8302/udp 8400 8500 8600 8600/udp 53 53/udp 130 | 131 | ENTRYPOINT ["supervisord", "-c", "/etc/supervisor/supervisord.conf"] 132 | # Command to run: docker run -p 8500:8500 -e CONSUL_SERVER=127.0.0.1 -d -t base_img_test 133 | -------------------------------------------------------------------------------- /plugins/grpc/Makefile: -------------------------------------------------------------------------------- 1 | KALDI_ROOT=/opt/kaldi 2 | KALDISERVE_INCLUDE=../../include 3 | 4 | HOST_SYSTEM = $(shell uname | cut -f 1 -d_) 5 | SYSTEM ?= $(HOST_SYSTEM) 6 | CXX = g++ 7 | CPPFLAGS += `pkg-config --cflags protobuf grpc` 8 | CXXFLAGS += -std=c++11 -O3 9 | 10 | LDFLAGS += -L/usr/local/lib `pkg-config --libs protobuf grpc++` \ 11 | -Wl,--no-as-needed -lgrpc++_reflection -Wl,--as-needed -ldl 12 | 13 | INCLUDES = -I${KALDISERVE_INCLUDE} -I${KALDI_ROOT}/src -I${KALDI_ROOT}/tools/openfst/include 14 | LIBS = -rdynamic -lboost_system -lkaldiserve -static-libstdc++ 15 | 16 | PROTOC = protoc 17 | GRPC_CPP_PLUGIN = grpc_cpp_plugin 18 | GRPC_CPP_PLUGIN_PATH ?= `which $(GRPC_CPP_PLUGIN)` 19 | 20 | PROTOS_PATH = ./protos 21 | 22 | vpath %.proto $(PROTOS_PATH) 23 | 24 | all: system-check build/kaldi_serve_app 25 | 26 | build/kaldi_serve_app: $(PROTOS_PATH)/kaldi_serve.pb.o $(PROTOS_PATH)/kaldi_serve.grpc.pb.o build/kaldi_serve_app.o 27 | $(CXX) $^ $(LDFLAGS) $(LIBS) -o $@ 28 | 29 | build/kaldi_serve_app.o: src/app.cc $(wildcard src/*.hpp) 30 | $(CXX) $(CXXFLAGS) $(INCLUDES) -I $(PROTOS_PATH) -c src/app.cc -o $@ 31 | 32 | .PRECIOUS: %.grpc.pb.cc 33 | %.grpc.pb.cc: %.proto 34 | $(PROTOC) -I $(PROTOS_PATH) --grpc_out=$(PROTOS_PATH) --plugin=protoc-gen-grpc=$(GRPC_CPP_PLUGIN_PATH) $< 35 | 36 | .PRECIOUS: %.pb.cc 37 | %.pb.cc: %.proto 38 | $(PROTOC) -I $(PROTOS_PATH) --cpp_out=$(PROTOS_PATH) $< 39 | 40 | clean: 41 | rm -f ./build/* $(PROTOS_PATH)/*.pb.* $(PROTOS_PATH)/*.o 42 | 43 | # The following is to test your system and ensure a smoother experience. 44 | # They are by no means necessary to actually compile a grpc-enabled software. 45 | 46 | PROTOC_CMD = which $(PROTOC) 47 | PROTOC_CHECK_CMD = $(PROTOC) --version | grep -q libprotoc.3 48 | PLUGIN_CHECK_CMD = which $(GRPC_CPP_PLUGIN) 49 | HAS_PROTOC = $(shell $(PROTOC_CMD) > /dev/null && echo true || echo false) 50 | ifeq ($(HAS_PROTOC),true) 51 | HAS_VALID_PROTOC = $(shell $(PROTOC_CHECK_CMD) 2> /dev/null && echo true || echo false) 52 | endif 53 | HAS_PLUGIN = $(shell $(PLUGIN_CHECK_CMD) > /dev/null && echo true || echo false) 54 | 55 | SYSTEM_OK = false 56 | ifeq ($(HAS_VALID_PROTOC),true) 57 | ifeq ($(HAS_PLUGIN),true) 58 | SYSTEM_OK = true 59 | endif 60 | endif 61 | 62 | system-check: 63 | ifneq ($(HAS_VALID_PROTOC),true) 64 | @echo " DEPENDENCY ERROR" 65 | @echo 66 | @echo "You don't have protoc 3.0.0 installed in your path." 67 | @echo "Please install Google protocol buffers 3.0.0 and its compiler." 68 | @echo "You can find it here:" 69 | @echo 70 | @echo " https://github.com/google/protobuf/releases/tag/v3.0.0" 71 | @echo 72 | @echo "Here is what I get when trying to evaluate your version of protoc:" 73 | @echo 74 | -$(PROTOC) --version 75 | @echo 76 | @echo 77 | endif 78 | ifneq ($(HAS_PLUGIN),true) 79 | @echo " DEPENDENCY ERROR" 80 | @echo 81 | @echo "You don't have the grpc c++ protobuf plugin installed in your path." 82 | @echo "Please install grpc. You can find it here:" 83 | @echo 84 | @echo " https://github.com/grpc/grpc" 85 | @echo 86 | @echo "Here is what I get when trying to detect if you have the plugin:" 87 | @echo 88 | -which $(GRPC_CPP_PLUGIN) 89 | @echo 90 | @echo 91 | endif 92 | ifneq ($(SYSTEM_OK),true) 93 | @false 94 | endif 95 | -------------------------------------------------------------------------------- /plugins/grpc/README.md: -------------------------------------------------------------------------------- 1 | # Kaldi-Serve gRPC Plugin 2 | 3 | [gRPC](https://grpc.io/) server & client components for [Kaldi](https://kaldi-asr.org/) based ASR. 4 | 5 | ## Installation 6 | 7 | ### Dependencies 8 | 9 | Make sure you have the following dependencies installed on your system before beginning the build process: 10 | 11 | * g++ compiler (>=4.7) that supports C++11 std 12 | * [CMake](https://cmake.org/install/) (>=3.13) 13 | * [Kaldi](https://kaldi-asr.org/) 14 | * [gRPC](https://github.com/grpc/grpc) 15 | * [Boost C++](https://www.boost.org/) libraries 16 | 17 | ### Build from source 18 | 19 | Make sure you have built the kaldiserve shared library and placed in `/usr/local/lib`. Let's build the server application using the kaldiserve library: 20 | 21 | ```bash 22 | make KALDI_ROOT="/path/to/local/repo/for/kaldi/" -j${nproc} 23 | ``` 24 | 25 | Run `make clean` to clear old build files. 26 | 27 | ### Docker Image 28 | 29 | #### Using pre-built images 30 | 31 | You can also pull a pre-built docker image from our [Docker Hub repository](https://hub.docker.com/repository/docker/vernacularai/kaldi-serve): 32 | 33 | ```bash 34 | docker pull vernacularai/kaldi-serve:latest-grpc 35 | docker run -it -p 5016:5016 -v /models:/home/app/models vernacularai/kaldi-serve:latest-grpc resources/model-spec.toml 36 | ``` 37 | 38 | You will find the built server application `kaldi_serve_app` in `/home/app`. 39 | 40 | #### Building the image 41 | 42 | You can build the docker image using the [Dockerfile](./Dockerfile) provided. 43 | 44 | ```bash 45 | cd ../../ 46 | docker build -t kaldi-serve:grpc -f plugins/grpc/Dockerfile . 47 | ``` 48 | 49 | You will get a stripped down production ready image from running the above command as we use multi-stage docker builds. In case you need a **development** image, build the image as follows: 50 | 51 | ```bash 52 | docker build --target builder -t kaldi-serve:grpc-dev -f plugins/grpc/Dockerfile . 53 | ``` 54 | 55 | ## Getting Started 56 | 57 | ### Server 58 | 59 | For running the server, you need to first specify model config in a toml which 60 | tells the program which models to load, where to look for etc. Structure of 61 | `model_spec_toml` file is specified in a sample in [resources](../../resources/model-spec.toml). 62 | 63 | ```bash 64 | ./kaldi_serve_app --help 65 | 66 | Kaldi gRPC server 67 | Usage: ./kaldi_serve_app [OPTIONS] model_spec_toml 68 | 69 | Positionals: 70 | model_spec_toml TEXT:FILE REQUIRED 71 | Path to toml specifying models to load 72 | 73 | Options: 74 | -h,--help Print this help message and exit 75 | -v,--version Show program version and exit 76 | -d,--debug Enable debug request logging 77 | ``` 78 | 79 | Please also see our [Aspire example](./examples/aspire) on how to get a server up and running with your models. 80 | 81 | #### Python Client 82 | 83 | A [Python gRPC client](./client) is also provided with a few example scripts (client SDK needs to be installed via [poetry](https://github.com/python-poetry/poetry)). For simple microphone testing, you can do something like the following (make sure the server is running on the same machine on the specified port, default: 5016): 84 | 85 | ```bash 86 | cd client/ 87 | poetry run python scripts/example_client.py mic --n-secs=5 --model=general --lang=hi 88 | ``` 89 | 90 | The output should look something like the following: 91 | ```bash 92 | { 93 | "results": [ 94 | { 95 | "alternatives": [ 96 | { 97 | "transcript": "हेलो दुनिया", 98 | "confidence": 0.95897794, 99 | "am_score": -374.5963, 100 | "lm_score": 131.33058 101 | }, 102 | { 103 | "transcript": "हैलो दुनिया", 104 | "confidence": 0.95882875, 105 | "am_score": -372.76187, 106 | "lm_score": 131.84035 107 | } 108 | ] 109 | } 110 | ] 111 | } 112 | ``` -------------------------------------------------------------------------------- /plugins/grpc/build/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !config/**/* 4 | 5 | !*/ -------------------------------------------------------------------------------- /plugins/grpc/build/config/consul/etc/consul/conf.d/bootstrap/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "bootstrap_expect": 3, 3 | "server": true, 4 | "data_dir": "/data/consul", 5 | "disable_update_check": true 6 | } 7 | -------------------------------------------------------------------------------- /plugins/grpc/build/config/consul/etc/cont-init.d/30-consul: -------------------------------------------------------------------------------- 1 | #!/usr/bin/with-contenv sh 2 | 3 | # Unless this has already been defined, set it. 4 | if [ -z "$CONSUL_BOOTSTRAP_LOG_FILE" ]; then 5 | printf "/var/log/consul-bootstrap/consul-bootstrap.log" > /var/run/s6/container_environment/CONSUL_BOOTSTRAP_LOG_FILE 6 | fi 7 | 8 | # Unless this has already been defined, set it. 9 | if [ -z "$CONSUL_RUNAS" ]; then 10 | printf "consul" > /var/run/s6/container_environment/CONSUL_RUNAS 11 | fi 12 | -------------------------------------------------------------------------------- /plugins/grpc/build/config/consul/etc/cont-init.d/40-consul: -------------------------------------------------------------------------------- 1 | #!/usr/bin/with-contenv sh 2 | 3 | mkdir -p `dirname $CONSUL_BOOTSTRAP_LOG_FILE` 4 | -------------------------------------------------------------------------------- /plugins/grpc/build/config/consul/etc/services.d/consul/run: -------------------------------------------------------------------------------- 1 | #!/usr/bin/with-contenv sh 2 | 3 | # 4 | # This script will wait until another Consul server is available before starting Consul. 5 | # 6 | 7 | # Bind to the external (LAN) IP. 8 | BIND=`host-ip` 9 | 10 | # Wait until another Consul server is available. 11 | consul-available 12 | 13 | # We have another Consul server available. Join it. 14 | RETRY_JOIN=`consul-join` 15 | 16 | # Provide an entry point for Consul WAN customisation. 17 | CONSUL_JOIN_WAN=`consul-join-wan` 18 | 19 | # Retrieve a UUID for -node-id. 20 | NODE_ID=`consul-node-id` 21 | 22 | # Start Consul. 23 | s6-setuidgid $CONSUL_RUNAS consul agent -node-id $NODE_ID -data-dir $CONSUL_DATA_DIR -config-dir $CONSUL_CONFIG_DIR -bind $BIND -advertise $BIND -client $CONSUL_CLIENT $RETRY_JOIN -retry-interval $CONSUL_RETRY_INTERVAL -domain $CONSUL_DOMAIN -datacenter $CONSUL_DC $CONSUL_JOIN_WAN 24 | -------------------------------------------------------------------------------- /plugins/grpc/build/config/consul/etc/services.d/resolver/run: -------------------------------------------------------------------------------- 1 | #!/usr/bin/with-contenv sh 2 | 3 | # Retrieve the external (LAN) IP, as this is what Consul's :8600 port has been bound too. 4 | BIND=`host-ip` 5 | 6 | # Start go-dnsmasq. 7 | s6-setuidgid $GO_DNSMASQ_RUNAS go-dnsmasq --default-resolver --ndots "1" --fwd-ndots "0" --stubzones=".$CONSUL_DOMAIN/$BIND:8600" --hostsfile=/etc/hosts >> $GO_DNSMASQ_LOG_FILE 2>&1 8 | -------------------------------------------------------------------------------- /plugins/grpc/build/config/consul/usr/bin/consul-available: -------------------------------------------------------------------------------- 1 | #!/usr/bin/with-contenv sh 2 | 3 | # 4 | # This script will wait indenfinitely until a DNS query for `consul` (or $CONSUL_SERVER_NAME) can be resolved. 5 | # 6 | # The script will exit 0 once a DNS query has resolved. 7 | # 8 | 9 | # Allow the sleep timeframe to be customised. Default to 1. 10 | SLEEP=${CONSUL_AVAILABLE_SLEEP:-1} 11 | 12 | # Wait until consul-join finds another Consul server. 13 | until consul-join > /dev/null 2>&1 14 | do 15 | 16 | # Output debug messages if required. 17 | consul-debug "[consul-available] sleep ($SLEEP)..." 18 | 19 | sleep $SLEEP 20 | 21 | done 22 | 23 | consul-debug "[consul-available] consul is available" 24 | 25 | # We have another Consul server available. Exit 0! 26 | exit 0 27 | -------------------------------------------------------------------------------- /plugins/grpc/build/config/consul/usr/bin/consul-debug: -------------------------------------------------------------------------------- 1 | #!/usr/bin/with-contenv sh 2 | 3 | # 4 | # This script will take anything from stdin and append to a log file. 5 | # 6 | 7 | # Output debug messages if required. 8 | if [ "$CONSUL_BOOTSTRAP_DEBUG" = "true" ]; then 9 | echo "$@" >> $CONSUL_BOOTSTRAP_LOG_FILE 10 | fi 11 | -------------------------------------------------------------------------------- /plugins/grpc/build/config/consul/usr/bin/consul-join: -------------------------------------------------------------------------------- 1 | #!/usr/bin/with-contenv sh 2 | 3 | # 4 | # This script will return a list of IP addresses (that resolved from a DNS query for `consul` 5 | # or $CONSUL_SERVER_NAME) prefixed with -retry-join. If none can be found the script will exit 1. 6 | # 7 | # For example: 8 | # -retry-join 192.168.0.8 -retry-join 192.168.0.10 -retry-join 192.168.0.19 9 | # 10 | 11 | # Get the IP of the current container. We don't want this container to match itself. 12 | HOST_IP=`host-ip` 13 | 14 | # Look for an IP that doesn't match that of this container. 15 | JOIN=`container-find | grep -v "$HOST_IP" | awk '{ printf " -retry-join " $0}'` 16 | 17 | # If we didn't find another IP, exit 1. 18 | if test -z "$JOIN"; then 19 | exit 1; 20 | fi 21 | 22 | # Output debug messages if required. 23 | consul-debug "[consul-join] $JOIN" 24 | 25 | # We found at least one. Exit 0! 26 | echo "$JOIN" 27 | exit 0 28 | -------------------------------------------------------------------------------- /plugins/grpc/build/config/consul/usr/bin/consul-join-wan: -------------------------------------------------------------------------------- 1 | #!/usr/bin/with-contenv sh 2 | 3 | # 4 | # This script provides an entry point to customise the `-retry-join-wan` and `-advertise-wan` 5 | # values via the environment variables `CONSUL_JOIN_WAN` and `CONSUL_ADVERTISE_WAN`. 6 | # 7 | # If those ENV variables are not set however, this file does nothing. 8 | # 9 | 10 | EXTRA="" 11 | 12 | if [ "$CONSUL_JOIN_WAN" ]; then 13 | EXTRA=$EXTRA" -retry-join-wan $CONSUL_JOIN_WAN" 14 | fi 15 | 16 | if [ "$CONSUL_ADVERTISE_WAN" ]; then 17 | EXTRA=$EXTRA" -advertise-wan $CONSUL_ADVERTISE_WAN" 18 | fi 19 | 20 | consul-debug "[consul-join] $JOIN" 21 | 22 | echo "$EXTRA" 23 | exit 0 24 | -------------------------------------------------------------------------------- /plugins/grpc/build/config/consul/usr/bin/consul-node-id: -------------------------------------------------------------------------------- 1 | #!/usr/bin/with-contenv sh 2 | 3 | # 4 | # This script provides an entry point to customise the `-node-id` arguments 5 | # passed to consul when starting in agent mode. 6 | # 7 | 8 | # Based on many tests, the following produced the most reliable UUIDs. 9 | # The built-in one Consul provides intermittently clashed, as did `cat /proc/sys/kernel/random/uuid`. 10 | NODE_ID=$(uuidgen) 11 | 12 | # Output debug messages if required. 13 | consul-debug "[consul-node-id] node-id: $NODE_ID..." 14 | 15 | echo $NODE_ID 16 | exit 0 17 | -------------------------------------------------------------------------------- /plugins/grpc/build/config/consul/usr/bin/container-find: -------------------------------------------------------------------------------- 1 | #!/usr/bin/with-contenv sh 2 | 3 | # 4 | # This script is responsible for returning multiple IP addresses for the specific 5 | # container name (defaulting to $CONSUL_SERVER_NAME (consul)). 6 | # 7 | # Return format should be one IP address per line: 8 | # 9 | # 192.168.0.8 10 | # 192.168.0.10 11 | # 192.168.0.19 12 | # 13 | 14 | dig +short ${1:-$CONSUL_SERVER_NAME} 15 | -------------------------------------------------------------------------------- /plugins/grpc/build/config/supervisord/consul-template.conf: -------------------------------------------------------------------------------- 1 | [program:consul-template] 2 | priority = 1 3 | command = consul-template -template "/etc/consul-templates/load-balancer.tpl:/etc/nginx/conf.d/load-balancer.conf:supervisorctl restart nginx" 4 | stdout_capture_maxbytes = 1MB 5 | redirect_stderr = true 6 | stdout_logfile = /var/log/supervisor/%(program_name)s.log 7 | -------------------------------------------------------------------------------- /plugins/grpc/build/config/supervisord/consul.conf: -------------------------------------------------------------------------------- 1 | [program:consul] 2 | command = consul agent -bind '{{ GetInterfaceIP "eth0" }}' -retry-join='%(ENV_CONSUL_SERVER)s' -data-dir /tmp/consul -config-dir /etc/consul.d/ 3 | stdout_capture_maxbytes = 1MB 4 | redirect_stderr = true 5 | stdout_logfile = /var/log/supervisor/%(program_name)s.log -------------------------------------------------------------------------------- /plugins/grpc/build/config/supervisord/grpc_server.conf: -------------------------------------------------------------------------------- 1 | [program:grpc_server] 2 | command = kaldi_serve_app /home/app/models.toml 3 | startsecs = 0 4 | autorestart = false 5 | startretries = 1 6 | redirect_stderr = true 7 | stdout_logfile = /var/log/supervisor/%(program_name)s.log 8 | -------------------------------------------------------------------------------- /plugins/grpc/build/config/supervisord/register.conf: -------------------------------------------------------------------------------- 1 | [program:register] 2 | command = register 3 | startsecs = 0 4 | autorestart = false 5 | startretries = 1 6 | redirect_stderr = true 7 | stdout_logfile = /var/log/supervisor/%(program_name)s.log 8 | -------------------------------------------------------------------------------- /plugins/grpc/build/config/supervisord/supervisor-nginx.conf: -------------------------------------------------------------------------------- 1 | [program:nginx] 2 | command = nginx 3 | autostart = true 4 | autorestart = unexpected 5 | exitcodes = 0 6 | redirect_stderr = true 7 | stdout_logfile = /var/log/supervisor/%(program_name)s.log 8 | -------------------------------------------------------------------------------- /plugins/grpc/build/config/supervisord/supervisord.conf: -------------------------------------------------------------------------------- 1 | [unix_http_server] 2 | file = /tmp/supervisor.sock ; the path to the socket file 3 | 4 | [supervisord] 5 | logfile = /var/log/supervisor/supervisord.log 6 | logfile_maxbytes = 50MB ; max main logfile bytes b4 rotation; default 50MB 7 | logfile_backups = 10 ; # of main logfile backups; 0 means none, default 10 8 | loglevel = info ; log level; default info; others: debug,warn,trace 9 | pidfile = /var/run/supervisord.pid 10 | nodaemon = true 11 | minfds = 1024 ; min. avail startup file descriptors; default 1024 12 | minprocs = 200 ; min. avail process descriptors;default 200 13 | directory = /tmp 14 | 15 | [rpcinterface:supervisor] 16 | supervisor.rpcinterface_factory = supervisor.rpcinterface:make_main_rpcinterface 17 | 18 | [supervisorctl] 19 | serverurl = unix:///tmp/supervisor.sock ; use a unix:// URL for a unix socket 20 | 21 | [include] 22 | files = /etc/supervisor/conf.d/*.conf 23 | -------------------------------------------------------------------------------- /plugins/grpc/client/.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.gitignore.io/api/python 2 | # Edit at https://www.gitignore.io/?templates=python 3 | 4 | ### Python ### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # celery beat schedule file 98 | celerybeat-schedule 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # Environments 104 | .env 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | # End of https://www.gitignore.io/api/python -------------------------------------------------------------------------------- /plugins/grpc/client/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: clean test 2 | 3 | all: kaldi_serve_pb2_grpc.py 4 | 5 | kaldi_serve_pb2_grpc.py: ../protos/kaldi_serve.proto 6 | poetry run python -m grpc_tools.protoc -I ../protos --python_out=./kaldi_serve --grpc_python_out=./kaldi_serve ../protos/kaldi_serve.proto 7 | # Hack to fix python import path issue in grpc code generation 8 | sed -i 's/import kaldi_serve_pb2 as kaldi__serve__pb2/import kaldi_serve.kaldi_serve_pb2 as kaldi__serve__pb2/' ./kaldi_serve/kaldi_serve_pb2_grpc.py 9 | 10 | clean: 11 | rm ./kaldi_serve/kaldi_serve_pb2_grpc.py ./kaldi_serve/kaldi_serve_pb2.py 12 | 13 | test: 14 | poetry run pytest 15 | -------------------------------------------------------------------------------- /plugins/grpc/client/README.md: -------------------------------------------------------------------------------- 1 | # Kaldi-Serve gRPC Client 2 | 3 | Python gRPC client for kaldi-serve. 4 | 5 | ```bash 6 | poetry install 7 | make # for generating python grpc related modules 8 | # Check out ./scripts dir for sample code 9 | ``` 10 | 11 | ## Testing 12 | 13 | Use format specified in `./tests/test_hi.yaml` and run `poetry run pytest`. 14 | -------------------------------------------------------------------------------- /plugins/grpc/client/kaldi_serve/__init__.py: -------------------------------------------------------------------------------- 1 | from kaldi_serve.core import KaldiServeClient 2 | from kaldi_serve.kaldi_serve_pb2 import RecognitionAudio, RecognitionConfig 3 | -------------------------------------------------------------------------------- /plugins/grpc/client/kaldi_serve/core.py: -------------------------------------------------------------------------------- 1 | from google.protobuf.empty_pb2 import Empty 2 | import grpc 3 | 4 | from kaldi_serve.kaldi_serve_pb2 import RecognitionConfig, RecognizeRequest 5 | from kaldi_serve.kaldi_serve_pb2_grpc import KaldiServeStub 6 | from kaldi_serve.kaldi_serve_pb2_grpc import google_dot_protobuf_dot_empty__pb2 as proto_empty 7 | 8 | 9 | class KaldiServeClient(object): 10 | """ 11 | Service that implements Kaldi API. 12 | 13 | Reference: https://github.com/googleapis/google-cloud-python/blob/3ba1ae73070769854a1f7371305c13752c0374ba/speech/google/cloud/speech_v1/gapic/speech_client.py 14 | """ 15 | 16 | def __init__(self, kaldi_serve_url="0.0.0.0:5016"): 17 | self.channel = grpc.insecure_channel(kaldi_serve_url) 18 | self._client = KaldiServeStub(self.channel) 19 | 20 | def list_models(self, timeout=None): 21 | return self._client.ListModels(proto_empty.Empty(), timeout=timeout) 22 | 23 | def recognize(self, config: RecognitionConfig, audio, uuid: str, timeout=None): 24 | request = RecognizeRequest(config=config, audio=audio, uuid=uuid) 25 | return self._client.Recognize(request, timeout=timeout) 26 | 27 | def streaming_recognize(self, config: RecognitionConfig, audio_chunks_gen, uuid: str, timeout=None): 28 | request_gen = (RecognizeRequest(config=config, audio=chunk, uuid=uuid) for chunk in audio_chunks_gen) 29 | return self._client.StreamingRecognize(request_gen, timeout=timeout) 30 | 31 | def streaming_recognize_raw(self, audio_params_gen, uuid: str, timeout=None): 32 | request_gen = (RecognizeRequest(config=config, audio=chunk, uuid=uuid) for config, chunk in audio_params_gen) 33 | return self._client.StreamingRecognize(request_gen, timeout=timeout) 34 | 35 | def bidi_streaming_recognize(self, config: RecognitionConfig, audio_chunks_gen, uuid: str, timeout=None): 36 | request_gen = (RecognizeRequest(config=config, audio=chunk, uuid=uuid) for chunk in audio_chunks_gen) 37 | return self._client.BidiStreamingRecognize(request_gen, timeout=timeout) 38 | 39 | def bidi_streaming_recognize_raw(self, audio_params_gen, uuid: str, timeout=None): 40 | request_gen = (RecognizeRequest(config=config, audio=chunk, uuid=uuid) for config, chunk in audio_params_gen) 41 | return self._client.BidiStreamingRecognize(request_gen, timeout=timeout) -------------------------------------------------------------------------------- /plugins/grpc/client/kaldi_serve/kaldi_serve_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | import grpc 3 | 4 | from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 5 | import kaldi_serve.kaldi_serve_pb2 as kaldi__serve__pb2 6 | 7 | 8 | class KaldiServeStub(object): 9 | # missing associated documentation comment in .proto file 10 | pass 11 | 12 | def __init__(self, channel): 13 | """Constructor. 14 | 15 | Args: 16 | channel: A grpc.Channel. 17 | """ 18 | self.ListModels = channel.unary_unary( 19 | '/kaldi_serve.KaldiServe/ListModels', 20 | request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, 21 | response_deserializer=kaldi__serve__pb2.ModelList.FromString, 22 | ) 23 | self.Recognize = channel.unary_unary( 24 | '/kaldi_serve.KaldiServe/Recognize', 25 | request_serializer=kaldi__serve__pb2.RecognizeRequest.SerializeToString, 26 | response_deserializer=kaldi__serve__pb2.RecognizeResponse.FromString, 27 | ) 28 | self.StreamingRecognize = channel.stream_unary( 29 | '/kaldi_serve.KaldiServe/StreamingRecognize', 30 | request_serializer=kaldi__serve__pb2.RecognizeRequest.SerializeToString, 31 | response_deserializer=kaldi__serve__pb2.RecognizeResponse.FromString, 32 | ) 33 | self.BidiStreamingRecognize = channel.stream_stream( 34 | '/kaldi_serve.KaldiServe/BidiStreamingRecognize', 35 | request_serializer=kaldi__serve__pb2.RecognizeRequest.SerializeToString, 36 | response_deserializer=kaldi__serve__pb2.RecognizeResponse.FromString, 37 | ) 38 | 39 | 40 | class KaldiServeServicer(object): 41 | # missing associated documentation comment in .proto file 42 | pass 43 | 44 | def ListModels(self, request, context): 45 | """Lists all the available loaded models 46 | """ 47 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 48 | context.set_details('Method not implemented!') 49 | raise NotImplementedError('Method not implemented!') 50 | 51 | def Recognize(self, request, context): 52 | """Performs synchronous non-streaming speech recognition. 53 | """ 54 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 55 | context.set_details('Method not implemented!') 56 | raise NotImplementedError('Method not implemented!') 57 | 58 | def StreamingRecognize(self, request_iterator, context): 59 | """Performs synchronous client-to-server streaming speech recognition: 60 | receive results after all audio has been streamed and processed. 61 | """ 62 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 63 | context.set_details('Method not implemented!') 64 | raise NotImplementedError('Method not implemented!') 65 | 66 | def BidiStreamingRecognize(self, request_iterator, context): 67 | """Performs synchronous bidirectional streaming speech recognition: 68 | receive results as the audio is being streamed and processed. 69 | """ 70 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 71 | context.set_details('Method not implemented!') 72 | raise NotImplementedError('Method not implemented!') 73 | 74 | 75 | def add_KaldiServeServicer_to_server(servicer, server): 76 | rpc_method_handlers = { 77 | 'ListModels': grpc.unary_unary_rpc_method_handler( 78 | servicer.ListModels, 79 | request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, 80 | response_serializer=kaldi__serve__pb2.ModelList.SerializeToString, 81 | ), 82 | 'Recognize': grpc.unary_unary_rpc_method_handler( 83 | servicer.Recognize, 84 | request_deserializer=kaldi__serve__pb2.RecognizeRequest.FromString, 85 | response_serializer=kaldi__serve__pb2.RecognizeResponse.SerializeToString, 86 | ), 87 | 'StreamingRecognize': grpc.stream_unary_rpc_method_handler( 88 | servicer.StreamingRecognize, 89 | request_deserializer=kaldi__serve__pb2.RecognizeRequest.FromString, 90 | response_serializer=kaldi__serve__pb2.RecognizeResponse.SerializeToString, 91 | ), 92 | 'BidiStreamingRecognize': grpc.stream_stream_rpc_method_handler( 93 | servicer.BidiStreamingRecognize, 94 | request_deserializer=kaldi__serve__pb2.RecognizeRequest.FromString, 95 | response_serializer=kaldi__serve__pb2.RecognizeResponse.SerializeToString, 96 | ), 97 | } 98 | generic_handler = grpc.method_handlers_generic_handler( 99 | 'kaldi_serve.KaldiServe', rpc_method_handlers) 100 | server.add_generic_rpc_handlers((generic_handler,)) 101 | -------------------------------------------------------------------------------- /plugins/grpc/client/kaldi_serve/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for working with ASR, audio, devices etc. 3 | """ 4 | 5 | import io 6 | import wave 7 | 8 | import pyaudio 9 | 10 | from pydub import AudioSegment 11 | from pydub.silence import detect_nonsilent 12 | 13 | 14 | def raw_bytes_to_wav(data: bytes, frame_rate: int, channels: int, sample_width: int) -> bytes: 15 | """ 16 | Convert raw PCM bytes to wav bytes (with the initial 44 bytes header) 17 | """ 18 | 19 | out = io.BytesIO() 20 | wf = wave.open(out, "wb") 21 | wf.setnchannels(channels) 22 | wf.setsampwidth(sample_width) 23 | wf.setframerate(frame_rate) 24 | wf.writeframes(data) 25 | wf.close() 26 | return out.getvalue() 27 | 28 | 29 | def chunks_from_mic(secs: int, frame_rate: int, channels: int): 30 | """ 31 | Generate wave audio chunks from microphone worth `secs` seconds. 32 | """ 33 | 34 | p = pyaudio.PyAudio() 35 | sample_format = pyaudio.paInt16 36 | 37 | # ~ 1sec of audio 38 | chunk_size = frame_rate 39 | 40 | stream = p.open(format=sample_format, 41 | channels=channels, 42 | rate=frame_rate, 43 | frames_per_buffer=chunk_size, 44 | input=True) 45 | 46 | sample_width = p.get_sample_size(sample_format) 47 | 48 | print('recording...') 49 | for _ in range(0, int(frame_rate / chunk_size * secs)): 50 | # The right way probably is to not send headers at all and let the 51 | # server side's chunk handler maintain state, taking data from 52 | # metadata. 53 | yield raw_bytes_to_wav(stream.read(chunk_size), frame_rate, channels, sample_width) 54 | 55 | stream.stop_stream() 56 | stream.close() 57 | p.terminate() 58 | 59 | 60 | def chunks_from_file(filename: str, sample_rate=8000, chunk_size=1, raw=False, pcm=False): 61 | """ 62 | Return wav chunks of given size (in seconds) from the file. 63 | """ 64 | 65 | # TODO: Should remove assumptions about audio properties from here 66 | audio = AudioSegment.from_file(filename, 67 | format="s16le" if pcm else "wav", 68 | frame_rate=sample_rate, channels=1, 69 | sample_width=2) 70 | return chunks_from_audio_segment(audio, chunk_size=chunk_size, raw=pcm or raw) 71 | 72 | def chunks_from_audio_segment(audio: str, chunk_size=1, raw=False): 73 | """ 74 | Return wav chunks of given size (in seconds) from the audio segment. 75 | """ 76 | if audio.duration_seconds <= chunk_size: 77 | if raw: 78 | return [audio.raw_data] 79 | else: 80 | audio_stream = io.BytesIO() 81 | audio.export(audio_stream, format="wav") 82 | return [audio_stream.getvalue()] 83 | 84 | chunks = [] 85 | for i in range(0, len(audio), int(chunk_size * 1000)): 86 | chunk = audio[i: i + chunk_size * 1000] 87 | if raw: 88 | chunks.append(chunk.raw_data) 89 | else: 90 | chunk_stream = io.BytesIO() 91 | chunk.export(chunk_stream, format="wav") 92 | chunks.append(chunk_stream.getvalue()) 93 | 94 | return chunks 95 | 96 | def byte_stream_from_file(filename: str, sample_rate=8000, raw: bool=False): 97 | audio = AudioSegment.from_file(filename, format="wav", 98 | frame_rate=sample_rate, 99 | channels=1, sample_width=2) 100 | 101 | if raw: 102 | return audio.raw_data 103 | 104 | byte_stream = io.BytesIO() 105 | audio.export(byte_stream, format="wav") 106 | return byte_stream.getvalue() -------------------------------------------------------------------------------- /plugins/grpc/client/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "kaldi_serve" 3 | version = "0.3.0" 4 | description = "Python bindings for kaldi streaming ASR" 5 | authors = [] 6 | 7 | [tool.poetry.dependencies] 8 | python = "^3.6" 9 | pydub = "^0.23.1" 10 | grpcio = "^1.22" 11 | grpcio-tools = "^1.22" 12 | docopt = "^0.6.2" 13 | pyaudio = "^0.2.11" 14 | 15 | [tool.poetry.dev-dependencies] 16 | pytest = "^5.0" 17 | pyyaml = "^5.1" 18 | tqdm = "^4.48.2" 19 | 20 | [build-system] 21 | requires = ["poetry>=0.12"] 22 | build-backend = "poetry.masonry.api" 23 | -------------------------------------------------------------------------------- /plugins/grpc/client/scripts/batch_decode.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for transcribing audios in batches using Kaldi-Serve ASR server. 3 | 4 | Usage: 5 | batch_decode.py [--model=] [--lang=] [--sample-rate=] [--max-alternatives=] [--num-proc=] [--output-json=] [--raw] [--transcripts-only] 6 | 7 | Options: 8 | --model= Name of the model to hit. [default: general] 9 | --lang= Language code of the model. [default: en] 10 | --sample-rate= Sampling rate to use for the audio. [default: 8000] 11 | --max-alternatives= Number of maximum alternatives to query from the server. [default: 10] 12 | --num-proc= Number of parallel processes. [default: 8] 13 | --output-json= Output json file path for decoded transcriptions. [default: transcripts.json] 14 | --raw Flag that specifies whether to stream raw audio bytes to server. 15 | --transcripts-only Flag that specifies whether or now to keep decoder metadata for transcripts. 16 | """ 17 | 18 | import json 19 | import random 20 | import traceback 21 | 22 | from typing import List 23 | from docopt import docopt 24 | from tqdm import tqdm 25 | from multiprocessing import Pool 26 | from concurrent.futures import ThreadPoolExecutor 27 | 28 | from kaldi_serve import KaldiServeClient, RecognitionAudio, RecognitionConfig 29 | from kaldi_serve.utils import byte_stream_from_file 30 | 31 | ENCODING = RecognitionConfig.AudioEncoding.LINEAR16 32 | 33 | client = KaldiServeClient() 34 | 35 | def run_multiprocessing(func, tasks, num_processes=None): 36 | with Pool(processes=num_processes) as pool: 37 | results = list(tqdm(pool.imap(func, tasks), total=len(tasks))) 38 | return results 39 | 40 | def run_multithreading(func, tasks, num_workers=None): 41 | with ThreadPoolExecutor(max_workers=num_workers) as executor: 42 | results = list(tqdm(executor.map(func, tasks), total=len(tasks))) 43 | return results 44 | 45 | def parse_response(response): 46 | output = [] 47 | 48 | for res in response.results: 49 | output.append([ 50 | { 51 | "transcript": alt.transcript, 52 | "confidence": alt.confidence, 53 | "am_score": alt.am_score, 54 | "lm_score": alt.lm_score 55 | } 56 | for alt in res.alternatives 57 | ]) 58 | return output 59 | 60 | 61 | def transcribe_audio(audio_stream, model: str, language_code: str, sample_rate=8000, max_alternatives=10, raw: bool=False): 62 | """ 63 | Transcribe the given audio chunks 64 | """ 65 | global client 66 | 67 | try: 68 | audio = RecognitionAudio(content=audio_stream) 69 | 70 | config = RecognitionConfig( 71 | sample_rate_hertz=sample_rate, 72 | encoding=ENCODING, 73 | language_code=language_code, 74 | max_alternatives=max_alternatives, 75 | model=model, 76 | raw=raw, 77 | data_bytes=len(audio_stream) 78 | ) 79 | 80 | response = client.recognize(config, audio, uuid=str(random.randint(1000, 100000)), timeout=1000) 81 | except Exception as e: 82 | print(f"error: {str(e)}") 83 | return [] 84 | 85 | return parse_response(response) 86 | 87 | def stream_and_transcribe(audio_path: str, model: str, language_code: str, sample_rate=8000, max_alternatives=10, raw: bool=False): 88 | try: 89 | audio_stream = byte_stream_from_file(audio_path, sample_rate, raw) 90 | result = transcribe_audio(audio_stream, model, language_code, sample_rate, max_alternatives, raw) 91 | return result 92 | except Exception as e: 93 | print('Error while handling {}'.format(audio_path)) 94 | print(e) 95 | return None 96 | 97 | def stream_and_transcribe_wrapper(args): 98 | return stream_and_transcribe(*args) 99 | 100 | def decode_files(audio_paths: List[str], model: str, language_code: str, 101 | sample_rate=8000, max_alternatives=10, raw: bool=False, 102 | num_proc: int=8): 103 | """ 104 | Decode files using parallel requests 105 | """ 106 | args = [ 107 | (path, model, language_code, sample_rate, max_alternatives, raw) 108 | for path in audio_paths 109 | ] 110 | 111 | results = run_multithreading(stream_and_transcribe_wrapper, args) 112 | 113 | results_dict = {path: response for path, response in list(zip(audio_paths, results)) if response is not None} 114 | return results_dict 115 | 116 | 117 | if __name__ == "__main__": 118 | args = docopt(__doc__) 119 | 120 | # args 121 | model = args["--model"] 122 | language_code = args["--lang"] 123 | sample_rate = int(args["--sample-rate"]) 124 | max_alternatives = int(args["--max-alternatives"]) 125 | raw = args["--raw"] 126 | 127 | num_proc = int(args["--num-proc"]) 128 | output_json = args["--output-json"] 129 | 130 | transcripts_only = args["--transcripts-only"] 131 | 132 | audio_paths_file = args[""] 133 | with open(audio_paths_file, "r", encoding="utf-8") as f: 134 | audio_paths = f.read().split("\n") 135 | 136 | audio_paths = list(filter(lambda x: x.endswith(".wav"), audio_paths)) 137 | results_dict = decode_files(audio_paths, model, language_code, sample_rate, max_alternatives, num_proc=num_proc, raw=raw) 138 | 139 | if transcripts_only: 140 | for audio_file, transcripts in results_dict.items(): 141 | transcripts = [[alt["transcript"] if isinstance(alt, dict) else alt[0]["transcript"] for alt in segment] for segment in transcripts] 142 | results_dict[audio_file] = transcripts 143 | 144 | with open(output_json, "w", encoding="utf-8") as f: 145 | f.write(json.dumps(results_dict)) 146 | 147 | -------------------------------------------------------------------------------- /plugins/grpc/client/scripts/example_client.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for testing out ASR server. 3 | 4 | Usage: 5 | example_client.py mic [--n-secs=] [--model=] [--lang=] [--sample-rate=] [--max-alternatives=] [--stream] [--raw] [--pcm] [--word-level] 6 | example_client.py ... [--model=] [--lang=] [--sample-rate=] [--max-alternatives=] [--stream] [--raw] [--pcm] [--word-level] 7 | 8 | Options: 9 | --n-secs= Number of seconds to record the audio for before making a request. [default: 5] 10 | --model= Name of the model to hit. [default: general] 11 | --lang= Language code of the model. [default: en] 12 | --sample-rate= Sampling rate to use for the audio. [default: 8000] 13 | --max-alternatives= Number of maximum alternatives to query from the server. [default: 10] 14 | --raw Flag that specifies whether to stream raw audio bytes to server. 15 | --pcm Flag that specifies whether to send raw pcm bytes. 16 | --word-level Flag to enable word level features from server. 17 | """ 18 | import time 19 | import random 20 | import threading 21 | import traceback 22 | from pprint import pprint 23 | from typing import List 24 | 25 | from docopt import docopt 26 | from pydub import AudioSegment 27 | 28 | from kaldi_serve import KaldiServeClient, RecognitionAudio, RecognitionConfig 29 | from kaldi_serve.utils import ( 30 | chunks_from_file, 31 | chunks_from_mic, 32 | raw_bytes_to_wav 33 | ) 34 | 35 | ENCODING = RecognitionConfig.AudioEncoding.LINEAR16 36 | 37 | 38 | def parse_response(response): 39 | output = [] 40 | 41 | for res in response.results: 42 | output.append([ 43 | { 44 | "transcript": alt.transcript, 45 | "confidence": alt.confidence, 46 | "am_score": alt.am_score, 47 | "lm_score": alt.lm_score, 48 | "words": [ 49 | { 50 | "start_time": word.start_time, 51 | "end_time": word.end_time, 52 | "word": word.word, 53 | "confidence": word.confidence 54 | } 55 | for word in alt.words 56 | ] 57 | } 58 | for alt in res.alternatives 59 | ]) 60 | return output 61 | 62 | 63 | def transcribe_chunks_streaming(client, audio_chunks, model: str, language_code: str, 64 | sample_rate=8000, max_alternatives=10, raw: bool=False, 65 | word_level: bool=False, chunk_size: float=0.5): 66 | """ 67 | Transcribe the given audio chunks 68 | """ 69 | 70 | response = {} 71 | 72 | try: 73 | if raw: 74 | config = lambda chunk_len: RecognitionConfig( 75 | sample_rate_hertz=sample_rate, 76 | encoding=ENCODING, 77 | language_code=language_code, 78 | max_alternatives=max_alternatives, 79 | model=model, 80 | raw=True, 81 | word_level=word_level, 82 | data_bytes=chunk_len 83 | ) 84 | 85 | start = [None] 86 | def audio_params_gen(audio_chunks, start): 87 | for chunk in audio_chunks[:-1]: 88 | yield config(len(chunk)), RecognitionAudio(content=chunk) 89 | time.sleep(chunk_size) 90 | start[0] = time.time() 91 | yield config(len(audio_chunks[-1])), RecognitionAudio(content=audio_chunks[-1]) 92 | 93 | response = client.streaming_recognize_raw(audio_params_gen(audio_chunks, start), uuid=str(random.randint(1000, 100000))) 94 | end = time.time() 95 | print(f"{((end - start[0])*1000):.2f}ms") 96 | else: 97 | audio = (RecognitionAudio(content=chunk) for chunk in audio_chunks) 98 | config = RecognitionConfig( 99 | sample_rate_hertz=sample_rate, 100 | encoding=ENCODING, 101 | language_code=language_code, 102 | max_alternatives=max_alternatives, 103 | model=model, 104 | word_level=word_level 105 | ) 106 | response = client.streaming_recognize(config, audio, uuid=str(random.randint(1000, 100000))) 107 | except Exception as e: 108 | traceback.print_exc() 109 | print(f'error: {str(e)}') 110 | 111 | pprint(parse_response(response)) 112 | 113 | def transcribe_chunks_bidi_streaming(client, audio_chunks, model: str, language_code: str, 114 | sample_rate=8000, max_alternatives=10, raw: bool=False, 115 | word_level: bool=False): 116 | """ 117 | Transcribe the given audio chunks 118 | """ 119 | response = {} 120 | 121 | try: 122 | if raw: 123 | config = lambda chunk_len: RecognitionConfig( 124 | sample_rate_hertz=sample_rate, 125 | encoding=ENCODING, 126 | language_code=language_code, 127 | max_alternatives=max_alternatives, 128 | model=model, 129 | raw=True, 130 | data_bytes=chunk_len, 131 | word_level=word_level, 132 | ) 133 | 134 | def audio_params_gen(audio_chunks): 135 | for chunk in audio_chunks: 136 | yield config(len(chunk)), RecognitionAudio(content=chunk) 137 | 138 | response_gen = client.bidi_streaming_recognize_raw(audio_params_gen(audio_chunks), uuid=str(random.randint(1000, 100000))) 139 | else: 140 | config = RecognitionConfig( 141 | sample_rate_hertz=sample_rate, 142 | encoding=ENCODING, 143 | language_code=language_code, 144 | max_alternatives=max_alternatives, 145 | model=model, 146 | word_level=word_level 147 | ) 148 | 149 | def audio_chunks_gen(audio_chunks): 150 | for chunk in audio_chunks: 151 | yield RecognitionAudio(content=chunk) 152 | 153 | response_gen = client.bidi_streaming_recognize(config, audio_chunks_gen(audio_chunks), uuid=str(random.randint(1000, 100000))) 154 | except Exception as e: 155 | traceback.print_exc() 156 | print(f'error: {str(e)}') 157 | 158 | for response in response_gen: 159 | pprint(parse_response(response)) 160 | 161 | 162 | def decode_files(client, audio_paths: List[str], model: str, language_code: str, 163 | sample_rate=8000, max_alternatives=10, raw: bool=False, 164 | pcm: bool=False, word_level: bool=False, chunk_size: float=0.5): 165 | """ 166 | Decode files using threaded requests 167 | """ 168 | chunked_audios = [chunks_from_file(x, sample_rate=sample_rate, chunk_size=chunk_size, raw=raw, pcm=pcm) for x in audio_paths] 169 | 170 | threads = [ 171 | threading.Thread( 172 | target=transcribe_chunks_streaming, 173 | args=(client, chunks, model, language_code, 174 | sample_rate, max_alternatives, raw, 175 | word_level, chunk_size) 176 | ) 177 | for chunks in chunked_audios 178 | ] 179 | 180 | for thread in threads: 181 | thread.start() 182 | 183 | for thread in threads: 184 | thread.join() 185 | 186 | 187 | if __name__ == "__main__": 188 | args = docopt(__doc__) 189 | client = KaldiServeClient() 190 | 191 | # args 192 | model = args["--model"] 193 | language_code = args["--lang"] 194 | sample_rate = int(args["--sample-rate"]) 195 | max_alternatives = int(args["--max-alternatives"]) 196 | 197 | # flags 198 | raw = args['--raw'] 199 | pcm = args['--pcm'] 200 | word_level = args["--word-level"] 201 | 202 | if args["mic"]: 203 | transcribe_chunks_bidi_streaming(client, chunks_from_mic(int(args["--n-secs"]), sample_rate, 1), 204 | model, language_code, sample_rate, max_alternatives, 205 | raw or pcm, word_level) 206 | else: 207 | decode_files(client, args[""], model, language_code, 208 | sample_rate, max_alternatives, raw, pcm, word_level) 209 | -------------------------------------------------------------------------------- /plugins/grpc/client/scripts/list_models.py: -------------------------------------------------------------------------------- 1 | """Script for listing all active/loaded models on the server.""" 2 | import traceback 3 | 4 | from pprint import pprint 5 | from kaldi_serve import KaldiServeClient 6 | 7 | 8 | def list_models(client): 9 | try: 10 | response = client.list_models() 11 | except Exception as e: 12 | traceback.print_exc() 13 | print(f'error: {str(e)}') 14 | 15 | models = list(map(lambda model: {"name": model.name, "language": model.language_code}, response.models)) 16 | pprint(models) 17 | 18 | 19 | if __name__ == "__main__": 20 | client = KaldiServeClient() 21 | list_models(client) -------------------------------------------------------------------------------- /plugins/grpc/client/tests/conftest.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import time 3 | 4 | import pytest 5 | import yaml 6 | 7 | from kaldi_serve import KaldiServeClient, RecognitionAudio, RecognitionConfig 8 | from kaldi_serve.utils import chunks_from_file 9 | 10 | 11 | def read_items(file_path: str): 12 | """ 13 | Read transcription specs for testing 14 | """ 15 | 16 | with open(file_path) as fp: 17 | return yaml.safe_load(fp) 18 | 19 | 20 | def pytest_collect_file(parent, path): 21 | if path.ext == ".yaml" and path.basename.startswith("test"): 22 | return TranscriptionSpecFile(path, parent) 23 | 24 | 25 | class TranscriptionSpecFile(pytest.File): 26 | def collect(self): 27 | client = KaldiServeClient() 28 | for i, item in enumerate(read_items(self.fspath)): 29 | yield TranscriptionItem(f"item-{i}", self, item, client) 30 | 31 | 32 | def dreamer(source_gen, sleep_time: int): 33 | for item in source_gen: 34 | yield item 35 | time.sleep(sleep_time) 36 | 37 | 38 | class TranscriptionItem(pytest.Item): 39 | """ 40 | Each item tells which files to read and throw at the server in parallel. 41 | Also tells the expected transcriptions for each. 42 | """ 43 | 44 | def __init__(self, name, parent, item, client): 45 | super().__init__(name, parent) 46 | self.audios = [ 47 | (chunks_from_file(audio_spec["file"]), audio_spec["transcription"]) 48 | for audio_spec in item 49 | ] 50 | self.client = client 51 | self.results = [None for _ in item] 52 | 53 | def decode_audio(self, index: int): 54 | # NOTE: These are only assumptions for now so test failures might not 55 | # necessarily mean error in model/server. 56 | config = RecognitionConfig( 57 | sample_rate_hertz=8000, 58 | encoding=RecognitionConfig.AudioEncoding.LINEAR16, 59 | language_code="hi", 60 | max_alternatives=10, 61 | model="general" 62 | ) 63 | 64 | audio = dreamer((RecognitionAudio(content=chunk) for chunk in self.audios[index][0]), 1) 65 | self.results[index] = self.client.streaming_recognize(config, audio, uuid="") 66 | 67 | def runtest(self): 68 | threads = [] 69 | for i in range(len(self.audios)): 70 | threads.append(threading.Thread(target=self.decode_audio, args=(i, ))) 71 | 72 | for thread in threads: 73 | thread.start() 74 | 75 | for i, thread in enumerate(threads): 76 | thread.join() 77 | assert self.results[i].results[0].alternatives[0].transcript == self.audios[i][1] 78 | 79 | def reportinfo(self): 80 | return self.fspath, 0, self.name 81 | -------------------------------------------------------------------------------- /plugins/grpc/client/tests/resources/hi/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /plugins/grpc/client/tests/test_hi.yaml: -------------------------------------------------------------------------------- 1 | # Basic correctness tests for our hindi model (and the server). Note that you 2 | # will need both the model and the audio files listed here. Since this might not 3 | # be case in all environments (also models might just not be that stable), we 4 | # only keep this mechanism in place, suggesting users to test using their own 5 | # data files and models. 6 | 7 | - - file: ./tests/resources/hi/one_two_three_four.wav 8 | transcription: "एक दो तीन चार" 9 | 10 | - - file: ./tests/resources/hi/five_six_seven_eight.wav 11 | transcription: "पांच छह सात आठ" 12 | 13 | - - file: ./tests/resources/hi/nine_ten_eleven_twelve.wav 14 | transcription: "नौ दस ग्यारह बारह" 15 | 16 | # Multiple items are thrown in parallel to the server 17 | - - file: ./tests/resources/hi/one_two_three_four.wav 18 | transcription: "एक दो तीन चार" 19 | - file: ./tests/resources/hi/five_six_seven_eight.wav 20 | transcription: "पांच छह सात आठ" 21 | - file: ./tests/resources/hi/nine_ten_eleven_twelve.wav 22 | transcription: "नौ दस ग्यारह बारह" 23 | -------------------------------------------------------------------------------- /plugins/grpc/examples/aspire/README.md: -------------------------------------------------------------------------------- 1 | ## ASPIRE Chain Model example 2 | 3 | This is a basic example using the Aspire recipe, that will show you how to serve a [Kaldi](https://github.com/kaldi-asr/kaldi/) ASR model via [kaldi-serve](https://github.com/Vernacular-ai/kaldi-serve). If you have any query regarding the technical details of the recipe, please check [here](https://github.com/kaldi-asr/kaldi/tree/master/egs/aspire). 4 | 5 | ### Setup 6 | 7 | The setup script will download the chain model (if needed) and format the files as per our requirements: 8 | 9 | ```bash 10 | ./utils/setup_aspire_chain_model.sh --kaldi-root [KALDI ROOT] 11 | ``` 12 | 13 | ### Serving the model 14 | 15 | You can also run the following script directly as it will call the setup script anyhow to validate the model directory structure: 16 | 17 | ```bash 18 | ./run_server.sh --kaldi-root [KALDI ROOT] 19 | ``` 20 | -------------------------------------------------------------------------------- /plugins/grpc/examples/aspire/model/conf/ivector_extractor.conf: -------------------------------------------------------------------------------- 1 | --splice-config=conf/splice.conf 2 | --cmvn-config=conf/online_cmvn.conf 3 | --lda-matrix=ivector_extractor/final.mat 4 | --global-cmvn-stats=ivector_extractor/global_cmvn.stats 5 | --diag-ubm=ivector_extractor/final.dubm 6 | --ivector-extractor=ivector_extractor/final.ie 7 | --num-gselect=5 8 | --min-post=0.025 9 | --posterior-scale=0.1 10 | --max-remembered-frames=1000 11 | --max-count=100 12 | -------------------------------------------------------------------------------- /plugins/grpc/examples/aspire/model/conf/mfcc.conf: -------------------------------------------------------------------------------- 1 | # config for high-resolution MFCC features, intended for neural network training. 2 | # Note: we keep all cepstra, so it has the same info as filterbank features, 3 | # but MFCC is more easily compressible (because less correlated) which is why 4 | # we prefer this method. 5 | --use-energy=false # use average of log energy, not energy. 6 | --sample-frequency=8000 # Switchboard is sampled at 8kHz 7 | --num-mel-bins=40 # similar to Google's setup. 8 | --num-ceps=40 # there is no dimensionality reduction. 9 | --low-freq=40 # low cutoff frequency for mel bins 10 | --high-freq=-200 # high cutoff frequently, relative to Nyquist of 4000 (=3800) 11 | -------------------------------------------------------------------------------- /plugins/grpc/examples/aspire/model/conf/online.conf: -------------------------------------------------------------------------------- 1 | --feature-type=mfcc 2 | --mfcc-config=conf/mfcc.conf 3 | --ivector-extraction-config=conf/ivector_extractor.conf 4 | --endpoint.silence-phones=1:2:3:4:5:6:7:8:9:10:11:12:13:14:15:16:17:18:19:20 5 | -------------------------------------------------------------------------------- /plugins/grpc/examples/aspire/model/conf/online_cmvn.conf: -------------------------------------------------------------------------------- 1 | # configuration file for apply-cmvn-online, used in the script ../local/run_online_decoding.sh 2 | -------------------------------------------------------------------------------- /plugins/grpc/examples/aspire/model/conf/splice.conf: -------------------------------------------------------------------------------- 1 | --left-context=3 2 | --right-context=3 3 | -------------------------------------------------------------------------------- /plugins/grpc/examples/aspire/model/ivector_extractor/final.dubm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skit-ai/kaldi-serve/05b17f663cd4c860621fcf8c9d904f16f4ebc900/plugins/grpc/examples/aspire/model/ivector_extractor/final.dubm -------------------------------------------------------------------------------- /plugins/grpc/examples/aspire/model/ivector_extractor/final.ie: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skit-ai/kaldi-serve/05b17f663cd4c860621fcf8c9d904f16f4ebc900/plugins/grpc/examples/aspire/model/ivector_extractor/final.ie -------------------------------------------------------------------------------- /plugins/grpc/examples/aspire/model/ivector_extractor/final.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skit-ai/kaldi-serve/05b17f663cd4c860621fcf8c9d904f16f4ebc900/plugins/grpc/examples/aspire/model/ivector_extractor/final.mat -------------------------------------------------------------------------------- /plugins/grpc/examples/aspire/model/ivector_extractor/global_cmvn.stats: -------------------------------------------------------------------------------- 1 | [ 2 | 8.703383e+10 -1.636128e+10 -1.782775e+10 -2.403552e+10 -2.827338e+10 -1.734756e+10 -1.732724e+10 -9.909365e+09 -1.759744e+10 -6.15036e+09 -1.115416e+10 -8.634972e+09 -9.717826e+09 -5.072965e+09 -5.649734e+09 -3.287935e+09 -3.583767e+09 -1.148717e+09 -1.336033e+09 1.762154e+08 -4.473571e+08 1.471545e+07 -4.178443e+07 -3.075354e+07 1.322115e+08 1.643474e+08 4.94922e+08 4.846454e+08 7.162844e+08 4.813698e+08 1.112786e+09 6.139736e+08 8.726894e+08 6.062539e+08 6.838717e+08 1.580619e+08 -1.913887e+08 -8.36281e+08 -5.430025e+08 -1.006219e+08 8.981191e+08 3 | 8.748708e+12 5.02378e+11 7.31285e+11 9.755332e+11 1.431399e+12 7.046708e+11 7.21216e+11 4.731844e+11 7.282942e+11 3.986182e+11 4.728768e+11 3.803825e+11 3.896125e+11 2.800419e+11 2.521889e+11 1.928312e+11 1.571531e+11 1.069645e+11 7.009651e+10 3.646358e+10 1.776496e+10 6.182638e+09 8.20155e+08 2.116207e+08 2.448784e+09 6.086773e+09 1.178776e+10 1.702884e+10 1.98357e+10 2.134787e+10 2.632029e+10 2.629598e+10 2.520875e+10 2.205661e+10 2.286158e+10 1.919904e+10 1.510108e+10 1.335139e+10 9.674904e+09 7.123607e+09 0 ] 4 | -------------------------------------------------------------------------------- /plugins/grpc/examples/aspire/model/ivector_extractor/online_cmvn.conf: -------------------------------------------------------------------------------- 1 | # configuration file for apply-cmvn-online, used in the script ../local/run_online_decoding.sh 2 | -------------------------------------------------------------------------------- /plugins/grpc/examples/aspire/model/ivector_extractor/splice_opts: -------------------------------------------------------------------------------- 1 | --left-context=3 --right-context=3 2 | -------------------------------------------------------------------------------- /plugins/grpc/examples/aspire/model/word_boundary.int: -------------------------------------------------------------------------------- 1 | 1 nonword 2 | 2 begin 3 | 3 end 4 | 4 internal 5 | 5 singleton 6 | 6 nonword 7 | 7 begin 8 | 8 end 9 | 9 internal 10 | 10 singleton 11 | 11 nonword 12 | 12 begin 13 | 13 end 14 | 14 internal 15 | 15 singleton 16 | 16 nonword 17 | 17 begin 18 | 18 end 19 | 19 internal 20 | 20 singleton 21 | 21 begin 22 | 22 end 23 | 23 internal 24 | 24 singleton 25 | 25 begin 26 | 26 end 27 | 27 internal 28 | 28 singleton 29 | 29 begin 30 | 30 end 31 | 31 internal 32 | 32 singleton 33 | 33 begin 34 | 34 end 35 | 35 internal 36 | 36 singleton 37 | 37 begin 38 | 38 end 39 | 39 internal 40 | 40 singleton 41 | 41 begin 42 | 42 end 43 | 43 internal 44 | 44 singleton 45 | 45 begin 46 | 46 end 47 | 47 internal 48 | 48 singleton 49 | 49 begin 50 | 50 end 51 | 51 internal 52 | 52 singleton 53 | 53 begin 54 | 54 end 55 | 55 internal 56 | 56 singleton 57 | 57 begin 58 | 58 end 59 | 59 internal 60 | 60 singleton 61 | 61 begin 62 | 62 end 63 | 63 internal 64 | 64 singleton 65 | 65 begin 66 | 66 end 67 | 67 internal 68 | 68 singleton 69 | 69 begin 70 | 70 end 71 | 71 internal 72 | 72 singleton 73 | 73 begin 74 | 74 end 75 | 75 internal 76 | 76 singleton 77 | 77 begin 78 | 78 end 79 | 79 internal 80 | 80 singleton 81 | 81 begin 82 | 82 end 83 | 83 internal 84 | 84 singleton 85 | 85 begin 86 | 86 end 87 | 87 internal 88 | 88 singleton 89 | 89 begin 90 | 90 end 91 | 91 internal 92 | 92 singleton 93 | 93 begin 94 | 94 end 95 | 95 internal 96 | 96 singleton 97 | 97 begin 98 | 98 end 99 | 99 internal 100 | 100 singleton 101 | 101 begin 102 | 102 end 103 | 103 internal 104 | 104 singleton 105 | 105 begin 106 | 106 end 107 | 107 internal 108 | 108 singleton 109 | 109 begin 110 | 110 end 111 | 111 internal 112 | 112 singleton 113 | 113 begin 114 | 114 end 115 | 115 internal 116 | 116 singleton 117 | 117 begin 118 | 118 end 119 | 119 internal 120 | 120 singleton 121 | 121 begin 122 | 122 end 123 | 123 internal 124 | 124 singleton 125 | 125 begin 126 | 126 end 127 | 127 internal 128 | 128 singleton 129 | 129 begin 130 | 130 end 131 | 131 internal 132 | 132 singleton 133 | 133 begin 134 | 134 end 135 | 135 internal 136 | 136 singleton 137 | 137 begin 138 | 138 end 139 | 139 internal 140 | 140 singleton 141 | 141 begin 142 | 142 end 143 | 143 internal 144 | 144 singleton 145 | 145 begin 146 | 146 end 147 | 147 internal 148 | 148 singleton 149 | 149 begin 150 | 150 end 151 | 151 internal 152 | 152 singleton 153 | 153 begin 154 | 154 end 155 | 155 internal 156 | 156 singleton 157 | 157 begin 158 | 158 end 159 | 159 internal 160 | 160 singleton 161 | 161 begin 162 | 162 end 163 | 163 internal 164 | 164 singleton 165 | 165 begin 166 | 166 end 167 | 167 internal 168 | 168 singleton 169 | 169 begin 170 | 170 end 171 | 171 internal 172 | 172 singleton 173 | 173 begin 174 | 174 end 175 | 175 internal 176 | 176 singleton 177 | -------------------------------------------------------------------------------- /plugins/grpc/examples/aspire/model_spec.toml: -------------------------------------------------------------------------------- 1 | [[model]] 2 | name = "aspire" 3 | language_code = "en" 4 | path = "examples/aspire/model" 5 | n_decoders = 1 6 | beam = 15.0 7 | max_active = 7000 8 | lattice_beam = 6.0 9 | acoustic_scale = 1.0 10 | frame_subsampling_factor = 3 11 | -------------------------------------------------------------------------------- /plugins/grpc/examples/aspire/run_server.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cxx="g++" # g++ needs to be >= 8.0 4 | kaldi_root="/opt/kaldi" 5 | 6 | # Safety mechanism (possible running this script with modified arguments) 7 | . utils/parse_options.sh || exit 1 8 | [[ $# -ge 1 ]] && { 9 | echo "Wrong arguments!" 10 | exit 1 11 | } 12 | 13 | . utils/setup_aspire_chain_model.sh --kaldi-root $kaldi_root || exit 1; 14 | cd ../../ 15 | 16 | if ! [ -x build/kaldi_serve_app ]; then 17 | make -j KALDI_ROOT=$kaldi_root CXX=$cxx || exit 1; 18 | fi 19 | 20 | ./build/kaldi_serve_app examples/aspire/model_spec.toml 21 | -------------------------------------------------------------------------------- /plugins/grpc/examples/aspire/utils/parse_options.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey); 4 | # Arnab Ghoshal, Karel Vesely 5 | 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 14 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 15 | # MERCHANTABLITY OR NON-INFRINGEMENT. 16 | # See the Apache 2 License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | 20 | # Parse command-line options. 21 | # To be sourced by another script (as in ". parse_options.sh"). 22 | # Option format is: --option-name arg 23 | # and shell variable "option_name" gets set to value "arg." 24 | # The exception is --help, which takes no arguments, but prints the 25 | # $help_message variable (if defined). 26 | 27 | 28 | ### 29 | ### The --config file options have lower priority to command line 30 | ### options, so we need to import them first... 31 | ### 32 | 33 | # Now import all the configs specified by command-line, in left-to-right order 34 | for ((argpos=1; argpos<$#; argpos++)); do 35 | if [ "${!argpos}" == "--config" ]; then 36 | argpos_plus1=$((argpos+1)) 37 | config=${!argpos_plus1} 38 | [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 39 | . $config # source the config file. 40 | fi 41 | done 42 | 43 | 44 | ### 45 | ### Now we process the command line options 46 | ### 47 | while true; do 48 | [ -z "${1:-}" ] && break; # break if there are no arguments 49 | case "$1" in 50 | # If the enclosing script is called with --help option, print the help 51 | # message and exit. Scripts should put help messages in $help_message 52 | --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; 53 | else printf "$help_message\n" 1>&2 ; fi; 54 | exit 0 ;; 55 | --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" 56 | exit 1 ;; 57 | # If the first command-line argument begins with "--" (e.g. --foo-bar), 58 | # then work out the variable name as $name, which will equal "foo_bar". 59 | --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; 60 | # Next we test whether the variable in question is undefned-- if so it's 61 | # an invalid option and we die. Note: $0 evaluates to the name of the 62 | # enclosing script. 63 | # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar 64 | # is undefined. We then have to wrap this test inside "eval" because 65 | # foo_bar is itself inside a variable ($name). 66 | eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; 67 | 68 | oldval="`eval echo \\$$name`"; 69 | # Work out whether we seem to be expecting a Boolean argument. 70 | if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then 71 | was_bool=true; 72 | else 73 | was_bool=false; 74 | fi 75 | 76 | # Set the variable to the right value-- the escaped quotes make it work if 77 | # the option had spaces, like --cmd "queue.pl -sync y" 78 | eval $name=\"$2\"; 79 | 80 | # Check that Boolean-valued arguments are really Boolean. 81 | if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then 82 | echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 83 | exit 1; 84 | fi 85 | shift 2; 86 | ;; 87 | *) break; 88 | esac 89 | done 90 | 91 | 92 | # Check for an empty argument to the --cmd option, which can easily occur as a 93 | # result of scripting errors. 94 | [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; 95 | 96 | 97 | true; # so this script returns exit code 0. 98 | -------------------------------------------------------------------------------- /plugins/grpc/examples/aspire/utils/setup_aspire_chain_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | kaldi_root="/opt/kaldi" 4 | 5 | # Safety mechanism (possible running this script with modified arguments) 6 | . utils/parse_options.sh || exit 1 7 | [[ $# -ge 1 ]] && { 8 | echo "Wrong arguments!" 9 | exit 1 10 | } 11 | 12 | aspire_model_path="exp/tdnn_7b_chain_online" 13 | example_root=$(pwd) 14 | 15 | escape_path () { 16 | echo $(echo $1 | sed -e 's/\.//g' | sed -e 's/\/\//\//g' | sed -e 's/\//\\\//g') 17 | } 18 | 19 | if ! [ -d model/ ]; then 20 | 21 | cd $kaldi_root/egs/aspire/s5 22 | 23 | if ! [ -f aspire_chain_model.tar.bz2 ]; then 24 | echo "downloading ASPIRE Chain Model..." 25 | wget https://kaldi-asr.org/models/1/0001_aspire_chain_model_with_hclg.tar.bz2 -O aspire_chain_model.tar.bz2 26 | fi 27 | 28 | if ! [ -d exp/ ]; then 29 | tar -xvf aspire_chain_model.tar.bz2 30 | fi 31 | 32 | if ! [ -f $aspire_model_path/conf/online.conf ]; then 33 | echo "generating files needed for online decoding" 34 | . steps/online/nnet3/prepare_online_decoding.sh \ 35 | --mfcc-config conf/mfcc_hires.conf data/lang_chain \ 36 | exp/nnet3/extractor exp/chain/tdnn_7b exp/tdnn_7b_chain_online 37 | fi 38 | 39 | if ! [ -d model/ ]; then 40 | echo "copying essential files from aspire recipe" 41 | mkdir model/ 42 | cp $aspire_model_path/final.mdl model/ 43 | cp $aspire_model_path/graph_pp/HCLG.fst model/ 44 | cp $aspire_model_path/graph_pp/words.txt model/ 45 | cp $aspire_model_path/graph_pp/phones/word_boundary.int model/ 46 | cp -r $aspire_model_path/conf model/conf 47 | cp -r $aspire_model_path/ivector_extractor model/ivector_extractor 48 | 49 | escaped_model_path="$(escape_path "$kaldi_root/egs/aspire/s5/$aspire_model_path/")" 50 | 51 | sed -i "s/$escaped_model_path//g" "model/conf/online.conf" 52 | sed -i "s/$escaped_model_path//g" "model/conf/ivector_extractor.conf" 53 | fi 54 | 55 | mv model $example_root/model 56 | 57 | cd $example_root 58 | fi 59 | -------------------------------------------------------------------------------- /plugins/grpc/protos/kaldi_serve.grpc.pb.cc: -------------------------------------------------------------------------------- 1 | // Generated by the gRPC C++ plugin. 2 | // If you make any local change, they will be lost. 3 | // source: kaldi_serve.proto 4 | 5 | #include "kaldi_serve.pb.h" 6 | #include "kaldi_serve.grpc.pb.h" 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | namespace kaldi_serve { 23 | 24 | static const char* KaldiServe_method_names[] = { 25 | "/kaldi_serve.KaldiServe/ListModels", 26 | "/kaldi_serve.KaldiServe/Recognize", 27 | "/kaldi_serve.KaldiServe/StreamingRecognize", 28 | "/kaldi_serve.KaldiServe/BidiStreamingRecognize", 29 | }; 30 | 31 | std::unique_ptr< KaldiServe::Stub> KaldiServe::NewStub(const std::shared_ptr< ::grpc::ChannelInterface>& channel, const ::grpc::StubOptions& options) { 32 | (void)options; 33 | std::unique_ptr< KaldiServe::Stub> stub(new KaldiServe::Stub(channel)); 34 | return stub; 35 | } 36 | 37 | KaldiServe::Stub::Stub(const std::shared_ptr< ::grpc::ChannelInterface>& channel) 38 | : channel_(channel), rpcmethod_ListModels_(KaldiServe_method_names[0], ::grpc::internal::RpcMethod::NORMAL_RPC, channel) 39 | , rpcmethod_Recognize_(KaldiServe_method_names[1], ::grpc::internal::RpcMethod::NORMAL_RPC, channel) 40 | , rpcmethod_StreamingRecognize_(KaldiServe_method_names[2], ::grpc::internal::RpcMethod::CLIENT_STREAMING, channel) 41 | , rpcmethod_BidiStreamingRecognize_(KaldiServe_method_names[3], ::grpc::internal::RpcMethod::BIDI_STREAMING, channel) 42 | {} 43 | 44 | ::grpc::Status KaldiServe::Stub::ListModels(::grpc::ClientContext* context, const ::google::protobuf::Empty& request, ::kaldi_serve::ModelList* response) { 45 | return ::grpc::internal::BlockingUnaryCall(channel_.get(), rpcmethod_ListModels_, context, request, response); 46 | } 47 | 48 | void KaldiServe::Stub::experimental_async::ListModels(::grpc::ClientContext* context, const ::google::protobuf::Empty* request, ::kaldi_serve::ModelList* response, std::function f) { 49 | ::grpc_impl::internal::CallbackUnaryCall(stub_->channel_.get(), stub_->rpcmethod_ListModels_, context, request, response, std::move(f)); 50 | } 51 | 52 | void KaldiServe::Stub::experimental_async::ListModels(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::kaldi_serve::ModelList* response, std::function f) { 53 | ::grpc_impl::internal::CallbackUnaryCall(stub_->channel_.get(), stub_->rpcmethod_ListModels_, context, request, response, std::move(f)); 54 | } 55 | 56 | void KaldiServe::Stub::experimental_async::ListModels(::grpc::ClientContext* context, const ::google::protobuf::Empty* request, ::kaldi_serve::ModelList* response, ::grpc::experimental::ClientUnaryReactor* reactor) { 57 | ::grpc_impl::internal::ClientCallbackUnaryFactory::Create(stub_->channel_.get(), stub_->rpcmethod_ListModels_, context, request, response, reactor); 58 | } 59 | 60 | void KaldiServe::Stub::experimental_async::ListModels(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::kaldi_serve::ModelList* response, ::grpc::experimental::ClientUnaryReactor* reactor) { 61 | ::grpc_impl::internal::ClientCallbackUnaryFactory::Create(stub_->channel_.get(), stub_->rpcmethod_ListModels_, context, request, response, reactor); 62 | } 63 | 64 | ::grpc::ClientAsyncResponseReader< ::kaldi_serve::ModelList>* KaldiServe::Stub::AsyncListModelsRaw(::grpc::ClientContext* context, const ::google::protobuf::Empty& request, ::grpc::CompletionQueue* cq) { 65 | return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::kaldi_serve::ModelList>::Create(channel_.get(), cq, rpcmethod_ListModels_, context, request, true); 66 | } 67 | 68 | ::grpc::ClientAsyncResponseReader< ::kaldi_serve::ModelList>* KaldiServe::Stub::PrepareAsyncListModelsRaw(::grpc::ClientContext* context, const ::google::protobuf::Empty& request, ::grpc::CompletionQueue* cq) { 69 | return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::kaldi_serve::ModelList>::Create(channel_.get(), cq, rpcmethod_ListModels_, context, request, false); 70 | } 71 | 72 | ::grpc::Status KaldiServe::Stub::Recognize(::grpc::ClientContext* context, const ::kaldi_serve::RecognizeRequest& request, ::kaldi_serve::RecognizeResponse* response) { 73 | return ::grpc::internal::BlockingUnaryCall(channel_.get(), rpcmethod_Recognize_, context, request, response); 74 | } 75 | 76 | void KaldiServe::Stub::experimental_async::Recognize(::grpc::ClientContext* context, const ::kaldi_serve::RecognizeRequest* request, ::kaldi_serve::RecognizeResponse* response, std::function f) { 77 | ::grpc_impl::internal::CallbackUnaryCall(stub_->channel_.get(), stub_->rpcmethod_Recognize_, context, request, response, std::move(f)); 78 | } 79 | 80 | void KaldiServe::Stub::experimental_async::Recognize(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::kaldi_serve::RecognizeResponse* response, std::function f) { 81 | ::grpc_impl::internal::CallbackUnaryCall(stub_->channel_.get(), stub_->rpcmethod_Recognize_, context, request, response, std::move(f)); 82 | } 83 | 84 | void KaldiServe::Stub::experimental_async::Recognize(::grpc::ClientContext* context, const ::kaldi_serve::RecognizeRequest* request, ::kaldi_serve::RecognizeResponse* response, ::grpc::experimental::ClientUnaryReactor* reactor) { 85 | ::grpc_impl::internal::ClientCallbackUnaryFactory::Create(stub_->channel_.get(), stub_->rpcmethod_Recognize_, context, request, response, reactor); 86 | } 87 | 88 | void KaldiServe::Stub::experimental_async::Recognize(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::kaldi_serve::RecognizeResponse* response, ::grpc::experimental::ClientUnaryReactor* reactor) { 89 | ::grpc_impl::internal::ClientCallbackUnaryFactory::Create(stub_->channel_.get(), stub_->rpcmethod_Recognize_, context, request, response, reactor); 90 | } 91 | 92 | ::grpc::ClientAsyncResponseReader< ::kaldi_serve::RecognizeResponse>* KaldiServe::Stub::AsyncRecognizeRaw(::grpc::ClientContext* context, const ::kaldi_serve::RecognizeRequest& request, ::grpc::CompletionQueue* cq) { 93 | return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::kaldi_serve::RecognizeResponse>::Create(channel_.get(), cq, rpcmethod_Recognize_, context, request, true); 94 | } 95 | 96 | ::grpc::ClientAsyncResponseReader< ::kaldi_serve::RecognizeResponse>* KaldiServe::Stub::PrepareAsyncRecognizeRaw(::grpc::ClientContext* context, const ::kaldi_serve::RecognizeRequest& request, ::grpc::CompletionQueue* cq) { 97 | return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::kaldi_serve::RecognizeResponse>::Create(channel_.get(), cq, rpcmethod_Recognize_, context, request, false); 98 | } 99 | 100 | ::grpc::ClientWriter< ::kaldi_serve::RecognizeRequest>* KaldiServe::Stub::StreamingRecognizeRaw(::grpc::ClientContext* context, ::kaldi_serve::RecognizeResponse* response) { 101 | return ::grpc_impl::internal::ClientWriterFactory< ::kaldi_serve::RecognizeRequest>::Create(channel_.get(), rpcmethod_StreamingRecognize_, context, response); 102 | } 103 | 104 | void KaldiServe::Stub::experimental_async::StreamingRecognize(::grpc::ClientContext* context, ::kaldi_serve::RecognizeResponse* response, ::grpc::experimental::ClientWriteReactor< ::kaldi_serve::RecognizeRequest>* reactor) { 105 | ::grpc_impl::internal::ClientCallbackWriterFactory< ::kaldi_serve::RecognizeRequest>::Create(stub_->channel_.get(), stub_->rpcmethod_StreamingRecognize_, context, response, reactor); 106 | } 107 | 108 | ::grpc::ClientAsyncWriter< ::kaldi_serve::RecognizeRequest>* KaldiServe::Stub::AsyncStreamingRecognizeRaw(::grpc::ClientContext* context, ::kaldi_serve::RecognizeResponse* response, ::grpc::CompletionQueue* cq, void* tag) { 109 | return ::grpc_impl::internal::ClientAsyncWriterFactory< ::kaldi_serve::RecognizeRequest>::Create(channel_.get(), cq, rpcmethod_StreamingRecognize_, context, response, true, tag); 110 | } 111 | 112 | ::grpc::ClientAsyncWriter< ::kaldi_serve::RecognizeRequest>* KaldiServe::Stub::PrepareAsyncStreamingRecognizeRaw(::grpc::ClientContext* context, ::kaldi_serve::RecognizeResponse* response, ::grpc::CompletionQueue* cq) { 113 | return ::grpc_impl::internal::ClientAsyncWriterFactory< ::kaldi_serve::RecognizeRequest>::Create(channel_.get(), cq, rpcmethod_StreamingRecognize_, context, response, false, nullptr); 114 | } 115 | 116 | ::grpc::ClientReaderWriter< ::kaldi_serve::RecognizeRequest, ::kaldi_serve::RecognizeResponse>* KaldiServe::Stub::BidiStreamingRecognizeRaw(::grpc::ClientContext* context) { 117 | return ::grpc_impl::internal::ClientReaderWriterFactory< ::kaldi_serve::RecognizeRequest, ::kaldi_serve::RecognizeResponse>::Create(channel_.get(), rpcmethod_BidiStreamingRecognize_, context); 118 | } 119 | 120 | void KaldiServe::Stub::experimental_async::BidiStreamingRecognize(::grpc::ClientContext* context, ::grpc::experimental::ClientBidiReactor< ::kaldi_serve::RecognizeRequest,::kaldi_serve::RecognizeResponse>* reactor) { 121 | ::grpc_impl::internal::ClientCallbackReaderWriterFactory< ::kaldi_serve::RecognizeRequest,::kaldi_serve::RecognizeResponse>::Create(stub_->channel_.get(), stub_->rpcmethod_BidiStreamingRecognize_, context, reactor); 122 | } 123 | 124 | ::grpc::ClientAsyncReaderWriter< ::kaldi_serve::RecognizeRequest, ::kaldi_serve::RecognizeResponse>* KaldiServe::Stub::AsyncBidiStreamingRecognizeRaw(::grpc::ClientContext* context, ::grpc::CompletionQueue* cq, void* tag) { 125 | return ::grpc_impl::internal::ClientAsyncReaderWriterFactory< ::kaldi_serve::RecognizeRequest, ::kaldi_serve::RecognizeResponse>::Create(channel_.get(), cq, rpcmethod_BidiStreamingRecognize_, context, true, tag); 126 | } 127 | 128 | ::grpc::ClientAsyncReaderWriter< ::kaldi_serve::RecognizeRequest, ::kaldi_serve::RecognizeResponse>* KaldiServe::Stub::PrepareAsyncBidiStreamingRecognizeRaw(::grpc::ClientContext* context, ::grpc::CompletionQueue* cq) { 129 | return ::grpc_impl::internal::ClientAsyncReaderWriterFactory< ::kaldi_serve::RecognizeRequest, ::kaldi_serve::RecognizeResponse>::Create(channel_.get(), cq, rpcmethod_BidiStreamingRecognize_, context, false, nullptr); 130 | } 131 | 132 | KaldiServe::Service::Service() { 133 | AddMethod(new ::grpc::internal::RpcServiceMethod( 134 | KaldiServe_method_names[0], 135 | ::grpc::internal::RpcMethod::NORMAL_RPC, 136 | new ::grpc::internal::RpcMethodHandler< KaldiServe::Service, ::google::protobuf::Empty, ::kaldi_serve::ModelList>( 137 | std::mem_fn(&KaldiServe::Service::ListModels), this))); 138 | AddMethod(new ::grpc::internal::RpcServiceMethod( 139 | KaldiServe_method_names[1], 140 | ::grpc::internal::RpcMethod::NORMAL_RPC, 141 | new ::grpc::internal::RpcMethodHandler< KaldiServe::Service, ::kaldi_serve::RecognizeRequest, ::kaldi_serve::RecognizeResponse>( 142 | std::mem_fn(&KaldiServe::Service::Recognize), this))); 143 | AddMethod(new ::grpc::internal::RpcServiceMethod( 144 | KaldiServe_method_names[2], 145 | ::grpc::internal::RpcMethod::CLIENT_STREAMING, 146 | new ::grpc::internal::ClientStreamingHandler< KaldiServe::Service, ::kaldi_serve::RecognizeRequest, ::kaldi_serve::RecognizeResponse>( 147 | std::mem_fn(&KaldiServe::Service::StreamingRecognize), this))); 148 | AddMethod(new ::grpc::internal::RpcServiceMethod( 149 | KaldiServe_method_names[3], 150 | ::grpc::internal::RpcMethod::BIDI_STREAMING, 151 | new ::grpc::internal::BidiStreamingHandler< KaldiServe::Service, ::kaldi_serve::RecognizeRequest, ::kaldi_serve::RecognizeResponse>( 152 | std::mem_fn(&KaldiServe::Service::BidiStreamingRecognize), this))); 153 | } 154 | 155 | KaldiServe::Service::~Service() { 156 | } 157 | 158 | ::grpc::Status KaldiServe::Service::ListModels(::grpc::ServerContext* context, const ::google::protobuf::Empty* request, ::kaldi_serve::ModelList* response) { 159 | (void) context; 160 | (void) request; 161 | (void) response; 162 | return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); 163 | } 164 | 165 | ::grpc::Status KaldiServe::Service::Recognize(::grpc::ServerContext* context, const ::kaldi_serve::RecognizeRequest* request, ::kaldi_serve::RecognizeResponse* response) { 166 | (void) context; 167 | (void) request; 168 | (void) response; 169 | return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); 170 | } 171 | 172 | ::grpc::Status KaldiServe::Service::StreamingRecognize(::grpc::ServerContext* context, ::grpc::ServerReader< ::kaldi_serve::RecognizeRequest>* reader, ::kaldi_serve::RecognizeResponse* response) { 173 | (void) context; 174 | (void) reader; 175 | (void) response; 176 | return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); 177 | } 178 | 179 | ::grpc::Status KaldiServe::Service::BidiStreamingRecognize(::grpc::ServerContext* context, ::grpc::ServerReaderWriter< ::kaldi_serve::RecognizeResponse, ::kaldi_serve::RecognizeRequest>* stream) { 180 | (void) context; 181 | (void) stream; 182 | return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); 183 | } 184 | 185 | 186 | } // namespace kaldi_serve 187 | 188 | -------------------------------------------------------------------------------- /plugins/grpc/protos/kaldi_serve.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "google/protobuf/empty.proto"; 4 | package kaldi_serve; 5 | 6 | service KaldiServe { 7 | // Lists all the available loaded models 8 | rpc ListModels(google.protobuf.Empty) returns (ModelList) {} 9 | 10 | // Performs synchronous non-streaming speech recognition. 11 | rpc Recognize(RecognizeRequest) returns (RecognizeResponse) {} 12 | 13 | // Performs synchronous client-to-server streaming speech recognition: 14 | // receive results after all audio has been streamed and processed. 15 | rpc StreamingRecognize(stream RecognizeRequest) returns (RecognizeResponse) {} 16 | 17 | // Performs synchronous bidirectional streaming speech recognition: 18 | // receive results as the audio is being streamed and processed. 19 | rpc BidiStreamingRecognize(stream RecognizeRequest) returns (stream RecognizeResponse) {} 20 | } 21 | 22 | message ModelList { 23 | repeated Model models = 1; 24 | } 25 | 26 | message Model { 27 | string name = 1; 28 | string language_code = 2; 29 | } 30 | 31 | message RecognizeRequest { 32 | RecognitionConfig config = 1; 33 | RecognitionAudio audio = 2; 34 | string uuid = 3; 35 | } 36 | 37 | message RecognizeResponse { 38 | repeated SpeechRecognitionResult results = 1; 39 | } 40 | 41 | // Provides information to the recognizer that specifies how to process the request 42 | message RecognitionConfig { 43 | enum AudioEncoding { 44 | ENCODING_UNSPECIFIED = 0; 45 | LINEAR16 = 1; 46 | FLAC = 2; 47 | // MULAW = 3; 48 | // AMR = 4; 49 | // AMR_WB = 5; 50 | // OGG_OPUS = 6; 51 | // SPEEX_WITH_HEADER_BYTE = 7; 52 | } 53 | 54 | AudioEncoding encoding = 1; 55 | int32 sample_rate_hertz = 2; // Valid values are: 8000-48000. 56 | string language_code = 3; 57 | int32 max_alternatives = 4; 58 | bool punctuation = 5; 59 | repeated SpeechContext speech_contexts = 6; 60 | int32 audio_channel_count = 7; 61 | // RecognitionMetadata metadata = 9; 62 | string model = 10; 63 | bool raw = 11; 64 | int32 data_bytes = 12; 65 | bool word_level = 13; 66 | } 67 | 68 | // Either `content` or `uri` must be supplied. 69 | message RecognitionAudio { 70 | oneof audio_source { 71 | bytes content = 1; 72 | string uri = 2; 73 | } 74 | } 75 | 76 | message SpeechRecognitionResult { 77 | repeated SpeechRecognitionAlternative alternatives = 1; 78 | } 79 | 80 | message SpeechRecognitionAlternative { 81 | string transcript = 1; 82 | float confidence = 2; 83 | float am_score = 3; 84 | float lm_score = 4; 85 | repeated Word words = 5; 86 | } 87 | 88 | message Word { 89 | float start_time = 1; 90 | float end_time = 2; 91 | string word = 3; 92 | float confidence = 4; 93 | } 94 | 95 | message SpeechContext { 96 | repeated string phrases = 1; 97 | string type = 2; 98 | } 99 | -------------------------------------------------------------------------------- /plugins/grpc/registrar/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "io/ioutil" 9 | "log" 10 | "net/http" 11 | "os" 12 | "strconv" 13 | "strings" 14 | "time" 15 | ) 16 | 17 | // It will try to register the service into the local Consul agent 10 times, after that it will stop. 18 | // A simple PUT requests to the /v1/agent/service/register of Consul, passing the correct payload. 19 | // To run the script, you need to set the following ENV vars: 20 | // 21 | // export APP_NAME=fake_service 22 | // export APP_PORT=5001 23 | // export CONSUL_PORT=8500 24 | // export TAGS=svc1 25 | // go run register/main.go 26 | // 27 | // Description for the env vars are as follows: 28 | // 29 | // APP_NAME: Name of the app being registered 30 | // APP_PORT: Port on which the app is running 31 | // CONSUL_PORT: Port on which consul is running. (Usually 8500) 32 | // TAGS: Tags to be used to signify the app 33 | // 34 | // Additionally, there are 2 optional env vars where you can specify the health check API details for the service 35 | // HEALTH_CHECK_TYPE: Can be "http" or "grpc" 36 | // HEALTH_CHECK_ENDPOINT: API endpoint for the health check. 37 | // Documentation for the same can be found here: https://www.consul.io/api/agent/check#register-check 38 | func main() { 39 | tags := os.Getenv("TAGS") 40 | if tags == "" { 41 | logFatal("tags") 42 | } 43 | 44 | name := os.Getenv("APP_NAME") 45 | if tags == "" { 46 | logFatal("tags") 47 | } 48 | 49 | _port := os.Getenv("APP_PORT") 50 | if _port == "" { 51 | logFatal("app port") 52 | } 53 | port, err := strconv.Atoi(_port) 54 | if err != nil { 55 | log.Fatalf("Unable to convert \"%s\" to int: %v", _port, err) 56 | } 57 | 58 | consulPort := os.Getenv("CONSUL_PORT") 59 | if consulPort == "" { 60 | logFatal("consul port") 61 | } 62 | 63 | managerIpAddress := os.Getenv("MANAGER_IP_ADDRESS") 64 | if managerIpAddress == "" { 65 | logFatal("manager ip address") 66 | } 67 | 68 | // Can be either HTTP or GRPC 69 | healthCheckProtocol := os.Getenv("HEALTH_CHECK_TYPE") 70 | healthCheckEndpoint := os.Getenv("HEALTH_CHECK_ENDPOINT") 71 | 72 | //Encode the data 73 | postBody := map[string]interface{}{ 74 | "name": name, 75 | "tags": strings.Split(tags, ","), 76 | "address": "", 77 | "port": port, 78 | } 79 | 80 | // Adding a health check endpoint if specified in the env vars 81 | if healthCheckEndpoint != "" && (healthCheckProtocol == "http" || healthCheckProtocol == "grpc") { 82 | postBody["checks"] = []map[string]string{ 83 | {"http": fmt.Sprintf("http://%s:%s/hostname", managerIpAddress, consulPort), "interval": "5s"}, 84 | } 85 | } 86 | 87 | consulRegisterEndpoint := fmt.Sprintf("http://%s:%s/v1/agent/service/register", managerIpAddress, consulPort) 88 | 89 | // Attempt to register the service 10 times before giving up 90 | for i := 0; i < 10; i++ { 91 | if err = registerService(consulRegisterEndpoint, postBody); err != nil { 92 | // Wait for a second and then attempt tp register 93 | time.Sleep(1 * time.Second) 94 | continue 95 | } else { 96 | break 97 | } 98 | } 99 | 100 | if err != nil { 101 | log.Fatalf("Unable to register to consul: %v", err) 102 | } 103 | } 104 | 105 | // Simply PUT request to consul's register endpoint to register a service 106 | func registerService(consulRegisterEndpoint string, postBody map[string]interface{}) error { 107 | var b []byte 108 | var err error 109 | if b, err = json.Marshal(postBody); err != nil { 110 | return err 111 | } 112 | responseBody := bytes.NewBuffer(b) 113 | req, err := http.NewRequest("PUT", consulRegisterEndpoint, responseBody) 114 | //Handle Error 115 | if err != nil { 116 | return err 117 | } 118 | 119 | httpClient := &http.Client{ 120 | Timeout: time.Duration(21 * time.Second), 121 | } 122 | 123 | req.Header.Set("Accept", "application/json") 124 | 125 | resp, err := httpClient.Do(req) 126 | if err != nil { 127 | return err 128 | } 129 | 130 | defer resp.Body.Close() 131 | 132 | //Read the response body 133 | body, err := ioutil.ReadAll(resp.Body) 134 | if err != nil { 135 | return err 136 | } 137 | sb := string(body) 138 | 139 | if resp.StatusCode != 200 { 140 | return errors.New(fmt.Sprintf("Response code %d from consul. Body: \"%s\"", resp.StatusCode, sb)) 141 | } else { 142 | fmt.Println("Service registered with consul successfully !") 143 | } 144 | 145 | return nil 146 | } 147 | 148 | func logFatal(varName string) { 149 | if name, err := os.Hostname(); err != nil { 150 | log.Fatalf("No %v specified for host %v", varName, name) 151 | } else { 152 | log.Fatalf("No %v specified for this host", varName) 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /plugins/grpc/src/app.cc: -------------------------------------------------------------------------------- 1 | // app.cc - gRPC Application Entry 2 | 3 | // stl includes 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | // lib includes 11 | #include 12 | #include 13 | 14 | // local includes 15 | #include "config.hpp" 16 | #include "server.hpp" 17 | 18 | // vendor includes 19 | #include "vendor/CLI11.hpp" 20 | 21 | using namespace kaldiserve; 22 | 23 | 24 | int main(int argc, char *argv[]) { 25 | CLI::App app{"Kaldi gRPC server"}; 26 | 27 | std::string model_spec_toml; 28 | app.add_option("model_spec_toml", model_spec_toml, "Path to toml specifying models to load") 29 | ->required() 30 | ->check(CLI::ExistingFile); 31 | 32 | app.add_flag("-d,--debug", DEBUG, "Flag to enable debug mode"); 33 | 34 | app.add_flag_callback("-v,--version", print_version, "Show program version and exit"); 35 | 36 | CLI11_PARSE(app, argc, argv); 37 | 38 | std::vector model_specs; 39 | parse_model_specs(model_spec_toml, model_specs); 40 | 41 | if (model_specs.size() == 0) { 42 | std::cout << ":: No model found in toml for loading" << ENDL; 43 | return 1; 44 | } 45 | 46 | std::cout << ":: Loading " << model_specs.size() << " models" << ENDL; 47 | for (auto const &model_spec : model_specs) { 48 | std::cout << ":: - " << model_spec.name + " (" + model_spec.language_code + ")" << ENDL; 49 | } 50 | 51 | run_server(model_specs); 52 | 53 | return 0; 54 | } 55 | -------------------------------------------------------------------------------- /plugins/grpc/src/config.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // lib includes 4 | #include 5 | #include 6 | 7 | using namespace kaldiserve; 8 | 9 | 10 | constexpr bool BOOST_1_67_X = ((BOOST_VERSION / 100000) >= 1) && (((BOOST_VERSION / 100) % 1000) >= 67); 11 | 12 | // boost headers for hash functions was changed after version 1_67_X 13 | #if BOOST_1_67_X 14 | #include 15 | #else 16 | #include 17 | #endif 18 | 19 | 20 | // hash function for model id type (pair of strings) for unordered_map hashing fn 21 | struct model_id_hash { 22 | std::size_t operator () (const model_id_t &id) const { 23 | std::size_t seed = 0; 24 | boost::hash_combine(seed, id.first); 25 | boost::hash_combine(seed, id.second); 26 | return seed; 27 | } 28 | }; -------------------------------------------------------------------------------- /plugins/grpc/src/server.hpp: -------------------------------------------------------------------------------- 1 | // server.hpp - Server Interface 2 | #pragma once 3 | 4 | // stl includes 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | // lib includes 13 | #include 14 | 15 | // kaldi includes 16 | #include 17 | 18 | // gRPC inludes 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | // local includes 26 | #include "config.hpp" 27 | #include "kaldi_serve.grpc.pb.h" 28 | 29 | using namespace kaldiserve; 30 | 31 | 32 | void add_alternatives_to_response(const utterance_results_t &results, 33 | kaldi_serve::RecognizeResponse *response, 34 | const kaldi_serve::RecognitionConfig &config) noexcept { 35 | 36 | kaldi_serve::SpeechRecognitionResult *sr_result = response->add_results(); 37 | kaldi_serve::SpeechRecognitionAlternative *alternative; 38 | kaldi_serve::Word *word; 39 | 40 | // find alternatives on final `lattice` after all chunks have been processed 41 | for (auto const &res : results) { 42 | if (!res.transcript.empty()) { 43 | alternative = sr_result->add_alternatives(); 44 | alternative->set_transcript(res.transcript); 45 | alternative->set_confidence(res.confidence); 46 | alternative->set_am_score(res.am_score); 47 | alternative->set_lm_score(res.lm_score); 48 | if (config.word_level()) { 49 | for (auto const &w: res.words) { 50 | word = alternative->add_words(); 51 | word->set_start_time(w.start_time); 52 | word->set_end_time(w.end_time); 53 | word->set_word(w.word); 54 | word->set_confidence(w.confidence); 55 | } 56 | } 57 | } 58 | } 59 | } 60 | 61 | 62 | // KaldiServeImpl :: 63 | // Defines the core server logic and request/response handlers. 64 | // Keeps `Decoder` instances cached in a thread-safe 65 | // multiple producer multiple consumer queue to handle each 66 | // request with a separate `Decoder`. 67 | class KaldiServeImpl final : public kaldi_serve::KaldiServe::Service { 68 | 69 | private: 70 | // Map of Thread-safe Decoder MPMC Queues for diff languages/models 71 | std::unordered_map, model_id_hash> decoder_queue_map_; 72 | 73 | // Tells if a given model name and language code is available for use. 74 | inline bool is_model_present(const model_id_t &) const noexcept; 75 | 76 | public: 77 | explicit KaldiServeImpl(const std::vector &) noexcept; 78 | 79 | grpc::Status ListModels(grpc::ServerContext *const, 80 | const google::protobuf::Empty *const, 81 | kaldi_serve::ModelList *const) override; 82 | 83 | // Non-Streaming Request Handler RPC service 84 | // Accepts a single `RecognizeRequest` message 85 | // Returns a single `RecognizeResponse` message 86 | grpc::Status Recognize(grpc::ServerContext *const, 87 | const kaldi_serve::RecognizeRequest *const, 88 | kaldi_serve::RecognizeResponse *const) override; 89 | 90 | // Streaming Request Handler RPC service 91 | // Accepts a stream of `RecognizeRequest` messages 92 | // Returns a single `RecognizeResponse` message 93 | grpc::Status StreamingRecognize(grpc::ServerContext *const, 94 | grpc::ServerReader *const, 95 | kaldi_serve::RecognizeResponse *const) override; 96 | 97 | // Bidirectional Streaming Request Handler RPC service 98 | // Accepts a stream of `RecognizeRequest` messages 99 | // Returns a stream of `RecognizeResponse` messages 100 | grpc::Status BidiStreamingRecognize(grpc::ServerContext *const, 101 | grpc::ServerReaderWriter*) override; 102 | }; 103 | 104 | KaldiServeImpl::KaldiServeImpl(const std::vector &model_specs) noexcept { 105 | for (auto const &model_spec : model_specs) { 106 | model_id_t model_id = std::make_pair(model_spec.name, model_spec.language_code); 107 | decoder_queue_map_[model_id] = std::unique_ptr(new DecoderQueue(model_spec)); 108 | } 109 | } 110 | 111 | inline bool KaldiServeImpl::is_model_present(const model_id_t &model_id) const noexcept { 112 | return decoder_queue_map_.find(model_id) != decoder_queue_map_.end(); 113 | } 114 | 115 | grpc::Status KaldiServeImpl::ListModels(grpc::ServerContext *const context, 116 | const google::protobuf::Empty *const request, 117 | kaldi_serve::ModelList *const model_list) { 118 | 119 | kaldi_serve::Model *model; 120 | 121 | for (auto const &model_id : decoder_queue_map_) { 122 | model = model_list->add_models(); 123 | model->set_name(model_id.first.first); 124 | model->set_language_code(model_id.first.second); 125 | } 126 | 127 | return grpc::Status::OK; 128 | } 129 | 130 | grpc::Status KaldiServeImpl::Recognize(grpc::ServerContext *const context, 131 | const kaldi_serve::RecognizeRequest *const request, 132 | kaldi_serve::RecognizeResponse *const response) { 133 | const kaldi_serve::RecognitionConfig config = request->config(); 134 | std::string uuid = request->uuid(); 135 | const int32 n_best = config.max_alternatives(); 136 | const int32 sample_rate_hertz = config.sample_rate_hertz(); 137 | const std::string model_name = config.model(); 138 | const std::string language_code = config.language_code(); 139 | const model_id_t model_id = std::make_pair(model_name, language_code); 140 | 141 | if (!is_model_present(model_id)) { 142 | return grpc::Status(grpc::StatusCode::NOT_FOUND, "Model " + model_name + " (" + language_code + ") not found"); 143 | } 144 | 145 | std::chrono::system_clock::time_point start_time; 146 | if (DEBUG) start_time = std::chrono::system_clock::now(); 147 | 148 | // Decoder Acquisition :: 149 | // - Tries to attain lock and obtain decoder from the queue. 150 | // - Waits here until lock on queue is attained. 151 | // - Each new audio stream gets separate decoder object. 152 | Decoder *decoder_ = decoder_queue_map_[model_id]->acquire(); 153 | 154 | if (DEBUG) { 155 | std::chrono::system_clock::time_point end_time = std::chrono::system_clock::now(); 156 | auto ms = std::chrono::duration_cast(end_time - start_time); 157 | std::cout << "[" << timestamp_now() << "] uuid: " << uuid << " decoder acquired in: " << ms.count() << "ms" << ENDL; 158 | } 159 | 160 | kaldi_serve::RecognitionAudio audio = request->audio(); 161 | std::stringstream input_stream(audio.content()); 162 | 163 | if (DEBUG) start_time = std::chrono::system_clock::now(); 164 | decoder_->start_decoding(uuid); 165 | 166 | // decode speech signals in chunks 167 | try { 168 | if (config.raw()) { 169 | decoder_->decode_raw_wav_audio(input_stream, sample_rate_hertz, config.data_bytes()); 170 | } else { 171 | decoder_->decode_wav_audio(input_stream); 172 | } 173 | } catch (kaldi::KaldiFatalError &e) { 174 | decoder_queue_map_[model_id]->release(decoder_); 175 | std::string message = std::string(e.what()) + " :: " + std::string(e.KaldiMessage()); 176 | return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, message); 177 | } catch (std::exception &e) { 178 | decoder_queue_map_[model_id]->release(decoder_); 179 | return grpc::Status(grpc::StatusCode::INTERNAL, e.what()); 180 | } 181 | 182 | utterance_results_t k_results_; 183 | decoder_->get_decoded_results(n_best, k_results_, config.word_level()); 184 | 185 | add_alternatives_to_response(k_results_, response, config); 186 | 187 | // Decoder Release :: 188 | // - Releases the lock on the decoder and pushes back into queue. 189 | // - Notifies another request handler thread of availability. 190 | decoder_->free_decoder(); 191 | decoder_queue_map_[model_id]->release(decoder_); 192 | 193 | if (DEBUG) { 194 | std::chrono::system_clock::time_point end_time = std::chrono::system_clock::now(); 195 | // LOG REQUEST RESOLVE TIME --> END 196 | auto ms = std::chrono::duration_cast(end_time - start_time); 197 | std::cout << "[" << timestamp_now() << "] uuid: " << uuid << " request resolved in: " << ms.count() << "ms" << ENDL; 198 | } 199 | 200 | return grpc::Status::OK; 201 | } 202 | 203 | grpc::Status KaldiServeImpl::StreamingRecognize(grpc::ServerContext *const context, 204 | grpc::ServerReader *const reader, 205 | kaldi_serve::RecognizeResponse *const response) { 206 | kaldi_serve::RecognizeRequest request_; 207 | reader->Read(&request_); 208 | 209 | // We first read the request to see if we have the correct model and language to load 210 | // Assuming: config may change mid-way (only `raw` and `data_bytes` fields) 211 | kaldi_serve::RecognitionConfig config = request_.config(); 212 | std::string uuid = request_.uuid(); 213 | const int32 n_best = config.max_alternatives(); 214 | const int32 sample_rate_hertz = config.sample_rate_hertz(); 215 | const std::string model_name = config.model(); 216 | const std::string language_code = config.language_code(); 217 | const model_id_t model_id = std::make_pair(model_name, language_code); 218 | 219 | if (!is_model_present(model_id)) { 220 | return grpc::Status(grpc::StatusCode::NOT_FOUND, "Model " + model_name + " (" + language_code + ") not found"); 221 | } 222 | 223 | std::chrono::system_clock::time_point start_time, start_time_req; 224 | if (DEBUG) start_time = std::chrono::system_clock::now(); 225 | 226 | // Decoder Acquisition :: 227 | // - Tries to attain lock and obtain decoder from the queue. 228 | // - Waits here until lock on queue is attained. 229 | // - Each new audio stream gets separate decoder object. 230 | Decoder *decoder_ = decoder_queue_map_[model_id]->acquire(); 231 | 232 | if (DEBUG) { 233 | std::chrono::system_clock::time_point end_time = std::chrono::system_clock::now(); 234 | auto ms = std::chrono::duration_cast(end_time - start_time); 235 | std::cout << "[" << timestamp_now() << "] uuid: " << uuid << " decoder acquired in: " << ms.count() << "ms" << ENDL; 236 | } 237 | 238 | int i = 0; 239 | int bytes = 0; 240 | 241 | if (DEBUG) start_time_req = std::chrono::system_clock::now(); 242 | decoder_->start_decoding(uuid); 243 | 244 | // read chunks until end of stream 245 | do { 246 | if (DEBUG) { 247 | // LOG REQUEST RESOLVE TIME --> START (at the last request since that would be the actual latency) 248 | start_time = std::chrono::system_clock::now(); 249 | 250 | i++; 251 | bytes += config.data_bytes(); 252 | 253 | std::stringstream debug_msg; 254 | debug_msg << "[" 255 | << timestamp_now() 256 | << "] uuid: " << uuid 257 | << " chunk #" << i 258 | << " received"; 259 | if (config.raw()) { 260 | debug_msg << " - " << config.data_bytes() 261 | << " bytes (total " << bytes 262 | << ")"; 263 | } 264 | 265 | std::cout << debug_msg.str() << ENDL; 266 | } 267 | config = request_.config(); 268 | kaldi_serve::RecognitionAudio audio = request_.audio(); 269 | std::stringstream input_stream_chunk(audio.content()); 270 | 271 | // decode intermediate speech signals 272 | // Assuming: audio stream has already been chunked into desired length 273 | try { 274 | if (config.raw()) { 275 | decoder_->decode_stream_raw_wav_chunk(input_stream_chunk, sample_rate_hertz, config.data_bytes()); 276 | } else { 277 | decoder_->decode_stream_wav_chunk(input_stream_chunk); 278 | } 279 | } catch (kaldi::KaldiFatalError &e) { 280 | decoder_queue_map_[model_id]->release(decoder_); 281 | std::string message = std::string(e.what()) + " :: " + std::string(e.KaldiMessage()); 282 | return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, message); 283 | } catch (std::exception &e) { 284 | decoder_queue_map_[model_id]->release(decoder_); 285 | return grpc::Status(grpc::StatusCode::INTERNAL, e.what()); 286 | } 287 | 288 | if (DEBUG) { 289 | std::chrono::system_clock::time_point end_time = std::chrono::system_clock::now(); 290 | auto ms = std::chrono::duration_cast(end_time - start_time); 291 | 292 | std::stringstream debug_msg; 293 | debug_msg << "[" 294 | << timestamp_now() 295 | << "] uuid: " << uuid 296 | << " chunk #" << i 297 | << " computed in " 298 | << ms.count() << "ms"; 299 | 300 | std::cout << debug_msg.str() << ENDL; 301 | } 302 | } while (reader->Read(&request_)); 303 | 304 | if (DEBUG) start_time = std::chrono::system_clock::now(); 305 | 306 | utterance_results_t k_results_; 307 | decoder_->get_decoded_results(n_best, k_results_, config.word_level()); 308 | 309 | add_alternatives_to_response(k_results_, response, config); 310 | 311 | // Decoder Release :: 312 | // - Releases the lock on the decoder and pushes back into queue. 313 | // - Notifies another request handler thread of availability. 314 | decoder_->free_decoder(); 315 | decoder_queue_map_[model_id]->release(decoder_); 316 | 317 | if (DEBUG) { 318 | std::chrono::system_clock::time_point end_time_req = std::chrono::system_clock::now(); 319 | auto ms = std::chrono::duration_cast(end_time_req - start_time); 320 | std::cout << "[" << timestamp_now() << "] uuid: " << uuid << " found best paths in " << ms.count() << "ms" << ENDL; 321 | 322 | // LOG REQUEST RESOLVE TIME --> END 323 | auto ms_req = std::chrono::duration_cast(end_time_req - start_time_req); 324 | std::cout << "[" << timestamp_now() << "] uuid: " << uuid << " request resolved in: " << ms_req.count() << "ms" << ENDL; 325 | } 326 | 327 | return grpc::Status::OK; 328 | } 329 | 330 | grpc::Status KaldiServeImpl::BidiStreamingRecognize(grpc::ServerContext *const context, 331 | grpc::ServerReaderWriter *stream) { 332 | kaldi_serve::RecognizeRequest request_; 333 | stream->Read(&request_); 334 | 335 | // We first read the request to see if we have the correct model and language to load 336 | // Assuming: config may change mid-way (only `raw` and `data_bytes` fields) 337 | kaldi_serve::RecognitionConfig config = request_.config(); 338 | std::string uuid = request_.uuid(); 339 | const int32 n_best = config.max_alternatives(); 340 | const int32 sample_rate_hertz = config.sample_rate_hertz(); 341 | const std::string model_name = config.model(); 342 | const std::string language_code = config.language_code(); 343 | const model_id_t model_id = std::make_pair(model_name, language_code); 344 | 345 | if (!is_model_present(model_id)) { 346 | return grpc::Status(grpc::StatusCode::NOT_FOUND, "Model " + model_name + " (" + language_code + ") not found"); 347 | } 348 | 349 | std::chrono::system_clock::time_point start_time, start_time_req; 350 | if (DEBUG) start_time = std::chrono::system_clock::now(); 351 | 352 | // Decoder Acquisition :: 353 | // - Tries to attain lock and obtain decoder from the queue. 354 | // - Waits here until lock on queue is attained. 355 | // - Each new audio stream gets separate decoder object. 356 | Decoder *decoder_ = decoder_queue_map_[model_id]->acquire(); 357 | 358 | if (DEBUG) { 359 | std::chrono::system_clock::time_point end_time = std::chrono::system_clock::now(); 360 | auto ms = std::chrono::duration_cast(end_time - start_time); 361 | std::cout << "[" << timestamp_now() << "] uuid: " << uuid << " decoder acquired in: " << ms.count() << "ms" << ENDL; 362 | } 363 | 364 | int i = 0; 365 | int bytes = 0; 366 | 367 | if (DEBUG) start_time_req = std::chrono::system_clock::now(); 368 | decoder_->start_decoding(uuid); 369 | 370 | // read chunks until end of stream 371 | do { 372 | if (DEBUG) { 373 | start_time = std::chrono::system_clock::now(); 374 | 375 | i++; 376 | bytes += config.data_bytes(); 377 | 378 | std::stringstream debug_msg; 379 | debug_msg << "[" 380 | << timestamp_now() 381 | << "] uuid: " << uuid 382 | << " chunk #" << i 383 | << " received"; 384 | if (config.raw()) { 385 | debug_msg << " - " << config.data_bytes() 386 | << " bytes (total " << bytes 387 | << ")"; 388 | } 389 | 390 | std::cout << debug_msg.str() << ENDL; 391 | } 392 | config = request_.config(); 393 | kaldi_serve::RecognitionAudio audio = request_.audio(); 394 | std::stringstream input_stream_chunk(audio.content()); 395 | 396 | // decode intermediate speech signals 397 | // Assuming: audio stream has already been chunked into desired length 398 | try { 399 | if (config.raw()) { 400 | decoder_->decode_stream_raw_wav_chunk(input_stream_chunk, sample_rate_hertz, config.data_bytes()); 401 | } else { 402 | decoder_->decode_stream_wav_chunk(input_stream_chunk); 403 | } 404 | 405 | utterance_results_t k_results_; 406 | decoder_->get_decoded_results(n_best, k_results_, config.word_level(), true); 407 | 408 | kaldi_serve::RecognizeResponse response_; 409 | add_alternatives_to_response(k_results_, &response_, config); 410 | 411 | stream->Write(response_); 412 | 413 | } catch (kaldi::KaldiFatalError &e) { 414 | decoder_queue_map_[model_id]->release(decoder_); 415 | std::string message = std::string(e.what()) + " :: " + std::string(e.KaldiMessage()); 416 | return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, message); 417 | } catch (std::exception &e) { 418 | decoder_queue_map_[model_id]->release(decoder_); 419 | return grpc::Status(grpc::StatusCode::INTERNAL, e.what()); 420 | } 421 | 422 | if (DEBUG) { 423 | std::chrono::system_clock::time_point end_time = std::chrono::system_clock::now(); 424 | auto ms = std::chrono::duration_cast(end_time - start_time); 425 | 426 | std::stringstream debug_msg; 427 | debug_msg << "[" 428 | << timestamp_now() 429 | << "] uuid: " << uuid 430 | << " chunk #" << i 431 | << " computed in " 432 | << ms.count() << "ms"; 433 | 434 | std::cout << debug_msg.str() << ENDL; 435 | } 436 | } while (stream->Read(&request_)); 437 | 438 | if (DEBUG) start_time = std::chrono::system_clock::now(); 439 | 440 | utterance_results_t k_results_; 441 | decoder_->get_decoded_results(n_best, k_results_, config.word_level()); 442 | 443 | kaldi_serve::RecognizeResponse response_; 444 | add_alternatives_to_response(k_results_, &response_, config); 445 | 446 | stream->Write(response_); 447 | 448 | // Decoder Release :: 449 | // - Releases the lock on the decoder and pushes back into queue. 450 | // - Notifies another request handler thread of availability. 451 | decoder_->free_decoder(); 452 | decoder_queue_map_[model_id]->release(decoder_); 453 | 454 | if (DEBUG) { 455 | std::chrono::system_clock::time_point end_time_req = std::chrono::system_clock::now(); 456 | auto ms = std::chrono::duration_cast(end_time_req - start_time); 457 | std::cout << "[" << timestamp_now() << "] uuid: " << uuid << " found best paths in " << ms.count() << "ms" << ENDL; 458 | 459 | // LOG REQUEST RESOLVE TIME --> END 460 | auto ms_req = std::chrono::duration_cast(end_time_req - start_time_req); 461 | std::cout << "[" << timestamp_now() << "] uuid: " << uuid << " request resolved in: " << ms_req.count() << "ms" << ENDL; 462 | } 463 | 464 | return grpc::Status::OK; 465 | } 466 | 467 | 468 | // Runs the Server with the Kaldi Service 469 | void run_server(const std::vector &model_specs) { 470 | KaldiServeImpl service(model_specs); 471 | 472 | std::string server_address("0.0.0.0:5016"); 473 | 474 | grpc::ServerBuilder builder; 475 | builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); 476 | builder.RegisterService(&service); 477 | 478 | std::unique_ptr server(builder.BuildAndStart()); 479 | 480 | std::cout << "kaldi-serve gRPC Streaming Server listening on " << server_address << ENDL; 481 | server->Wait(); 482 | } 483 | 484 | 485 | /** 486 | NOTES: 487 | ------ 488 | 489 | VARIABLES THAT INFLUENCE SERVER RELIABILITY :: 490 | 1. Length of Audio Stream (in secs) 491 | 2. No. of chunks in the Audio Stream 492 | 3. Time intervals between subsequent chunks of audio stream 493 | 4. No. of Decoders in Queue 494 | 5. Timeout for each request (chunk essentially) 495 | 6. No. of concurrent streams being handled by the server 496 | */ 497 | -------------------------------------------------------------------------------- /python/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | find_package(Python REQUIRED) 2 | 3 | # add pybind module 4 | file(GLOB_RECURSE PY_ALL_SOURCE_FILES "kaldiserve_pybind/*.cpp") 5 | pybind11_add_module(kaldiserve_pybind ${PY_ALL_SOURCE_FILES}) 6 | 7 | # suppress pybind11 warnings 8 | target_include_directories(kaldiserve_pybind SYSTEM PRIVATE ${PYBIND11_INCLUDE_DIR}) 9 | 10 | # kaldiserve_pybind, kaldiserve, kaldi, openfst includes 11 | target_include_directories(kaldiserve_pybind PRIVATE .) 12 | target_include_directories(kaldiserve_pybind PRIVATE ../include) 13 | target_include_directories(kaldiserve_pybind PRIVATE ${KALDI_ROOT}/src ${KALDI_ROOT}/tools/openfst/include) 14 | 15 | # kaldiserve lib 16 | target_link_libraries(kaldiserve_pybind PRIVATE kaldiserve) -------------------------------------------------------------------------------- /python/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM vernacularai/kaldi-serve:latest 2 | 3 | # install python3.6.5 through pyenv 4 | RUN apt-get update && \ 5 | apt-get upgrade -y && \ 6 | apt-get install -y \ 7 | libtcmalloc-minimal4 \ 8 | make build-essential libssl-dev zlib1g-dev \ 9 | libbz2-dev libreadline-dev libsqlite3-dev \ 10 | wget curl llvm libncurses5-dev libncursesw5-dev \ 11 | xz-utils tk-dev libffi-dev liblzma-dev \ 12 | python-openssl git 13 | 14 | RUN curl https://pyenv.run | bash 15 | RUN echo 'export PATH="~/.pyenv/bin:$PATH"' >> ~/.bashrc && \ 16 | echo 'eval "$(pyenv init -)"' >> ~/.bashrc && \ 17 | echo 'eval "$(pyenv virtualenv-init -)"' >> ~/.bashrc 18 | RUN bash -c "source ~/.bashrc && pyenv install 3.6.5 && pyenv global 3.6.5" 19 | 20 | # build python module using the C++ shared library 21 | WORKDIR /root/kaldi-serve 22 | COPY . . 23 | 24 | RUN bash -c "source ~/.bashrc && cd build && \ 25 | cmake .. -DBUILD_SHARED_LIBS=OFF -DBUILD_PYTHON_MODULE=ON -DBUILD_PYBIND11=ON -DPYTHON_EXECUTABLE=\$(pyenv which python) && \ 26 | make -j$(nproc) VERBOSE=1" 27 | 28 | RUN cp build/python/kaldiserve_pybind*.so python/kaldiserve/ 29 | RUN bash -c "source ~/.bashrc && cd python && pip install . -U" 30 | 31 | ENV LD_PRELOAD="/opt/intel/mkl/lib/intel64/libmkl_rt.so:/usr/lib/libtcmalloc_minimal.so.4" 32 | 33 | WORKDIR /root 34 | 35 | #cleanup 36 | RUN rm -rf kaldi-serve -------------------------------------------------------------------------------- /python/README.md: -------------------------------------------------------------------------------- 1 | # Kaldi-Serve Python Binding 2 | 3 | Python binding for the `kaldiserve` C++ library. 4 | 5 | ## Installation 6 | 7 | ### Build from source 8 | 9 | You will need [pybind11](https://github.com/pybind/pybind11) to be present and built, or alternately you can pass the `-DBUILD_PYBIND11=ON` flag and cmake will take care of it. You can build the bindings by passing `-DBUILD_PYTHON_MODULE=ON -DPYTHON_EXECUTABLE=${which python}` options to the main cmake command: 10 | 11 | ```bash 12 | # build the python bindings (starting from current dir) 13 | cd ../build 14 | cmake .. -DBUILD_PYBIND11=ON -DBUILD_PYTHON_MODULE=ON -DPYTHON_EXECUTABLE=${which python} 15 | make -j${nproc} 16 | 17 | # copy over the built shared library to the python package 18 | cp python/kaldiserve_pybind*.so ../python/kaldiserve/ 19 | 20 | # build the python package 21 | cd ../build/python 22 | pip install . -U 23 | ``` 24 | 25 | Now you can import `kaldiserve` into your python project. 26 | 27 | ### Docker Image 28 | 29 | #### Using pre-built images 30 | 31 | You can pull pre-built docker images (we currently support python version 3.6) from our [Docker Hub repository](https://hub.docker.com/repository/docker/vernacularai/kaldi-serve): 32 | 33 | ```bash 34 | docker pull vernacularai/kaldi-serve:latest-py3.6 35 | docker run -it vernacularai/kaldi-serve:latest-py3.6 36 | ``` 37 | 38 | You will find Python 3.6 pre-installed with `kaldiserve` python package. 39 | 40 | #### Building the image 41 | 42 | You can also build the docker image using the [Dockerfile](./Dockerfile) provided (make sure to do it from the root dir): 43 | 44 | ```bash 45 | cd ../ 46 | docker build -t kaldi-serve:py${PY_VERSION} -f python/Dockerfile . 47 | ``` 48 | 49 | ## Getting Started 50 | 51 | ### Transcription Interface 52 | 53 | ```python 54 | from io import BytesIO 55 | from kaldiserve import ChainModel, Decoder, parse_model_specs, start_decoding 56 | 57 | # model specification 58 | model_spec = parse_model_specs("../resources/model-spec.toml")[0] 59 | # chain model contains all large const components that can be shared across decoders on multiple threads 60 | model = ChainModel(model_spec) 61 | 62 | # initialize a decoder that keeps a const reference to the model 63 | decoder = Decoder(model) 64 | 65 | audio_files = ["sample1.wav", "sample2.wav"] 66 | 67 | for audio_file in audio_files: 68 | # read audio file as bytes 69 | with open(audio_file, "rb") as f: 70 | audio_bytes = BytesIO(f.read()).getvalue() 71 | 72 | with start_decoding(decoder): 73 | # decode the audio 74 | decoder.decode_wav_audio(audio_bytes) 75 | # get the transcripts (10 alternatives) 76 | alts = decoder.get_decoded_results(10) 77 | 78 | print(alts) 79 | ``` 80 | 81 | ### Sample Scripts 82 | 83 | You will need `kaldiserve` python package and some other [dependencies](./scripts/requirements.txt) to be installed: 84 | 85 | ```bash 86 | pip install -r scripts/requirements.txt 87 | ``` 88 | 89 | There are some sample [scripts](./scripts) provided that can be referenced as examples: 90 | 1. [Transcribe](./scripts/transcribe.py) - transcribes a single audio file 91 | 2. [Batch Transcribe](./scripts/batch_transcribe.py) - transcribes a batch of audio files via multi-threading 92 | 93 | ## Known Issues 94 | 95 | 1. If you face `INTEL MKL ERROR` when instantiating `ChainModel|DecoderFactory|DecoderQueue`, try the following: 96 | 97 | ```bash 98 | export LD_PRELOAD="${MKL_ROOT}/lib/intel64/libmkl_rt.so" 99 | ``` 100 | 101 | 2. If you face `*** Error in python: double free or corruption (!prev): ...`, try the following: 102 | 103 | ```bash 104 | apt-get install libtcmalloc-minimal4 105 | export LD_PRELOAD="/usr/lib/libtcmalloc_minimal.so.4" 106 | ``` -------------------------------------------------------------------------------- /python/kaldiserve/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" 2 | 3 | from kaldiserve.kaldiserve_pybind import ModelSpec, Word, Alternative # types 4 | from kaldiserve.kaldiserve_pybind import _ModelSpecList, _WordList, _AlternativeList # type list aliases 5 | from kaldiserve.kaldiserve_pybind import ChainModel # models 6 | from kaldiserve.kaldiserve_pybind import Decoder, DecoderQueue, DecoderFactory # decoders 7 | from kaldiserve.kaldiserve_pybind import parse_model_specs # utils 8 | 9 | from contextlib import contextmanager 10 | 11 | 12 | @contextmanager 13 | def acquire_decoder(dq: DecoderQueue): 14 | decoder = dq.acquire() 15 | try: 16 | yield decoder 17 | finally: 18 | dq.release(decoder) 19 | 20 | 21 | @contextmanager 22 | def start_decoding(decoder: Decoder, uuid: str=""): 23 | decoder.start_decoding(uuid) 24 | try: 25 | yield None 26 | finally: 27 | decoder.free_decoder() -------------------------------------------------------------------------------- /python/kaldiserve_pybind/decoder.cpp: -------------------------------------------------------------------------------- 1 | // stl includes 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | // pybind includes 9 | #include 10 | 11 | // kaldiserve_pybind includes 12 | #include "kaldiserve_pybind/kaldiserve_pybind.h" 13 | 14 | // kaldiserve includes 15 | #include "kaldiserve/model.hpp" 16 | #include "kaldiserve/decoder.hpp" 17 | #include "kaldiserve/types.hpp" 18 | 19 | 20 | namespace kaldiserve { 21 | 22 | void pybind_decoder(py::module &m) { 23 | // kaldiserve.Decoder 24 | py::class_(m, "Decoder", "Decoder class.") 25 | .def(py::init()) 26 | .def("start_decoding", &Decoder::start_decoding) 27 | .def("free_decoder", &Decoder::free_decoder) 28 | // wav stream chunk 29 | .def("decode_stream_wav_chunk", [](Decoder &self, py::bytes &wav_bytes) { 30 | std::string wav_bytes_str(wav_bytes); 31 | { 32 | py::gil_scoped_release release; 33 | std::istringstream wav_stream(wav_bytes_str); 34 | self.decode_stream_wav_chunk(wav_stream); 35 | } 36 | 37 | }) 38 | // raw wav stream chunk 39 | .def("decode_stream_raw_wav_chunk", [](Decoder &self, py::bytes &wav_bytes, 40 | const float &samp_freq, const int &data_bytes) { 41 | std::string wav_bytes_str(wav_bytes); 42 | { 43 | py::gil_scoped_release release; 44 | std::istringstream wav_stream(wav_bytes_str); 45 | self.decode_stream_raw_wav_chunk(wav_stream, samp_freq, data_bytes); 46 | } 47 | }) 48 | // wav audio 49 | .def("decode_wav_audio", [](Decoder &self, py::bytes &wav_bytes, const float &chunk_size) { 50 | std::string wav_bytes_str(wav_bytes); 51 | { 52 | py::gil_scoped_release release; 53 | std::istringstream wav_stream(wav_bytes_str); 54 | self.decode_wav_audio(wav_stream, chunk_size); 55 | } 56 | }, py::arg("wav_bytes"), py::arg("chunk_size") = 1.0) 57 | // raw wav audio 58 | .def("decode_raw_wav_audio", [](Decoder &self, py::bytes &wav_bytes, const float &samp_freq, 59 | const int &data_bytes, const float &chunk_size) { 60 | std::string wav_bytes_str(wav_bytes); 61 | { 62 | py::gil_scoped_release release; 63 | std::istringstream wav_stream(wav_bytes_str); 64 | self.decode_raw_wav_audio(wav_stream, samp_freq, data_bytes, chunk_size); 65 | } 66 | }, py::arg("wav_bytes"), py::arg("samp_freq"), 67 | py::arg("data_bytes"), py::arg("chunk_size") = 1.0) 68 | // get decoding results -> list[Alternative] 69 | .def("get_decoded_results", [](Decoder &self, const int &n_best, 70 | const bool &word_level, const bool &bidi_streaming) { 71 | std::vector alts; 72 | { 73 | py::gil_scoped_release release; 74 | self.get_decoded_results(n_best, alts, word_level, bidi_streaming); 75 | } 76 | py::list py_alts = py::cast(alts); 77 | return py_alts; 78 | }, py::arg("n_best"), 79 | py::arg("word_level") = false, 80 | py::arg("bidi_streaming") = false); 81 | 82 | // kaldiserve.DecoderFactory 83 | py::class_(m, "DecoderFactory", "Decoder Factory class.") 84 | .def(py::init()) 85 | .def("produce", &DecoderFactory::produce, py::call_guard(), py::return_value_policy::reference); 86 | 87 | // kaldiserve.DecoderQueue 88 | py::class_(m, "DecoderQueue", "Decoder Queue class.") 89 | .def(py::init()) 90 | .def("acquire", &DecoderQueue::acquire, py::call_guard(), py::return_value_policy::reference) 91 | .def("release", &DecoderQueue::release);//, py::call_guard()); 92 | } 93 | 94 | } // namespace kaldiserve -------------------------------------------------------------------------------- /python/kaldiserve_pybind/kaldiserve_pybind.cpp: -------------------------------------------------------------------------------- 1 | #include "kaldiserve_pybind/kaldiserve_pybind.h" 2 | 3 | 4 | namespace kaldiserve { 5 | 6 | PYBIND11_MODULE(kaldiserve_pybind, m) { 7 | m.doc() = "Python binding of kaldiserve"; 8 | 9 | // types bindings 10 | pybind_types(m); 11 | // model bindings 12 | pybind_model(m); 13 | // decoder bindings 14 | pybind_decoder(m); 15 | // utils bindings 16 | pybind_utils(m); 17 | } 18 | 19 | } // namespace kaldiserve -------------------------------------------------------------------------------- /python/kaldiserve_pybind/kaldiserve_pybind.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // stl includes 4 | #include 5 | 6 | // pybind includes 7 | #include 8 | #include 9 | 10 | // kaldiserve includes 11 | #include "kaldiserve/types.hpp" 12 | 13 | namespace py = pybind11; 14 | using namespace py::literals; 15 | 16 | 17 | PYBIND11_MAKE_OPAQUE(std::vector); 18 | PYBIND11_MAKE_OPAQUE(std::vector); 19 | PYBIND11_MAKE_OPAQUE(std::vector); 20 | 21 | namespace kaldiserve { 22 | 23 | // types 24 | void pybind_types(py::module &m); 25 | 26 | // model 27 | void pybind_model(py::module &m); 28 | 29 | // decoder 30 | void pybind_decoder(py::module &m); 31 | 32 | // utils 33 | void pybind_utils(py::module &m); 34 | } -------------------------------------------------------------------------------- /python/kaldiserve_pybind/model.cpp: -------------------------------------------------------------------------------- 1 | // kaldiserve_pybind includes 2 | #include "kaldiserve_pybind/kaldiserve_pybind.h" 3 | 4 | // kaldiserve includes 5 | #include "kaldiserve/model.hpp" 6 | #include "kaldiserve/types.hpp" 7 | 8 | 9 | namespace kaldiserve { 10 | 11 | void pybind_model(py::module &m) { 12 | // kaldiserve.ChainModel 13 | py::class_(m, "ChainModel", "Chain model class.") 14 | .def(py::init()); 15 | } 16 | 17 | } // namespace kaldiserve -------------------------------------------------------------------------------- /python/kaldiserve_pybind/types.cpp: -------------------------------------------------------------------------------- 1 | // stl includes 2 | #include 3 | #include 4 | 5 | // kaldiserve_pybind includes 6 | #include "kaldiserve_pybind/kaldiserve_pybind.h" 7 | 8 | // kaldiserve includes 9 | #include "kaldiserve/types.hpp" 10 | 11 | 12 | namespace kaldiserve { 13 | 14 | void pybind_types(py::module &m) { 15 | 16 | py::bind_vector>(m, "_ModelSpecList"); 17 | 18 | // kaldiserve.ModelSpec 19 | py::class_(m, "ModelSpec", "Model Specification struct.") 20 | .def(py::init<>()) 21 | .def_readonly("name", &ModelSpec::name) 22 | .def_readonly("language_code", &ModelSpec::language_code) 23 | .def_readonly("path", &ModelSpec::path) 24 | .def_readonly("n_decoders", &ModelSpec::n_decoders) 25 | .def_readonly("min_active", &ModelSpec::min_active) 26 | .def_readonly("max_active", &ModelSpec::max_active) 27 | .def_readonly("frame_subsampling_factor", &ModelSpec::frame_subsampling_factor) 28 | .def_readonly("beam", &ModelSpec::beam) 29 | .def_readonly("lattice_beam", &ModelSpec::lattice_beam) 30 | .def_readonly("acoustic_scale", &ModelSpec::acoustic_scale) 31 | .def_readonly("silence_weight", &ModelSpec::silence_weight) 32 | .def_readonly("max_ngram_order", &ModelSpec::max_ngram_order) 33 | .def_readonly("rnnlm_weight", &ModelSpec::rnnlm_weight) 34 | .def_readonly("bos_index", &ModelSpec::bos_index) 35 | .def_readonly("eos_index", &ModelSpec::eos_index) 36 | .def("__repr__", [](const ModelSpec &ms) { 37 | return ""; 40 | }); 41 | // .def(py::init(), 45 | // py::arg("name"), py::arg("language_code"), py::arg("path"), 46 | // py::arg("n_decoders") = 1, py::arg("min_active") = 200, 47 | // py::arg("max_active") = 7000, py::arg("frame_subsampling_factor") = 3, 48 | // py::arg("beam") = 16.0, py::arg("lattice_beam") = 6.0, py::arg("acoustic_scale") = 1.0, 49 | // py::arg("silence_weight") = 1.0, py::arg("max_ngram_order") = 3, 50 | // py::arg("rnnlm_weight") = 0.5, py::arg("bos_index") = "1", py::arg("eos_index") = "2"); 51 | 52 | py::bind_vector>(m, "_WordList"); 53 | 54 | // kaldiserve.Word 55 | py::class_(m, "Word", "Word struct.") 56 | .def(py::init<>()) 57 | .def_readonly("start_time", &Word::start_time) 58 | .def_readonly("end_time", &Word::end_time) 59 | .def_readonly("confidence", &Word::confidence) 60 | .def_readonly("word", &Word::word) 61 | .def("__repr__", [](const Word &w) { 62 | return ""; 66 | }); 67 | // .def(py::init(), 68 | // py::arg("start_time"), py::arg("end_time"), py::arg("confidence"), py::arg("word")) 69 | 70 | py::bind_vector>(m, "_AlternativeList"); 71 | 72 | // kaldiserve.Alternative 73 | py::class_(m, "Alternative", "Alternative struct.") 74 | .def(py::init<>()) 75 | .def_readonly("transcript", &Alternative::transcript) 76 | .def_readonly("confidence", &Alternative::confidence) 77 | .def_readonly("am_score", &Alternative::am_score) 78 | .def_readonly("lm_score", &Alternative::lm_score) 79 | .def_readonly("words", &Alternative::words) 80 | .def("__repr__", [](const Alternative &alt) { 81 | return ""; 85 | // "', words: '" + std::string(alt.words.begin(), alt.words.end()) + "'}>"; 86 | }); 87 | // .def(py::init>(), 88 | // py::arg("transcript"), py::arg("confidence"), py::arg("am_score"), py::arg("lm_score"), py::arg("words")) 89 | } 90 | 91 | } // namespace kaldiserve -------------------------------------------------------------------------------- /python/kaldiserve_pybind/utils.cpp: -------------------------------------------------------------------------------- 1 | // stl includes 2 | #include 3 | #include 4 | 5 | // pybind includes 6 | #include 7 | 8 | // kaldiserve_pybind includes 9 | #include "kaldiserve_pybind/kaldiserve_pybind.h" 10 | 11 | // kaldiserve includes 12 | #include "kaldiserve/utils.hpp" 13 | #include "kaldiserve/types.hpp" 14 | 15 | 16 | namespace kaldiserve { 17 | 18 | void pybind_utils(py::module &m) { 19 | m.def("parse_model_specs", [](const std::string &toml_path) { 20 | std::vector model_specs; 21 | parse_model_specs(toml_path, model_specs); 22 | py::list py_model_specs = py::cast(model_specs); 23 | return py_model_specs; 24 | }); 25 | } 26 | 27 | } // namespace kaldiserve -------------------------------------------------------------------------------- /python/scripts/batch_transcribe.py: -------------------------------------------------------------------------------- 1 | """ 2 | Batch Audio Transcription script using kalidserve. 3 | 4 | Usage: batch_transcribe.py 5 | """ 6 | import time 7 | import threading 8 | 9 | from io import BytesIO 10 | from typing import List, Text 11 | from docopt import docopt 12 | 13 | import kaldiserve as ks 14 | 15 | 16 | def transcribe(decoder: ks.Decoder, wav_stream: bytes) -> List[ks.Alternative]: 17 | with ks.start_decoding(decoder): 18 | # decode the audio 19 | decoder.decode_wav_audio(wav_stream) 20 | # get the transcripts 21 | alts = decoder.get_decoded_results(10, False, False) 22 | return alts 23 | 24 | 25 | def decode_thread(decoder_queue: ks.DecoderQueue, audio_file: Text, n: int): 26 | # read audio bytes 27 | with open(audio_file, "rb") as f: 28 | audio_bytes = BytesIO(f.read()).getvalue() 29 | 30 | start = time.time() 31 | with ks.acquire_decoder(decoder_queue) as decoder: 32 | end = time.time() 33 | print(f"{audio_file}: decoder acquired in {(end - start):.4f}s") 34 | # transcribe audio 35 | start = time.time() 36 | alts = transcribe(decoder, audio_bytes) 37 | end = time.time() 38 | print(f"{audio_file}: decoded audio in {(end - start):.4f}s") 39 | 40 | print(f"{audio_file}: Alternatives\n{alts}") 41 | 42 | 43 | if __name__ == "__main__": 44 | args = docopt(__doc__) 45 | 46 | model_spec_toml = args[""] 47 | audio_paths_file = args[""] 48 | 49 | # parse model spec 50 | model_spec = ks.parse_model_specs(model_spec_toml)[0] 51 | # create decoder queue 52 | decoder_queue = ks.DecoderQueue(model_spec) 53 | 54 | # read audio paths 55 | with open(audio_paths_file, "r", encoding="utf-8") as f: 56 | audio_paths = f.read().split("\n") 57 | 58 | audio_paths = list(filter(lambda x: x.endswith(".wav"), audio_paths)) 59 | 60 | # multithreaded decoding 61 | threads = [ 62 | threading.Thread(target=decode_thread, args=(decoder_queue, audio_path, i + 1,)) 63 | for i, audio_path in enumerate(audio_paths) 64 | ] 65 | 66 | for thread in threads: 67 | thread.start() 68 | 69 | for thread in threads: 70 | thread.join() -------------------------------------------------------------------------------- /python/scripts/requirements.txt: -------------------------------------------------------------------------------- 1 | kaldiserve==1.0.0 2 | docopt -------------------------------------------------------------------------------- /python/scripts/transcribe.py: -------------------------------------------------------------------------------- 1 | """ 2 | Audio Transcription script using kalidserve. 3 | 4 | Usage: transcribe.py 5 | """ 6 | import time 7 | 8 | from io import BytesIO 9 | from typing import List, Text 10 | from docopt import docopt 11 | 12 | import kaldiserve as ks 13 | 14 | 15 | def transcribe(decoder: ks.Decoder, wav_stream: bytes) -> List[ks.Alternative]: 16 | with ks.start_decoding(decoder): 17 | # decode the audio 18 | decoder.decode_wav_audio(wav_stream) 19 | # get the transcripts 20 | alts = decoder.get_decoded_results(10) 21 | return alts 22 | 23 | 24 | if __name__ == "__main__": 25 | args = docopt(__doc__) 26 | 27 | model_spec_toml = args[""] 28 | audio_file = args[""] 29 | 30 | # parse model spec 31 | model_spec = ks.parse_model_specs(model_spec_toml)[0] 32 | # create chain model 33 | model = ks.ChainModel(model_spec) 34 | # create decoder instance 35 | decoder = ks.Decoder(model) 36 | 37 | # read audio bytes 38 | with open(audio_file, "rb") as f: 39 | audio_bytes = BytesIO(f.read()).getvalue() 40 | 41 | # transcribe audio 42 | start = time.time() 43 | alts = transcribe(decoder, audio_bytes) 44 | end = time.time() 45 | print(f"{audio_file}: decoded audio in {(end - start):.4f}s") 46 | 47 | print(f"{audio_file}: Alternatives\n{alts}") -------------------------------------------------------------------------------- /python/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | 5 | from setuptools import setup 6 | 7 | 8 | def find_library(lib_regex, path="kaldiserve", required=True): 9 | pattern = re.compile(lib_regex) 10 | 11 | files = os.listdir(path) 12 | for f in files: 13 | if pattern.match(f): 14 | return f 15 | if required: 16 | raise FileNotFoundError(lib_regex) 17 | return None 18 | 19 | 20 | setup( 21 | name='kaldiserve', 22 | version='1.0.0', 23 | author='Vernacular.ai team', 24 | author_email='hello@vernacular.ai', 25 | description='A plug-and-play abstraction over Kaldi ASR toolkit.', 26 | long_description='', 27 | packages=["kaldiserve"], 28 | package_dir={"kaldiserve": "kaldiserve"}, 29 | include_package_data=True, 30 | package_data={ 31 | "kaldiserve": [find_library(r"kaldiserve_pybind.*\.so")] 32 | }, 33 | classifiers=[ 34 | "Intended Audience :: Developers", 35 | "Intended Audience :: Science/Research" 36 | "Operating System :: POSIX :: Linux", 37 | "Programming Language :: C", 38 | "Programming Language :: C++", 39 | "Programming Language :: Python :: 3", 40 | "Programming Language :: Python :: 3.5", 41 | "Programming Language :: Python :: 3.6", 42 | "Programming Language :: Python :: 3.7", 43 | "Programming Language :: Python :: 3.8", 44 | "Topic :: Software Development :: Libraries :: Python Modules", 45 | ], 46 | keywords="asr speech-recognition kaldi grpc-server", 47 | license="Apache", 48 | url="https://github.com/Vernacular-ai/kaldi-serve", 49 | project_urls={ 50 | 'Documentation': 'https://github.com/Vernacular-ai/kaldi-serve', 51 | 'Source code': 'https://github.com/Vernacular-ai/kaldi-serve', 52 | 'Issues': 'https://github.com/Vernacular-ai/kaldi-serve/issues', 53 | }, 54 | zip_safe=False, 55 | ) -------------------------------------------------------------------------------- /resources/model-spec.toml: -------------------------------------------------------------------------------- 1 | # This is a sample model specification file which can be loaded in kaldi-serve 2 | # Here we specify a list of model with extra properties like model name, 3 | # language code etc. 4 | 5 | # Compulsory keys are `name', `language' (both used to identify a loaded model) 6 | # and `path'. 7 | [[model]] 8 | name = "general" 9 | language_code = "hi" 10 | path = "./path/to/model/dir" 11 | 12 | [[model]] 13 | name = "general" 14 | language_code = "en" 15 | path = "./path/to/model/dir" 16 | # A few optional decoder related parameters. Default value listed in comment at 17 | # the end. Most of the viterbi params can be tuned to trade-off speed vs 18 | # accuracy. 19 | n_decoders = 20 # 1 20 | beam = 7.0 # 13.0 21 | min_active = 200 # 200 22 | max_active = 3000 # 7000 23 | lattice_beam = 3.0 # 6.0 24 | acoustic_scale = 1.0 # 1.0 25 | frame_subsampling_factor = 3 # 3 26 | silence_weight = 1.0 27 | 28 | # A model `path` looks something like the following (for minimal transcription 29 | # only use case): 30 | 31 | # ├── conf 32 | # │   ├── ivector_extractor.conf 33 | # │   ├── mfcc.conf 34 | # │   ├── online_cmvn.conf 35 | # │   └── splice.conf 36 | # ├── final.mdl 37 | # ├── HCLG.fst 38 | # ├── ivector_extractor 39 | # │   ├── final.dubm 40 | # │   ├── final.ie 41 | # │   ├── final.mat 42 | # │   └── global_cmvn.stats 43 | # ├── word_boundary.int (optional; needed only for word level confidence and timing information) 44 | # └── words.txt 45 | 46 | # The files above have the default kaldi chain model interpretation (with 47 | # ivector also as an input). A few things to notes: 48 | # + `final.mdl` contains the neural net and transition model. 49 | # + `HCLG.fst` is the decoding FST. 50 | # + `words.txt` is a symbol table mapping decoder output ids to words. 51 | # + For feature pipeline, mfcc config is picked from`conf/mfcc.conf`. 52 | # + For ivector, we read the `conf/ivector_extractor.conf` allowing two kinds of 53 | # paths for params in ivector config. 54 | # - Absolute like /mnt/model/ivector_extractor/final.mat 55 | # - Relative to the model-dir, something like ivector_extractor/final.mat -------------------------------------------------------------------------------- /src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | include_directories(${KALDI_ROOT}/src ${KALDI_ROOT}/tools/openfst/include) 2 | include_directories(${CUDA_TK_ROOT}/include) 3 | include_directories(../include/kaldiserve) 4 | 5 | file(GLOB_RECURSE ALL_SOURCE_FILES "*.cpp") 6 | 7 | add_library(kaldiserve SHARED ${ALL_SOURCE_FILES}) 8 | 9 | target_link_options(kaldiserve PUBLIC "-L/usr/local/lib -Wl,--no-as-needed -Wl,--as-needed -ldl '-Wl,-rpath,$$ORIGIN/../lib' -rdynamic") 10 | target_link_directories(kaldiserve 11 | PUBLIC "${KALDI_ROOT}/src/lib" 12 | PUBLIC "${KALDI_ROOT}/tools/openfst/lib" 13 | ) 14 | target_link_libraries(kaldiserve 15 | # openfst 16 | fst 17 | # kaldi 18 | kaldi-decoder 19 | kaldi-lat 20 | kaldi-fstext 21 | kaldi-hmm 22 | kaldi-feat 23 | kaldi-transform 24 | kaldi-gmm 25 | kaldi-tree 26 | kaldi-util 27 | kaldi-matrix 28 | kaldi-cudamatrix 29 | kaldi-base 30 | kaldi-nnet3 31 | kaldi-chain 32 | kaldi-online2 33 | kaldi-ivector 34 | kaldi-lm 35 | kaldi-rnnlm 36 | # boost 37 | boost_filesystem 38 | -static-libstdc++ 39 | ) 40 | 41 | set_target_properties(kaldiserve PROPERTIES LINKER_LANGUAGE CXX) -------------------------------------------------------------------------------- /src/config.cpp: -------------------------------------------------------------------------------- 1 | // config.cpp - Configuration Implementation 2 | 3 | // stl includes 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | // local includes 11 | #include "config.hpp" 12 | 13 | 14 | namespace kaldiserve { 15 | 16 | void print_version() { 17 | std::cout << VERSION << std::endl; 18 | exit(EXIT_SUCCESS); 19 | } 20 | 21 | std::string timestamp_now() { 22 | auto now = std::chrono::system_clock::now(); 23 | std::time_t now_time = std::chrono::system_clock::to_time_t(now); 24 | std::string now_str = std::string(std::asctime(std::localtime(&now_time))); 25 | auto millis = std::chrono::duration_cast(now.time_since_epoch()).count(); 26 | std::stringstream time_string; 27 | time_string << now_str.substr(0, now_str.size() - 6) << "." << (millis % 1000); 28 | return time_string.str(); 29 | } 30 | 31 | } // namespace kaldiserve -------------------------------------------------------------------------------- /src/decoder/decoder-common.cpp: -------------------------------------------------------------------------------- 1 | // decoder-common.cpp - Decoder Common methods Implementation 2 | 3 | // local includes 4 | #include "config.hpp" 5 | #include "decoder.hpp" 6 | #include "types.hpp" 7 | 8 | 9 | namespace kaldiserve { 10 | 11 | void find_alternatives(kaldi::CompactLattice &clat, 12 | const std::size_t &n_best, 13 | utterance_results_t &results, 14 | const bool &word_level, 15 | ChainModel *const model, 16 | const DecoderOptions &options) { 17 | if (clat.NumStates() == 0) { 18 | KALDI_LOG << "Empty lattice."; 19 | } 20 | 21 | if (options.enable_rnnlm) { 22 | // rnnlm.fst 23 | std::unique_ptr lm_to_add_orig = 24 | make_uniq(model->model_spec.max_ngram_order, *model->rnnlm_info); 25 | std::unique_ptr lm_to_add = 26 | make_uniq(model->rnnlm_weight, lm_to_add_orig.get()); 27 | 28 | // G.fst 29 | std::unique_ptr> lm_to_subtract_det_backoff = 30 | make_uniq>(*model->lm_to_subtract_fst); 31 | std::unique_ptr lm_to_subtract_det_scale = 32 | make_uniq(-model->rnnlm_weight, lm_to_subtract_det_backoff.get()); 33 | 34 | // combine both LM fsts 35 | fst::ComposeDeterministicOnDemandFst combined_lms(lm_to_subtract_det_scale.get(), lm_to_add.get()); 36 | 37 | // Before composing with the LM FST, we scale the lattice weights 38 | // by the inverse of "lm_scale". We'll later scale by "lm_scale". 39 | // We do it this way so we can determinize and it will give the 40 | // right effect (taking the "best path" through the LM) regardless 41 | // of the sign of lm_scale. 42 | if (model->decodable_opts.acoustic_scale != 1.0) { 43 | fst::ScaleLattice(fst::AcousticLatticeScale(model->decodable_opts.acoustic_scale), &clat); 44 | } 45 | kaldi::TopSortCompactLatticeIfNeeded(&clat); 46 | 47 | // compose lattice with combined language model. 48 | kaldi::CompactLattice composed_clat; 49 | kaldi::ComposeCompactLatticePruned(model->compose_opts, clat, 50 | &combined_lms, &composed_clat); 51 | 52 | if (composed_clat.NumStates() == 0) { 53 | // Something went wrong. A warning will already have been printed. 54 | KALDI_WARN << "Empty lattice after RNNLM rescoring."; 55 | } else { 56 | clat = composed_clat; 57 | } 58 | } 59 | 60 | auto lat = make_uniq(); 61 | fst::ConvertLattice(clat, lat.get()); 62 | 63 | kaldi::Lattice nbest_lat; 64 | std::vector nbest_lats; 65 | 66 | fst::ShortestPath(*lat, &nbest_lat, n_best); 67 | fst::ConvertNbestToVector(nbest_lat, &nbest_lats); 68 | 69 | if (nbest_lats.empty()) { 70 | KALDI_WARN << "no N-best entries"; 71 | return; 72 | } 73 | 74 | for (auto const &l : nbest_lats) { 75 | // NOTE: Check why int32s specifically are used here 76 | std::vector input_ids; 77 | std::vector word_ids; 78 | std::vector word_strings; 79 | std::string sentence; 80 | 81 | kaldi::LatticeWeight weight; 82 | fst::GetLinearSymbolSequence(l, &input_ids, &word_ids, &weight); 83 | 84 | for (auto const &wid : word_ids) { 85 | word_strings.push_back(model->word_syms->Find(wid)); 86 | } 87 | string_join(word_strings, " ", sentence); 88 | 89 | Alternative alt; 90 | alt.transcript = sentence; 91 | alt.lm_score = float(weight.Value1()); 92 | alt.am_score = float(weight.Value2()); 93 | alt.confidence = calculate_confidence(alt.lm_score, alt.am_score, word_ids.size()); 94 | 95 | results.push_back(alt); 96 | } 97 | 98 | if (!(options.enable_word_level && word_level)) 99 | return; 100 | 101 | kaldi::CompactLattice aligned_clat; 102 | kaldi::BaseFloat max_expand = 0.0; 103 | int32 max_states; 104 | 105 | if (max_expand > 0) 106 | max_states = 1000 + max_expand * clat.NumStates(); 107 | else 108 | max_states = 0; 109 | 110 | bool ok = kaldi::WordAlignLattice(clat, model->trans_model, *model->wb_info, max_states, &aligned_clat); 111 | 112 | if (!ok) { 113 | if (aligned_clat.Start() != fst::kNoStateId) { 114 | KALDI_WARN << "Outputting partial lattice"; 115 | kaldi::TopSortCompactLatticeIfNeeded(&aligned_clat); 116 | ok = true; 117 | } else { 118 | KALDI_WARN << "Empty aligned lattice, producing no output."; 119 | } 120 | } else { 121 | if (aligned_clat.Start() == fst::kNoStateId) { 122 | KALDI_WARN << "Lattice was empty"; 123 | ok = false; 124 | } else { 125 | kaldi::TopSortCompactLatticeIfNeeded(&aligned_clat); 126 | } 127 | } 128 | 129 | std::vector words; 130 | 131 | // compute confidences and times only if alignment was ok 132 | if (ok) { 133 | kaldi::BaseFloat frame_shift = 0.01; 134 | kaldi::BaseFloat lm_scale = 1.0; 135 | kaldi::MinimumBayesRiskOptions mbr_opts; 136 | mbr_opts.decode_mbr = false; 137 | 138 | fst::ScaleLattice(fst::LatticeScale(lm_scale, model->decodable_opts.acoustic_scale), &aligned_clat); 139 | auto mbr = make_uniq(aligned_clat, mbr_opts); 140 | 141 | const std::vector &conf = mbr->GetOneBestConfidences(); 142 | const std::vector &best_words = mbr->GetOneBest(); 143 | const std::vector> × = mbr->GetOneBestTimes(); 144 | 145 | KALDI_ASSERT(conf.size() == best_words.size() && best_words.size() == times.size()); 146 | 147 | for (size_t i = 0; i < best_words.size(); i++) { 148 | KALDI_ASSERT(best_words[i] != 0 || mbr_opts.print_silence); // Should not have epsilons. 149 | 150 | Word word; 151 | kaldi::BaseFloat time_unit = frame_shift * model->decodable_opts.frame_subsampling_factor; 152 | word.start_time = times[i].first * time_unit; 153 | word.end_time = times[i].second * time_unit; 154 | word.word = model->word_syms->Find(best_words[i]); // lookup word in SymbolTable 155 | word.confidence = conf[i]; 156 | 157 | words.push_back(word); 158 | } 159 | } 160 | 161 | if (!results.empty() and !words.empty()) { 162 | results[0].words = words; 163 | } 164 | } 165 | 166 | } // namespace kaldiserve -------------------------------------------------------------------------------- /src/decoder/decoder-factory.cpp: -------------------------------------------------------------------------------- 1 | // decoder-factory.cpp - Decoder Factory Implementation 2 | 3 | // local includes 4 | #include "config.hpp" 5 | #include "types.hpp" 6 | #include "model.hpp" 7 | #include "decoder.hpp" 8 | 9 | 10 | namespace kaldiserve { 11 | 12 | DecoderFactory::DecoderFactory(const ModelSpec &model_spec) : model_spec(model_spec) { 13 | model_ = make_uniq(model_spec); 14 | } 15 | 16 | } // namespace kaldiserve -------------------------------------------------------------------------------- /src/decoder/decoder-queue.cpp: -------------------------------------------------------------------------------- 1 | // decoder-queue.cpp - Decoder Queue Implementation 2 | 3 | // local includes 4 | #include "config.hpp" 5 | #include "decoder.hpp" 6 | #include "types.hpp" 7 | 8 | 9 | namespace kaldiserve { 10 | 11 | DecoderQueue::DecoderQueue(const ModelSpec &model_spec) { 12 | std::cout << ":: Loading model from " << model_spec.path << ENDL; 13 | 14 | decoder_factory_ = make_uniq(model_spec); 15 | for (size_t i = 0; i < model_spec.n_decoders; i++) { 16 | queue_.push(decoder_factory_->produce()); 17 | } 18 | } 19 | 20 | DecoderQueue::~DecoderQueue() { 21 | while (!queue_.empty()) { 22 | auto decoder = queue_.front(); 23 | queue_.pop(); 24 | delete decoder; 25 | } 26 | } 27 | 28 | void DecoderQueue::push_(Decoder *const item) { 29 | std::unique_lock mlock(mutex_); 30 | queue_.push(item); 31 | mlock.unlock(); 32 | cond_.notify_one(); // condition var notifies another suspended thread (help up in `pop`) 33 | } 34 | 35 | Decoder *DecoderQueue::pop_() { 36 | std::unique_lock mlock(mutex_); 37 | // waits until a decoder object is available 38 | while (queue_.empty()) { 39 | // suspends current thread execution and awaits condition notification 40 | cond_.wait(mlock); 41 | } 42 | auto item = queue_.front(); 43 | queue_.pop(); 44 | return item; 45 | } 46 | 47 | } // namespace kaldiserve -------------------------------------------------------------------------------- /src/decoder/decoder.cpp: -------------------------------------------------------------------------------- 1 | // decoder-cpu.cpp - CPU Decoder Implementation 2 | 3 | // local includes 4 | #include "config.hpp" 5 | #include "decoder.hpp" 6 | #include "types.hpp" 7 | 8 | 9 | namespace kaldiserve { 10 | 11 | Decoder::Decoder(ChainModel *const model) : model_(model) { 12 | 13 | if (model_->wb_info != nullptr) options.enable_word_level = true; 14 | if (model_->rnnlm_info != nullptr) options.enable_rnnlm = true; 15 | 16 | // decoder vars initialization 17 | decoder_ = NULL; 18 | feature_pipeline_ = NULL; 19 | silence_weighting_ = NULL; 20 | adaptation_state_ = NULL; 21 | } 22 | 23 | Decoder::~Decoder() noexcept { 24 | free_decoder(); 25 | } 26 | 27 | void Decoder::start_decoding(const std::string &uuid) noexcept { 28 | free_decoder(); 29 | 30 | adaptation_state_ = new kaldi::OnlineIvectorExtractorAdaptationState(model_->feature_info->ivector_extractor_info); 31 | 32 | feature_pipeline_ = new kaldi::OnlineNnet2FeaturePipeline(*model_->feature_info); 33 | feature_pipeline_->SetAdaptationState(*adaptation_state_); 34 | 35 | decoder_ = new kaldi::SingleUtteranceNnet3Decoder(model_->lattice_faster_decoder_config, 36 | model_->trans_model, *model_->decodable_info, 37 | *model_->decode_fst, feature_pipeline_); 38 | decoder_->InitDecoding(); 39 | 40 | silence_weighting_ = new kaldi::OnlineSilenceWeighting(model_->trans_model, 41 | model_->feature_info->silence_weighting_config, 42 | model_->decodable_opts.frame_subsampling_factor); 43 | 44 | uuid_ = uuid; 45 | } 46 | 47 | void Decoder::free_decoder() noexcept { 48 | if (decoder_) { 49 | delete decoder_; 50 | decoder_ = NULL; 51 | } 52 | if (adaptation_state_) { 53 | delete adaptation_state_; 54 | adaptation_state_ = NULL; 55 | } 56 | if (feature_pipeline_) { 57 | delete feature_pipeline_; 58 | feature_pipeline_ = NULL; 59 | } 60 | if (silence_weighting_) { 61 | delete silence_weighting_; 62 | silence_weighting_ = NULL; 63 | } 64 | uuid_ = ""; 65 | } 66 | 67 | void Decoder::decode_stream_wav_chunk(std::istream &wav_stream) { 68 | kaldi::WaveData wave_data; 69 | wave_data.Read(wav_stream); 70 | 71 | const kaldi::BaseFloat samp_freq = wave_data.SampFreq(); 72 | 73 | // get the data for channel zero (if the signal is not mono, we only 74 | // take the first channel). 75 | kaldi::SubVector wave_part(wave_data.Data(), 0); 76 | std::vector> delta_weights; 77 | _decode_wave(wave_part, delta_weights, samp_freq); 78 | } 79 | 80 | void Decoder::decode_stream_raw_wav_chunk(std::istream &wav_stream, 81 | const float& samp_freq, 82 | const int &data_bytes) { 83 | kaldi::Matrix wave_matrix; 84 | read_raw_wav_stream(wav_stream, data_bytes, wave_matrix); 85 | 86 | // get the data for channel zero (if the signal is not mono, we only 87 | // take the first channel). 88 | kaldi::SubVector wave_part(wave_matrix, 0); 89 | std::vector> delta_weights; 90 | 91 | _decode_wave(wave_part, delta_weights, samp_freq); 92 | } 93 | 94 | void Decoder::decode_wav_audio(std::istream &wav_stream, 95 | const float &chunk_size) { 96 | kaldi::WaveData wave_data; 97 | wave_data.Read(wav_stream); 98 | 99 | // get the data for channel zero (if the signal is not mono, we only 100 | // take the first channel). 101 | kaldi::SubVector data(wave_data.Data(), 0); 102 | const kaldi::BaseFloat samp_freq = wave_data.SampFreq(); 103 | 104 | int32 chunk_length; 105 | if (chunk_size > 0) { 106 | chunk_length = int32(samp_freq * chunk_size); 107 | if (chunk_length == 0) 108 | chunk_length = 1; 109 | } else { 110 | chunk_length = std::numeric_limits::max(); 111 | } 112 | 113 | int32 samp_offset = 0; 114 | std::vector> delta_weights; 115 | 116 | while (samp_offset < data.Dim()) { 117 | int32 samp_remaining = data.Dim() - samp_offset; 118 | int32 num_samp = chunk_length < samp_remaining ? chunk_length : samp_remaining; 119 | 120 | kaldi::SubVector wave_part(data, samp_offset, num_samp); 121 | _decode_wave(wave_part, delta_weights, samp_freq); 122 | 123 | samp_offset += num_samp; 124 | } 125 | } 126 | 127 | void Decoder::decode_raw_wav_audio(std::istream &wav_stream, 128 | const float &samp_freq, 129 | const int &data_bytes, 130 | const float &chunk_size) { 131 | kaldi::Matrix wave_matrix; 132 | read_raw_wav_stream(wav_stream, data_bytes, wave_matrix); 133 | 134 | // get the data for channel zero (if the signal is not mono, we only 135 | // take the first channel). 136 | kaldi::SubVector data(wave_matrix, 0); 137 | 138 | int32 chunk_length; 139 | if (chunk_size > 0) { 140 | chunk_length = int32(samp_freq * chunk_size); 141 | if (chunk_length == 0) 142 | chunk_length = 1; 143 | } else { 144 | chunk_length = std::numeric_limits::max(); 145 | } 146 | 147 | int32 samp_offset = 0; 148 | std::vector> delta_weights; 149 | 150 | while (samp_offset < data.Dim()) { 151 | int32 samp_remaining = data.Dim() - samp_offset; 152 | int32 num_samp = chunk_length < samp_remaining ? chunk_length : samp_remaining; 153 | 154 | kaldi::SubVector wave_part(data, samp_offset, num_samp); 155 | _decode_wave(wave_part, delta_weights, samp_freq); 156 | 157 | samp_offset += num_samp; 158 | } 159 | } 160 | 161 | void Decoder::get_decoded_results(const int &n_best, 162 | utterance_results_t &results, 163 | const bool &word_level, 164 | const bool &bidi_streaming) { 165 | if (!bidi_streaming) { 166 | feature_pipeline_->InputFinished(); 167 | decoder_->AdvanceDecoding(); 168 | decoder_->FinalizeDecoding(); 169 | } 170 | 171 | if (decoder_->NumFramesDecoded() == 0) { 172 | KALDI_WARN << "audio may be empty :: decoded no frames"; 173 | return; 174 | } 175 | 176 | kaldi::CompactLattice clat; 177 | try { 178 | decoder_->GetLattice(true, &clat); 179 | find_alternatives(clat, n_best, results, word_level, model_, options); 180 | } catch (std::exception &e) { 181 | KALDI_ERR << "unexpected error during decoding lattice :: " << e.what(); 182 | } 183 | } 184 | 185 | void Decoder::_decode_wave(kaldi::SubVector &wave_part, 186 | std::vector> &delta_weights, 187 | const kaldi::BaseFloat &samp_freq) { 188 | feature_pipeline_->AcceptWaveform(samp_freq, wave_part); 189 | 190 | if (silence_weighting_->Active() && feature_pipeline_->IvectorFeature() != NULL) { 191 | silence_weighting_->ComputeCurrentTraceback(decoder_->Decoder()); 192 | silence_weighting_->GetDeltaWeights(feature_pipeline_->NumFramesReady(), 193 | &delta_weights); 194 | feature_pipeline_->IvectorFeature()->UpdateFrameWeights(delta_weights); 195 | } 196 | 197 | decoder_->AdvanceDecoding(); 198 | } 199 | 200 | } // namespace kaldiserve -------------------------------------------------------------------------------- /src/model/model-chain.cpp: -------------------------------------------------------------------------------- 1 | // model-chain.cpp - Chain Model Implementation 2 | 3 | // stl includes 4 | #include 5 | #include 6 | 7 | // local includes 8 | #include "model.hpp" 9 | #include "utils.hpp" 10 | #include "types.hpp" 11 | 12 | 13 | namespace kaldiserve { 14 | 15 | ChainModel::ChainModel(const ModelSpec &model_spec) : model_spec(model_spec) { 16 | std::string model_dir = model_spec.path; 17 | 18 | try { 19 | std::string hclg_filepath = join_path(model_dir, "HCLG.fst"); 20 | std::string model_filepath = join_path(model_dir, "final.mdl"); 21 | std::string word_syms_filepath = join_path(model_dir, "words.txt"); 22 | std::string word_boundary_filepath = join_path(model_dir, "word_boundary.int"); 23 | 24 | std::string conf_dir = join_path(model_dir, "conf"); 25 | std::string mfcc_conf_filepath = join_path(conf_dir, "mfcc.conf"); 26 | std::string ivector_conf_filepath = join_path(conf_dir, "ivector_extractor.conf"); 27 | 28 | std::string rnnlm_dir = join_path(model_dir, "rnnlm"); 29 | 30 | decode_fst = std::unique_ptr>(fst::ReadFstKaldiGeneric(hclg_filepath)); 31 | 32 | { 33 | bool binary; 34 | kaldi::Input ki(model_filepath, &binary); 35 | 36 | trans_model.Read(ki.Stream(), binary); 37 | am_nnet.Read(ki.Stream(), binary); 38 | 39 | kaldi::nnet3::SetBatchnormTestMode(true, &(am_nnet.GetNnet())); 40 | kaldi::nnet3::SetDropoutTestMode(true, &(am_nnet.GetNnet())); 41 | kaldi::nnet3::CollapseModel(kaldi::nnet3::CollapseModelConfig(), &(am_nnet.GetNnet())); 42 | } 43 | 44 | if (word_syms_filepath != "" && !(word_syms = std::unique_ptr(fst::SymbolTable::ReadText(word_syms_filepath)))) { 45 | KALDI_ERR << "Could not read symbol table from file " << word_syms_filepath; 46 | } 47 | 48 | if (exists(word_boundary_filepath)) { 49 | kaldi::WordBoundaryInfoNewOpts word_boundary_opts; 50 | wb_info = make_uniq(word_boundary_opts, word_boundary_filepath); 51 | } else { 52 | KALDI_WARN << "Word boundary file" << word_boundary_filepath 53 | << " not found. Disabling word level features."; 54 | } 55 | 56 | if (exists(rnnlm_dir) && 57 | exists(join_path(rnnlm_dir, "final.raw")) && 58 | exists(join_path(rnnlm_dir, "word_embedding.mat")) && 59 | exists(join_path(rnnlm_dir, "G.fst"))) { 60 | 61 | rnnlm_opts.bos_index = std::stoi(model_spec.bos_index); 62 | rnnlm_opts.eos_index = std::stoi(model_spec.eos_index); 63 | 64 | lm_to_subtract_fst = 65 | std::unique_ptr>(fst::ReadAndPrepareLmFst(join_path(rnnlm_dir, "G.fst"))); 66 | rnnlm_weight = model_spec.rnnlm_weight; 67 | 68 | kaldi::ReadKaldiObject(join_path(rnnlm_dir, "final.raw"), &rnnlm); 69 | KALDI_ASSERT(IsSimpleNnet(rnnlm)); 70 | kaldi::ReadKaldiObject(join_path(rnnlm_dir, "word_embedding.mat"), &word_embedding_mat); 71 | 72 | std::cout << "# Word Embeddings (RNNLM): " << word_embedding_mat.NumRows() << ENDL; 73 | 74 | rnnlm_info = 75 | make_uniq(rnnlm_opts, rnnlm, word_embedding_mat); 76 | } else { 77 | KALDI_WARN << "RNNLM artefacts not found. Disabling RNNLM rescoring feature."; 78 | } 79 | 80 | feature_info = make_uniq(); 81 | feature_info->feature_type = "mfcc"; 82 | kaldi::ReadConfigFromFile(mfcc_conf_filepath, &(feature_info->mfcc_opts)); 83 | 84 | feature_info->use_ivectors = true; 85 | kaldi::OnlineIvectorExtractionConfig ivector_extraction_opts; 86 | kaldi::ReadConfigFromFile(ivector_conf_filepath, &ivector_extraction_opts); 87 | 88 | // Expand paths if relative provided. We use model_dir as the base in 89 | // such cases. 90 | ivector_extraction_opts.lda_mat_rxfilename = expand_relative_path(ivector_extraction_opts.lda_mat_rxfilename, model_dir); 91 | ivector_extraction_opts.global_cmvn_stats_rxfilename = expand_relative_path(ivector_extraction_opts.global_cmvn_stats_rxfilename, model_dir); 92 | ivector_extraction_opts.diag_ubm_rxfilename = expand_relative_path(ivector_extraction_opts.diag_ubm_rxfilename, model_dir); 93 | ivector_extraction_opts.ivector_extractor_rxfilename = expand_relative_path(ivector_extraction_opts.ivector_extractor_rxfilename, model_dir); 94 | ivector_extraction_opts.cmvn_config_rxfilename = expand_relative_path(ivector_extraction_opts.cmvn_config_rxfilename, model_dir); 95 | ivector_extraction_opts.splice_config_rxfilename = expand_relative_path(ivector_extraction_opts.splice_config_rxfilename, model_dir); 96 | 97 | feature_info->ivector_extractor_info.Init(ivector_extraction_opts); 98 | feature_info->silence_weighting_config.silence_weight = model_spec.silence_weight; 99 | 100 | lattice_faster_decoder_config.min_active = model_spec.min_active; 101 | lattice_faster_decoder_config.max_active = model_spec.max_active; 102 | lattice_faster_decoder_config.beam = model_spec.beam; 103 | lattice_faster_decoder_config.lattice_beam = model_spec.lattice_beam; 104 | 105 | decodable_opts.acoustic_scale = model_spec.acoustic_scale; 106 | decodable_opts.frame_subsampling_factor = model_spec.frame_subsampling_factor; 107 | decodable_info = make_uniq(decodable_opts, &am_nnet); 108 | 109 | } catch (const std::exception &e) { 110 | KALDI_ERR << e.what(); 111 | } 112 | } 113 | 114 | } // namespace kaldiserve -------------------------------------------------------------------------------- /src/utils/utils-io.cpp: -------------------------------------------------------------------------------- 1 | // utils-io.cpp - I/O Utilities Implementation 2 | 3 | // lib includes 4 | #include 5 | 6 | // vendor includes 7 | #include "vendor/cpptoml.h" 8 | 9 | // local includes 10 | #include "utils.hpp" 11 | #include "types.hpp" 12 | 13 | 14 | namespace kaldiserve { 15 | 16 | std::string expand_relative_path(std::string path, std::string root_path) { 17 | boost::filesystem::path fs_path(path); 18 | if (fs_path.is_absolute()) { 19 | return path; 20 | } else { 21 | boost::filesystem::path fs_root_path(root_path); 22 | return (fs_root_path / fs_path).string(); 23 | } 24 | } 25 | 26 | std::string join_path(std::string a, std::string b) { 27 | boost::filesystem::path fs_a(a); 28 | boost::filesystem::path fs_b(b); 29 | return (fs_a / fs_b).string(); 30 | } 31 | 32 | bool exists(std::string path) { 33 | boost::filesystem::path fs_path(path); 34 | return boost::filesystem::exists(fs_path); 35 | } 36 | 37 | void parse_model_specs(const std::string &toml_path, std::vector &model_specs) { 38 | auto config = cpptoml::parse_file(toml_path); 39 | auto models = config->get_table_array("model"); 40 | 41 | ModelSpec spec; 42 | for (const auto &model : *models) { 43 | auto maybe_path = model->get_as("path"); 44 | auto maybe_name = model->get_as("name"); 45 | auto maybe_language_code = model->get_as("language_code"); 46 | auto maybe_n_decoders = model->get_as("n_decoders"); 47 | 48 | auto maybe_min_active = model->get_as("min_active"); 49 | auto maybe_max_active = model->get_as("max_active"); 50 | auto maybe_frame_subsampling_factor = model->get_as("frame_subsampling_factor"); 51 | auto maybe_beam = model->get_as("beam"); 52 | auto maybe_lattice_beam = model->get_as("lattice_beam"); 53 | auto maybe_acoustic_scale = model->get_as("acoustic_scale"); 54 | auto maybe_silence_weight = model->get_as("silence_weight"); 55 | auto maybe_max_ngram_order = model->get_as("max_ngram_order"); 56 | auto maybe_rnnlm_weight = model->get_as("rnnlm_weight"); 57 | auto maybe_bos_index = model->get_as("bos_index"); 58 | auto maybe_eos_index = model->get_as("eos_index"); 59 | 60 | // TODO: Throw error in case of invalid toml 61 | spec.path = *maybe_path; 62 | spec.name = *maybe_name; 63 | spec.language_code = *maybe_language_code; 64 | 65 | if (maybe_n_decoders) spec.n_decoders = *maybe_n_decoders; 66 | if (maybe_beam) spec.beam = *maybe_beam; 67 | if (maybe_min_active) spec.min_active = *maybe_min_active; 68 | if (maybe_max_active) spec.max_active = *maybe_max_active; 69 | if (maybe_lattice_beam) spec.lattice_beam = *maybe_lattice_beam; 70 | if (maybe_acoustic_scale) spec.acoustic_scale = *maybe_acoustic_scale; 71 | if (maybe_frame_subsampling_factor) spec.frame_subsampling_factor = *maybe_frame_subsampling_factor; 72 | if (maybe_silence_weight) spec.silence_weight = *maybe_silence_weight; 73 | if (maybe_max_ngram_order) spec.max_ngram_order = *maybe_max_ngram_order; 74 | if (maybe_rnnlm_weight) spec.rnnlm_weight = *maybe_rnnlm_weight; 75 | if (maybe_bos_index) spec.bos_index = *maybe_bos_index; 76 | if (maybe_eos_index) spec.eos_index = *maybe_eos_index; 77 | 78 | model_specs.push_back(spec); 79 | } 80 | } 81 | 82 | void string_join(const std::vector &strings, std::string separator, std::string &output) { 83 | output.clear(); 84 | 85 | for (auto i = 0; i < strings.size(); i++) { 86 | output += strings[i]; 87 | 88 | if (i != strings.size() - 1) { 89 | output += separator; 90 | } 91 | } 92 | } 93 | 94 | } // namespace kaldiserve --------------------------------------------------------------------------------