├── .github └── workflows │ └── build_and_publish_to_pypi.yml ├── .gitignore ├── LICENSE ├── README.md ├── openspeechtointent ├── __init__.py ├── forced_alignment.cpp ├── model.py ├── resources │ └── models │ │ ├── citrinet_spectrogram_filterbank.npy │ │ ├── citrinet_tokenizer.pkl │ │ └── citrinet_vocab.json └── utils.py ├── pyproject.toml ├── setup.py └── test ├── intents └── test_intents.json └── resources └── sample_1.wav /.github/workflows/build_and_publish_to_pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python distributions to PyPI 2 | 3 | on: 4 | workflow_dispatch: 5 | create: 6 | tags: 7 | - "*" 8 | 9 | jobs: 10 | build-n-publish: 11 | name: Build and publish Python distributions to PyPI 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@master 15 | - name: Set up Python 3.8 16 | uses: actions/setup-python@v3 17 | with: 18 | python-version: "3.8" 19 | - name: Install pypa/build 20 | run: >- 21 | python -m 22 | pip install 23 | build 24 | --user 25 | - name: Build a source tarball only 26 | run: >- 27 | python -m build --sdist 28 | - name: Publish distribution to PyPI 29 | if: startsWith(github.ref, 'refs/tags') 30 | uses: pypa/gh-action-pypi-publish@release/v1 31 | with: 32 | password: ${{ secrets.PYPI_API_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .aider* 131 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # openSpeechToIntent 2 | 3 | openSpeechToIntent is a library that maps audio containing speech to pre-specified lists of short texts. It can be used to directly map speech to these texts, which can represent categories or intents for various voice automation applications. 4 | 5 | openSpeechToIntent works by using small, but robust speech-to-text models (the default is [Citrinet 256 by NVIDIA](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_citrinet_256)) to generate predictions for characters or tokens, and then using [forced alignment](https://research.nvidia.com/labs/conv-ai/blogs/2023/2023-08-forced-alignment/) to determine if the audio matches different text transcriptions. The score of the alignment is used to determine which (if any) of the transcriptions is a good match to the audio. 6 | 7 | This approach has several useful features: 8 | - No training is required, and any text can be matched to any audio 9 | - It has strong performance while still being relatively fast on a wide range of hardware (see the [Performance](#performance) for more details), as the underlying models are small and efficient 10 | - It can be easily combined with other libraries (e.g., [openWakeWord](https://github.com/dscripka/openwakeword)) and tools to create more complex audio processing pipelines 11 | 12 | ## Installation 13 | 14 | To install openSpeechToIntent, you can simply use pip: 15 | 16 | ```bash 17 | pip install openSpeechToIntent 18 | ``` 19 | 20 | This should work on nearly all operating systems (Windows, macOS, Linux), as there are only four requirements: `numpy`, `onnxruntime`, `pybind11`, and `sentencepiece`. 21 | 22 | ## Usage 23 | 24 | openSpeechToIntent is designed to be simple to use. Just provide a file/array of audio data and a list of target intents, and the library will return information about potential intent matches. 25 | 26 | ```python 27 | 28 | from openspeechtointent.model import CitrinetModel 29 | 30 | # Load model (this will also download the model if it is not already present) 31 | mdl = CitrinetModel() 32 | 33 | # Define some simple intents 34 | intents = ["turn on the light", "pause the music", "set a 5 minute timer"] 35 | 36 | # Load a sample audio file (from the test/resources directory in this repo) 37 | # Can also directly provide a numpy array of 16-bit PCM audio data 38 | audio_file = "test/resources/sample_1.wav" # contains the speech "turn on the lights" 39 | 40 | # Match the audio to the provided intents 41 | matched_intents, scores, durations = mdl.match_intents(audio_file, intents) 42 | 43 | # View the results 44 | for intent, score, duration in zip(matched_intents, scores, durations): 45 | print(f"Intent: {intent}, Score: {score}, Duration: {duration}") 46 | 47 | # Output: 48 | # Intent: "turn on the lights", Score: 0.578, Duration: 0.800 seconds 49 | # Intent: "pause the music", Score: 0.270, Duration: 0.560 seconds 50 | # Intent: "set a 5 minute timer", Score: 0.119, Duration: 0.560 seconds 51 | # Intent: "remind me to buy apples tomorrow", Score: 0.032, Duration: 1.840 seconds 52 | ``` 53 | 54 | Scores are computed from the softmaxed logits from the Citrinet model, to scale them between 0 and 1. The score can be used to select possible matching intents (or none at all) by appropriate thresholds. 55 | 56 | The durations are the approximate length of the intent as aligned to the audio. This can provide another way to filter and select possible intent matches by selecting those that have the most appropriate duration. 57 | 58 | ## Performance 59 | 60 | For many use-cases, the performance of openSpeechToIntent can be surprisingly good. This is a testament to both the high quality of Nvidia pre-trained models, and the way that constraining the speech-to-text decoding to a fixed set of intents greatly reduces the search space of the problem. While real-world performance numbers always depend on the deployment environment, here are some examples use-cases that illustrate the type of performance possible with openSpeechToIntent: 61 | 62 | ### Ignoring false wake word activations 63 | 64 | Wake word detection frameworks like [openWakeWord](https://github.com/dscripka/openWakeWord), [microWakeWord](https://github.com/kahrendt/microWakeWord), etc. are designed to efficiently listen for target activation words while continuously processing input audio. The challenge with these types of systems is maintaining high recall of the target activation words, while not activating on other, unrelated audio. In practice, this is a difficult balance that requires careful training and tuning, and performance can vary widely depending on the environment and the specific wake word. 65 | 66 | One approach to improving the effective performance of these types of systems is to tune the wakeword model to be very sensitive, and then filter out any false activations through other means. openSpeechToIntent can be used in this way, assuming that there is a known list of intents that would normally be expected after a wake word activation. As an example, the table below shows the performance of the pre-trained `Alexa` wake word model from openWakeWord on the 24 hour [PicoVoice wake word benchmark dataset](https://github.com/Picovoice/wake-word-benchmark), where the model's threshold is set very low (0.1) to ensure very high recall but low precision. With this configuration, false positive rate on the Picovoice dataset is ~3.58, unnacceptably high. However, by using openSpeechToIntent to verify that the speech after the activation matches a list ~400 expected intents (see the list [here](test/intents/test_intents.json)), the false positive rate can be reduced to <0.04 false activations per hour. 67 | 68 | | openWakeWord Model | openWakeWord Score Threshold | openSpeechToIntent Score Threshold | False Positives per Hour | 69 | |---------------------|------------------------------|------------------------------------|--------------------------| 70 | | Alexa | 0.1 | NA | ~3.58 | 71 | | Alexa | 0.1 | 0.1 | <0.04 | 72 | 73 | 74 | ### Precisely matching intents 75 | 76 | openSpeechToIntent can also be used to perform more fine-grained classification across a reasonable number of intents. As an example, on the [Fluent Speech Commands](https://fluent.ai/fluent-speech-commands-a-dataset-for-spoken-language-understanding-research/) test set, an average accuracy of ~98% across 31 intents is possible. 77 | 78 | | Model | Accuracy | 79 | |-------|----------| 80 | | openSpeechToIntent | 98.2% | 81 | | [SOTA](https://paperswithcode.com/paper/finstreder-simple-and-fast-spoken-language) | 99.7% | 82 | 83 | 84 | ### Efficiency 85 | 86 | openSpeechToIntent is designed to be reasonably efficient, and can run on a wide range of hardware included normal desktop CPUs and moderately powerfull SBCs. The table below shows the performance of the default Nvidia Citrinet model on a several different systems, using a 3 second audio clip as input and matching against 400 intents. 87 | 88 | | CPU | Number of Threads | Time to Process 3s Audio Clip (ms) | Time to Match Against 400 Intents (ms) | 89 | |-----|-------------------|-------------------------------|---------------------------| 90 | | Intel Xeon W-2123 | 1 | 103 | 21 | 91 | | AMD Ryzen 1600 | 1 | 98 | 17 | 92 | | Raspberry Pi 4 | 1 | 320 | 56 | 93 | | Raspberry Pi 4 | 2 | 262 | 56 | 94 | 95 | Note that further optimizations are possible, and in general the length of the audio clip and the number of intents will have the largest impact on the efficiency of the system. By using hueristics to limit the number of intents to match and precisely controlling the length of the audio clip, performance can be further improved. 96 | 97 | ## Advanced Usage 98 | 99 | ### Using raw logit scores 100 | 101 | In some cases, instead of applying the softmax transform to the scores, it may be useful to use the raw logit scores directly. For example, if you are primarily using openSpeechToIntent to filter out false wake word activations from another system, the raw logit score can make it easer to set a global threshold that works well across all intents. 102 | 103 | As an example, suppose you have a set of 5 intents with raw logits scores of `[-6, -7, -8, -9, -10]`. In absolute terms, these scores are quite low, and none of the intents have good alignments with the audio. However, applying softmax to these scores gives `[0.6364, 0.2341, 0.0861, 0.0317, 0.0117]`, which falsely implies that the first intent is a good match. By carefully setting a threshold on the raw logit scores by tuning against your specific use-case and deployment environment, you can often achieve better performance compared to using the softmax scores. 104 | 105 | ## Limitations 106 | 107 | Currently, the library only supports matching English speech to english intents. Future work may involve expanding to other languages and supporting other speech-to-text frameworks like [whisper.cpp](https://github.com/ggerganov/whisper.cpp). 108 | 109 | ## Testing 110 | 111 | To run the tests for openSpeechToIntent, clone the repo and install with the test requirements: 112 | 113 | ```bash 114 | git clone https://github.com/dscripka/openSpeechToIntent.git 115 | cd openSpeechToIntent 116 | pip install -e ./[test] 117 | ``` 118 | 119 | Then run the tests with pytest: 120 | 121 | ```bash 122 | pytest 123 | ``` 124 | 125 | ## Acknowledgements 126 | 127 | Many thanks to Nvidia for the excellent Citrinet speech-to-text models, as well as many other highly performant speech and audio models. 128 | 129 | Also, credit to @MahmoudAshraf97 for the excellent modification of the [torch forced alignment cpp functions](https://github.com/MahmoudAshraf97/ctc-forced-aligner/blob/main/ctc_forced_aligner/forced_align_impl.cpp) to simplify dependencies and enable easy usage with `pybind`. 130 | 131 | ## License 132 | 133 | This code in this project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details. Portions of the code adapted in whole or part from other repositories are licensed under their respective licenses, as appropriate. 134 | 135 | The Nvidia Citrinit models is licensed under the [CC-BY-4.0 license](https://creativecommons.org/licenses/by/4.0/) and the NGC [Terms of Use](https://ngc.nvidia.com/legal/terms). -------------------------------------------------------------------------------- /openspeechtointent/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 David Scripka. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | 18 | MODELS = { 19 | "stft": { 20 | "model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/torchlibrosa_stft.onnx"), 21 | "download_url": "https://github.com/dscripka/openSpeechtoIntent/releases/download/v0.1.0.alpha/torchlibrosa_stft.onnx" 22 | }, 23 | "citrinet_256": { 24 | "model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/stt_en_citrinet_256.onnx"), 25 | "download_url": "https://github.com/dscripka/openSpeechtoIntent/releases/download/v0.1.0.alpha/stt_en_citrinet_256.onnx" 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /openspeechtointent/forced_alignment.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include // For handling STL containers 4 | #include 5 | #include 6 | #include 7 | 8 | namespace py = pybind11; 9 | 10 | template 11 | void forced_align_impl( 12 | const py::array_t& logProbs, 13 | const py::array_t& targets, 14 | const int64_t blank, 15 | py::array_t& paths) { 16 | const scalar_t kNegInfinity = -std::numeric_limits::infinity(); 17 | const auto batchIndex = 0; // TODO: support batch version and use the real batch index 18 | const auto T = logProbs.shape(1); 19 | const auto L = targets.shape(1); 20 | const auto S = 2 * L + 1; 21 | std::vector alphas(2 * S, kNegInfinity); 22 | std::vector backPtrBit0((S + 1) * (T - L), false); 23 | std::vector backPtrBit1((S + 1) * (T - L), false); 24 | std::vector backPtr_offset(T - 1); 25 | std::vector backPtr_seek(T - 1); 26 | auto logProbs_data = logProbs.template unchecked<3>(); 27 | auto targets_data = targets.template unchecked<2>(); 28 | auto paths_data = paths.template mutable_unchecked<2>(); 29 | auto R = 0; 30 | 31 | for (auto i = 1; i < L; i++) { 32 | if (targets_data(batchIndex, i) == targets_data(batchIndex, i - 1)) { 33 | ++R; 34 | } 35 | } 36 | if (T < L + R) { 37 | throw std::runtime_error("targets length is too long for CTC."); 38 | } 39 | auto start = T - (L + R) > 0 ? 0 : 1; 40 | auto end = (S == 1) ? 1 : 2; 41 | for (auto i = start; i < end; i++) { 42 | auto labelIdx = (i % 2 == 0) ? blank : targets_data(batchIndex, i / 2); 43 | alphas[i] = logProbs_data(batchIndex, 0, labelIdx); 44 | } 45 | unsigned long long seek = 0; 46 | for (auto t = 1; t < T; t++) { 47 | if (T - t <= L + R) { 48 | if ((start % 2 == 1) && 49 | targets_data(batchIndex, start / 2) != targets_data(batchIndex, start / 2 + 1)) { 50 | start = start + 1; 51 | } 52 | start = start + 1; 53 | } 54 | if (t <= L + R) { 55 | if (end % 2 == 0 && end < 2 * L && 56 | targets_data(batchIndex, end / 2 - 1) != targets_data(batchIndex, end / 2)) { 57 | end = end + 1; 58 | } 59 | end = end + 1; 60 | } 61 | auto startloop = start; 62 | auto curIdxOffset = t % 2; 63 | auto prevIdxOffset = (t - 1) % 2; 64 | std::fill(alphas.begin() + curIdxOffset * S, alphas.begin() + (curIdxOffset + 1) * S, kNegInfinity); 65 | backPtr_seek[t - 1] = seek; 66 | backPtr_offset[t - 1] = start; 67 | if (start == 0) { 68 | alphas[curIdxOffset * S] = alphas[prevIdxOffset * S] + logProbs_data(batchIndex, t, blank); 69 | startloop += 1; 70 | seek += 1; 71 | } 72 | for (auto i = startloop; i < end; i++) { 73 | auto x0 = alphas[prevIdxOffset * S + i]; 74 | auto x1 = alphas[prevIdxOffset * S + i - 1]; 75 | auto x2 = kNegInfinity; 76 | auto labelIdx = (i % 2 == 0) ? blank : targets_data(batchIndex, i / 2); 77 | if (i % 2 != 0 && i != 1 && 78 | targets_data(batchIndex, i / 2) != targets_data(batchIndex, i / 2 - 1)) { 79 | x2 = alphas[prevIdxOffset * S + i - 2]; 80 | } 81 | scalar_t result = 0.0; 82 | if (x2 > x1 && x2 > x0) { 83 | result = x2; 84 | backPtrBit1[seek + i - startloop] = true; 85 | } else if (x1 > x0 && x1 > x2) { 86 | result = x1; 87 | backPtrBit0[seek + i - startloop] = true; 88 | } else { 89 | result = x0; 90 | } 91 | alphas[curIdxOffset * S + i] = result + logProbs_data(batchIndex, t, labelIdx); 92 | } 93 | seek += (end - startloop); 94 | } 95 | auto idx1 = (T - 1) % 2; 96 | auto ltrIdx = alphas[idx1 * S + S - 1] > alphas[idx1 * S + S - 2] ? S - 1 : S - 2; 97 | for (auto t = T - 1; t > -1; t--) { 98 | auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_data(batchIndex, ltrIdx / 2); 99 | paths_data(batchIndex, t) = lbl_idx; 100 | auto t_minus_one = t - 1 >= 0 ? t - 1 : 0; 101 | auto backPtr_idx = backPtr_seek[t_minus_one] + 102 | ltrIdx - backPtr_offset[t_minus_one]; 103 | ltrIdx -= (backPtrBit1[backPtr_idx] << 1) | backPtrBit0[backPtr_idx]; 104 | } 105 | } 106 | 107 | std::tuple, py::array_t> compute( 108 | const py::array_t& logProbs, 109 | const py::array_t& targets, 110 | const int64_t blank) { 111 | if (logProbs.ndim() != 3) throw std::runtime_error("log_probs must be a 3-D array."); 112 | if (targets.ndim() != 2) throw std::runtime_error("targets must be a 2-D array."); 113 | if (logProbs.shape(0) != 1) throw std::runtime_error("Batch size must be 1."); 114 | const auto B = logProbs.shape(0); 115 | const auto T = logProbs.shape(1); 116 | auto paths = py::array_t({B, T}); 117 | forced_align_impl(logProbs, targets, blank, paths); 118 | auto aligned_paths = paths.unchecked<2>(); 119 | auto scores = py::array_t({T}); 120 | auto scores_data = scores.mutable_data(); 121 | auto logProbs_data = logProbs.unchecked<3>(); 122 | for (auto t = 0; t < T; ++t) { 123 | scores_data[t] = logProbs_data(0, t, aligned_paths(0, t)); 124 | } 125 | return std::make_tuple(paths, scores); 126 | } 127 | 128 | std::vector, py::array_t>> compute_all( 129 | const py::array_t& logProbs, 130 | const std::vector>& targets_list, 131 | const int64_t blank) { 132 | std::vector, py::array_t>> results; 133 | for (const auto& targets : targets_list) { 134 | auto result = compute(logProbs, targets, blank); 135 | results.push_back(result); 136 | } 137 | return results; 138 | } 139 | 140 | PYBIND11_MODULE(forced_alignment, m) { 141 | m.def("forced_align_single_sequence", &compute, "Compute forced alignment for a single target list."); 142 | m.def("forced_align_multiple_sequence", &compute_all, "Compute forced alignment for a list of target lists."); 143 | } -------------------------------------------------------------------------------- /openspeechtointent/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 David Scripka. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | ################## 16 | # Several functions and methods in the files were adapted in whole in or part from several other libraries 17 | # including forced alignment related functions from torchaudio 18 | # (https://pytorch.org/audio/main/tutorials/ctc_forced_alignment_api_tutorial.html) 19 | # and the excellent pure cpp implementation of the torch forced alignment code by 20 | # @MahmoudAshraf97 (https://github.com/MahmoudAshraf97/ctc-forced-aligner/blob/main/ctc_forced_aligner/forced_align_impl.cpp) 21 | 22 | import os 23 | import numpy as np 24 | from typing import List, NamedTuple, Tuple, Union 25 | import json 26 | import pickle 27 | import onnxruntime as ort 28 | import wave 29 | import difflib 30 | from openspeechtointent.forced_alignment import forced_align_multiple_sequence 31 | from openspeechtointent.utils import download_file 32 | from openspeechtointent import MODELS 33 | 34 | 35 | class TokenSpan(NamedTuple): 36 | """ 37 | A basic class to represent a token span with a score. 38 | """ 39 | token: int 40 | start: int 41 | end: int 42 | score: float 43 | 44 | 45 | class CitrinetModel: 46 | def __init__(self, 47 | model_path: str = os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/stt_en_citrinet_256.onnx"), 48 | ncpu: int = 1 49 | ): 50 | """ 51 | Initialize the Citrinet model, including pre-processing functions. 52 | Model is obtained from https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_citrinet_256 53 | and then converted to the ONNX format using the standard Nvidia NeMo tools. 54 | 55 | Args: 56 | model_path (str): Path to the Citrinet model 57 | ncpu (int): Number of threads to use for inference of the Citrinet model 58 | """ 59 | # Download models from github release if the don't already exist 60 | for model in MODELS.keys(): 61 | if not os.path.exists(MODELS[model]["model_path"]): 62 | download_file(MODELS[model]["download_url"], os.path.dirname(MODELS[model]["model_path"])) 63 | 64 | # limit to specified number of threads 65 | sess_options = ort.SessionOptions() 66 | sess_options.intra_op_num_threads = ncpu 67 | sess_options.inter_op_num_threads = ncpu 68 | 69 | # Load ASR model 70 | self.asr_model = ort.InferenceSession(model_path, sess_options=sess_options) 71 | 72 | # Load stft model 73 | location = os.path.dirname(os.path.abspath(__file__)) 74 | self.stft = ort.InferenceSession(os.path.join(location, "resources/models/torchlibrosa_stft.onnx"), sess_options=sess_options) 75 | 76 | # Load filterbank 77 | filterbank_path = os.path.join(location, "resources/models/citrinet_spectrogram_filterbank.npy") 78 | self.filterbank = np.load(filterbank_path) 79 | 80 | # Load tokenizer and vocab 81 | tokenizer_path = os.path.join(location, "resources/models/citrinet_tokenizer.pkl") 82 | self.tokenizer = pickle.load(open(tokenizer_path, "rb")) 83 | vocab_path = os.path.join(location, "resources/models/citrinet_vocab.json") 84 | self.vocab = json.load(open(vocab_path, 'r')) 85 | 86 | # Initialize similarity matrix attribute 87 | self.similarity_matrix = None 88 | 89 | def build_intent_similarity_matrix(self, intents: List[str]) -> np.ndarray: 90 | """Builds a similarity matrix between intents using the longest common subsequence algorithm. 91 | 92 | Args: 93 | intents (List[str]): List of intents 94 | 95 | Returns: 96 | np.ndarray: Similarity matrix 97 | """ 98 | n = len(intents) 99 | matrix = np.ones((n, n), dtype=float) 100 | 101 | for i in range(n): 102 | for j in range(i+1, n): 103 | # Calculate similarty using longest common subsequence 104 | similarity = difflib.SequenceMatcher(None, intents[i], intents[j]).find_longest_match( 105 | 0, len(intents[i]), 0, len(intents[j]) 106 | ).size/len(intents[i]) 107 | 108 | # Fill the matrix symmetrically 109 | matrix[i][j] = similarity 110 | matrix[j][i] = similarity 111 | 112 | return matrix 113 | 114 | def rerank_intents(self, 115 | logits: np.ndarray, 116 | intents: List[str], 117 | scores: List[float], 118 | method: str = "longer_match", 119 | partial_match_penalty: float = 0.1 120 | ) -> Tuple[List[str], List[float]]: 121 | """Rerank intents using various hueristics, which can improve accuracy in some cases. 122 | 123 | Args: 124 | logits (np.ndarray): Logits from the ASR model 125 | intents (List[str]): List of intents 126 | scores (List[float]): List of scores for the intents 127 | method (str): Method to use for reranking. Options are "longer_match" and "partial_match". 128 | "partial_match" will rerank intents by penalizing intents that are fully contained within other intents. 129 | "longer_match" will rerank intents by preferring longer intents over shorter ones when they have similar scores. 130 | partial_intent_penalty (float): Score penalty for intents that are fully contained within other intents when using the 131 | "partial_match" method. 132 | 133 | Returns: 134 | tuple: Reranked intents and scores (if applicable) 135 | """ 136 | 137 | if method == "longer_match": 138 | # Reranking intents that are similar in score by the length of the intent 139 | # This prefers longer matches over shorter ones when there are several very similar options 140 | 141 | buckets = [] 142 | for i, j in zip(intents, scores): 143 | if buckets == [] or abs(j - buckets[-1][0][1]) >= 0.10: 144 | buckets.append([(i, j)]) 145 | else: 146 | buckets[-1].append((i, j)) 147 | 148 | # Sort buckets by length of intents 149 | reranked_buckets = [sorted(i, key=lambda x: len(x[0]), reverse=True) if len(i) > 1 else i for i in buckets] 150 | reranked_intents = [j[0] for i in reranked_buckets for j in i] 151 | 152 | return reranked_intents, [] 153 | 154 | if method == "partial_match": 155 | # See if any intents are completely contained within other longer intents, and if so prefer longer intents 156 | # by penalizing the score of the shorter intents, but only if the score of the unique portion in the 157 | # longer intents is above a threshold (that is, is likely also present in the logits) 158 | new_scores = [i for i in scores] 159 | for ndx, intent in enumerate(intents): 160 | if any([intent in j for j in intents if intent != j]): 161 | # Get the unique sequence from the longer intents 162 | unique_sequences = [j.replace(intent, "").strip() for j in intents if intent != j and intent in j] 163 | unique_scores = self.get_forced_alignment_score(logits, unique_sequences + [intent], softmax_scores=True)[1] 164 | contained_intent_score = unique_scores[-1] 165 | if any([abs(i - contained_intent_score) < 0.1*contained_intent_score for i in unique_scores]): 166 | # Penalize the score of the content contained within the other intents 167 | new_scores[ndx] -= partial_match_penalty 168 | 169 | # Reorder the intents by the updated scores 170 | reranked_intents = [intents[i] for i in np.argsort(new_scores)[::-1]] 171 | reranked_scores = np.sort(new_scores)[::-1].tolist() 172 | 173 | return reranked_intents, reranked_scores 174 | 175 | def match_intents_by_similarity(self, 176 | logits: np.ndarray, 177 | s: np.ndarray, 178 | intents: List[str], 179 | sim_threshold: float = 0.6, 180 | topk: int = 5, 181 | **kwargs 182 | ): 183 | """ 184 | Searches the similarity matrix for intents that are similar, 185 | and have a score above the threshold. Can reduce the number of calls to the forced alignment models by 30-50% in most cases, 186 | which reduces total latency. 187 | 188 | Args: 189 | logits (np.ndarray): Logits from the ASR model used for forced alignment 190 | s (np.ndarray): Similarity matrix for the intents 191 | intents (List[str]): List of intents to search 192 | sim_threshold (float): Similarity threshold to group intents. Lower values will group more intents, which increases efficiency 193 | at the cost of recall. Scores approaching 1 are essentially the same as exhaustive search. 194 | topk (int): Number of top intents to return 195 | kwargs: Additional keyword arguments to pass to the `get_forced_alignment_score` function 196 | 197 | Returns: 198 | tuple: intents, scores, durations (respectively) that meet the thresholds 199 | """ 200 | # Sort the rows by sum of similarities 201 | sums = np.sum(s, axis=1) 202 | sorted_row_indices = np.argsort(sums) 203 | 204 | # Get the score of the intents 205 | top_intents = [] 206 | top_scores = [] 207 | top_durations = [] 208 | excluded_indices = [] 209 | for ndx in sorted_row_indices: 210 | if ndx in excluded_indices: 211 | continue 212 | 213 | # Get score of the intent 214 | _, score, duration = self.get_forced_alignment_score(logits, [intents[ndx]], softmax_scores=False) 215 | 216 | top_intents.append(intents[ndx]) 217 | top_scores.append(score[0]) 218 | top_durations.append(duration[0]) 219 | 220 | # Exclude indices by similarity 221 | intent_ndcs = np.where(s[ndx, :] > sim_threshold)[0] 222 | excluded_indices.extend(intent_ndcs) 223 | 224 | # Get the topk results 225 | topk_score_ndcs = np.array(top_scores).argsort()[::-1][0:topk] 226 | topk_intents = np.array(top_intents)[topk_score_ndcs] 227 | topk_scores = np.array(top_scores)[topk_score_ndcs] 228 | topk_durations = np.array(top_durations)[topk_score_ndcs] 229 | 230 | # Get topk scores and apply softmax to scores 231 | if kwargs.get("softmax_scores") is True: 232 | topk_scores = np.round(np.exp(topk_scores)/np.sum(np.exp(topk_scores)), 4) 233 | 234 | return topk_intents, topk_scores, topk_durations 235 | 236 | def get_seq_len(self, seq_len: np.ndarray) -> np.ndarray: 237 | """ 238 | Get the sequence length for the given input length. 239 | Note! This has hard-coded values for the default Citrinet 256 model from Nvidia. 240 | 241 | Args: 242 | seq_len (np.ndarray): Input sequence length 243 | 244 | Returns: 245 | np.ndarray: Sequence length for the model 246 | """ 247 | pad_amount = 512 // 2 * 2 248 | seq_len = np.floor_divide((seq_len + pad_amount - 512), 160) + 1 249 | return seq_len.astype(np.int64) 250 | 251 | def normalize_batch(self, x: np.ndarray, seq_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 252 | """ 253 | Normalize the input batch of features. 254 | 255 | Args: 256 | x (np.ndarray): Input features 257 | seq_len (np.ndarray): Sequence length 258 | normalize_type (str): Type of normalization to apply. Options are "per_feature" or "per_batch" 259 | 260 | Returns: 261 | tuple: Normalized features, mean, and standard deviation 262 | """ 263 | x_mean = None 264 | x_std = None 265 | batch_size, _, max_time = x.shape 266 | 267 | time_steps = np.tile(np.arange(max_time)[np.newaxis, :], (batch_size, 1)) 268 | valid_mask = time_steps < seq_len[:, np.newaxis] 269 | 270 | x_mean_numerator = np.where(valid_mask[:, np.newaxis, :], x, 0.0).sum(axis=2) 271 | x_mean_denominator = valid_mask.sum(axis=1) 272 | x_mean = x_mean_numerator / x_mean_denominator[:, np.newaxis] 273 | 274 | # Subtract 1 in the denominator to correct for the bias. 275 | x_std = np.sqrt( 276 | np.sum(np.where(valid_mask[:, np.newaxis, :], x - x_mean[:, :, np.newaxis], 0.0) ** 2, axis=2) 277 | / (x_mean_denominator[:, np.newaxis] - 1.0) 278 | ) 279 | # make sure x_std is not zero 280 | x_std += 1e-5 281 | return (x - x_mean[:, :, np.newaxis]) / x_std[:, :, np.newaxis], x_mean, x_std 282 | 283 | def get_features(self, x: np.ndarray, length: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 284 | """ 285 | Get the melspectrogram audio features for the raw input audio. 286 | 287 | Args: 288 | x (np.ndarray): Input audio 289 | length (np.ndarray): Length of the audio 290 | 291 | Returns: 292 | tuple: Features and sequence length (both as np.ndarrays) 293 | 294 | """ 295 | # get sequence length 296 | seq_len = self.get_seq_len(length) 297 | 298 | # do preemphasis 299 | preemph = 0.97 300 | x = np.concatenate((x[:, 0:1], x[:, 1:] - preemph * x[:, :-1]), axis=1) 301 | 302 | # do stft 303 | x = np.vstack(self.stft.run( 304 | [self.stft.get_outputs()[0].name, self.stft.get_outputs()[1].name], 305 | {self.stft.get_inputs()[0].name: x}) 306 | ) 307 | 308 | # convert to magnitude 309 | guard = 0 310 | x = np.sqrt((x**2).sum(axis=0) + guard).T.squeeze() 311 | 312 | # get power spectrum 313 | x = x**2 314 | 315 | # dot with filterbank energies 316 | x = np.matmul(self.filterbank, x) 317 | 318 | # log features if required 319 | x = np.log(x + 5.960464477539063e-08) 320 | 321 | # normalize if required 322 | x, _, _ = self.normalize_batch(x, seq_len) 323 | 324 | # mask to zero any values beyond seq_len in batch, pad to multiple of `pad_to` (for efficiency) 325 | max_len = x.shape[-1] 326 | mask = np.arange(max_len).reshape(1, -1) >= seq_len.reshape(-1, 1) 327 | x = np.where(mask[:, np.newaxis, :], 0, x) 328 | 329 | pad_to = 16 330 | pad_amt = x.shape[-1] % pad_to 331 | if pad_amt != 0: 332 | x = np.pad(x, ((0, 0), (0, 0), (0, pad_to - pad_amt)), mode='constant', constant_values=0) 333 | 334 | return x, seq_len 335 | 336 | def merge_tokens(self, tokens: np.ndarray, scores: np.ndarray, blank: int = 0, sr=16000) -> List[TokenSpan]: 337 | """Removes repeated tokens and blank tokens from the given CTC token sequence. 338 | 339 | Args: 340 | tokens (np.ndarray): Alignment tokens (unbatched) returned from forced_align. 341 | Shape: (time,). 342 | scores (np.ndarray): Alignment scores (unbatched) returned from forced_align. 343 | Shape: (time,). When computing the token-size score, the given score is averaged 344 | across the corresponding time span. 345 | 346 | Returns: 347 | list of TokenSpan objects 348 | """ 349 | if tokens.ndim != 1 or scores.ndim != 1: 350 | raise ValueError("`tokens` and `scores` must be 1D numpy arrays.") 351 | if len(tokens) != len(scores): 352 | raise ValueError("`tokens` and `scores` must be the same length.") 353 | 354 | diff = np.diff(np.concatenate(([-1], tokens, [-1]))) 355 | changes_wo_blank = np.nonzero(diff != 0)[0] 356 | tokens = tokens.tolist() 357 | 358 | # Get the spans, and calculate times (adjusting for padding) and scores 359 | spans = [ 360 | TokenSpan(token=self.vocab[token], start=max(0, (start*1280-4300)/sr), end=max(0, (end*1280-4300)/sr), score=np.mean(scores[start:end])) 361 | for start, end in zip(changes_wo_blank[:-1], changes_wo_blank[1:]) 362 | if (token := tokens[start]) != blank 363 | ] 364 | return spans 365 | 366 | def get_detailed_alignment(self, logits: np.ndarray, texts: List[str]) -> List[TokenSpan]: 367 | """Get detailed alignment of the texts to the logits, with all timing information. 368 | 369 | Args: 370 | logits (np.ndarray): Logits from the ASR model 371 | texts (List[str]): List of texts to align 372 | 373 | Returns: 374 | list of TokenSpan objects 375 | """ 376 | # Get tokens for texts 377 | new_ids = self.tokenizer.encode(texts) 378 | 379 | # filter out sequences with no tokens 380 | texts = [text for text, i in zip(texts, new_ids) if i != []] 381 | new_ids = [i for i in new_ids if i != []] 382 | 383 | # Ensure that tokens are not longer than the time steps in the logits, otherwise truncate 384 | new_ids = [i if len(i) < logits.shape[0] else i[:logits.shape[0]-1] for i in new_ids] 385 | 386 | # Convert token sequences to numpy arrays with the right shape for forced alignment 387 | new_ids = [np.array(i)[None, ] for i in new_ids] 388 | 389 | # Get forced alignments for all sequences 390 | alignment = forced_align_multiple_sequence( 391 | logits[None, ], 392 | new_ids, 393 | len(self.vocab)-1 394 | ) 395 | 396 | # Get token labels for the sequence 397 | t_labels = alignment[0][0].flatten() 398 | t_scores = alignment[0][1].flatten() 399 | 400 | # Merge the tokens 401 | spans = self.merge_tokens(t_labels, t_scores) 402 | 403 | return spans 404 | 405 | def get_forced_alignment_score(self, 406 | logits: np.ndarray, 407 | texts: List[str], 408 | topk: int = 5, 409 | softmax_scores: bool = True, 410 | sr: int = 16000 411 | ) -> Tuple[List[float], List[float]]: 412 | """ 413 | Get the forced alignment score for the given logits and text. Scores are optionally softmaxed to so that the 414 | score across the topk texts sum to 1. 415 | 416 | Args: 417 | logits (np.ndarray): Logits from the ASR model 418 | texts (List[str]): List of texts to align 419 | topk (int): Number of texts highest score intents to return 420 | softmax_scores (bool): If True, will apply softmax to the scores 421 | sr (int): Sample rate of the audio 422 | 423 | Returns: 424 | tuple: List of text, scores, and durations for best alignment of each text to the logits 425 | """ 426 | # Get tokens for texts 427 | new_ids = self.tokenizer.encode(texts) 428 | 429 | # filter out sequences with no tokens 430 | texts = [text for text, i in zip(texts, new_ids) if i != []] 431 | new_ids = [i for i in new_ids if i != []] 432 | 433 | # Ensure that tokens are not longer than the time steps in the logits, otherwise truncate 434 | new_ids = [i if len(i) < logits.shape[0] else i[:logits.shape[0]-1] for i in new_ids] 435 | 436 | # Convert token sequences to numpy arrays with the right shape for forced alignment 437 | new_ids = [np.array(i)[None, ] for i in new_ids] 438 | 439 | # Get forced alignments for all sequences 440 | alignments = forced_align_multiple_sequence( 441 | logits[None, ], 442 | new_ids, 443 | len(self.vocab)-1 444 | ) 445 | 446 | # Get forced alignments 447 | scores, durations = [], [] 448 | for alignment in alignments: 449 | # Get token labels for the sequence 450 | t_labels = alignment[0].flatten() 451 | 452 | # Get the average score of the unmerged sequence of tokens (empirically works better than mean after merging) 453 | score = round(alignment[1].mean(), 3) 454 | 455 | # Get the duration of the aligned tokens (don't merge CTC labels as this is slow 456 | # and we only need the total duration) 457 | non_space_tokens = [ndx for ndx, i in enumerate(t_labels) if i != 1024] 458 | start = non_space_tokens[0]*1280/sr 459 | end = (non_space_tokens[-1] + 1)*1280/sr 460 | duration = round(end - start, 3) 461 | 462 | durations.append(duration) 463 | scores.append(score) 464 | 465 | # Get topk texts 466 | sorted_scores_ndcs = np.array(scores).argsort()[::-1][0:topk] 467 | topk_texts = np.array(texts)[sorted_scores_ndcs] 468 | topk_scores = np.array(scores)[sorted_scores_ndcs] 469 | durations = np.array(durations)[sorted_scores_ndcs] 470 | 471 | # Get topk scores and apply softmax to scores 472 | if softmax_scores is True: 473 | topk_scores_sm = np.round(np.exp(topk_scores)/np.sum(np.exp(topk_scores)), 4) 474 | return topk_texts, topk_scores_sm, durations 475 | 476 | return topk_texts, topk_scores, durations 477 | 478 | def get_audio_features(self, audio: Union[str, np.ndarray]) -> Tuple[np.ndarray, np.ndarray]: 479 | """ 480 | Get the audio features for the given audio file or numpy array. Must contain 16-bit 16 khz PCM audio. 481 | 482 | Args: 483 | audio (Union[str, np.ndarray]): Audio file or numpy array of audio 484 | 485 | Returns: 486 | tuple: Features and sequence length (both as np.ndarrays) 487 | """ 488 | if isinstance(audio, str): 489 | with wave.open(audio, 'rb') as wav_file: 490 | n_frames = wav_file.getnframes() 491 | wav_dat = np.frombuffer(wav_file.readframes(n_frames), dtype=np.int16) 492 | else: 493 | wav_dat = audio 494 | 495 | # Check the data type 496 | if wav_dat.dtype != np.int16 or np.abs(wav_dat.max()) < 1.1: 497 | raise ValueError("Audio data must be 16-bit PCM!") 498 | 499 | # Convert to float32 from 16-bit PCM 500 | wav_dat = (wav_dat.astype(np.float32) / 32767) 501 | # forced alignment scores seems sensitive to this pad value? Sometimes get seg faults if it is too small? 502 | wav_dat = np.pad(wav_dat, (4300, 4300), mode='constant') 503 | all_features, lengths = self.get_features(wav_dat[None, ], np.array([wav_dat.shape[0]])) 504 | 505 | return all_features, lengths 506 | 507 | def get_logits(self, audio: Union[str, np.ndarray]) -> np.ndarray: 508 | """ 509 | Get the logits for the given audio file or numpy array using the Citrinet model. 510 | 511 | Args: 512 | audio (Union[str, np.ndarray]): Audio file or numpy array of audio 513 | 514 | Returns: 515 | np.ndarray: Logits from the ASR model 516 | """ 517 | # Preprocess audio 518 | all_features, lengths = self.get_audio_features(audio) 519 | 520 | # Transcribe processed audio with the onnx model 521 | logits = self.asr_model.run(None, {self.asr_model.get_inputs()[0].name: all_features.astype(np.float32), "length": lengths}) 522 | logits = logits[0][0] 523 | 524 | return logits 525 | 526 | def match_intents(self, 527 | audio: Union[str, np.ndarray], 528 | intents: List[str] = [], 529 | topk: int = 5, 530 | approximate: bool = False, 531 | softmax_scores: bool = True, 532 | ) -> Tuple[List[str], List[float], List[float]]: 533 | """ 534 | Match the intents for the given audio file or numpy array. 535 | 536 | Args: 537 | audio (Union[str, np.ndarray]): Audio file or numpy array of audio 538 | intents (List[str]): List of intents to search 539 | topk (int): Top k intents to return, by score 540 | approximate (bool): If True, will use approximate intent similarities to more efficiently search for matching intents 541 | softmax_scores (bool): If True, will apply softmax to the scores. This will make the scores in the topk intents sum to 1. 542 | If false, the scores will be the raw logits values for the forced aligned sequence. 543 | 544 | Returns: 545 | tuple: List of intents, scores, and durations 546 | """ 547 | # Get the logits 548 | logits = self.get_logits(audio) 549 | 550 | # Get the best matching intents 551 | if approximate is True and intents != []: 552 | # Build intent similarity matrix and cache for reuse 553 | if self.similarity_matrix is None: 554 | self.similarity_matrix = self.build_intent_similarity_matrix(intents) 555 | 556 | top_intents, scores, durations = self.match_intents_by_similarity( 557 | logits, 558 | self.similarity_matrix, intents, topk=topk, softmax_scores=softmax_scores 559 | ) 560 | 561 | elif approximate is False and intents != []: 562 | top_intents, scores, durations = self.get_forced_alignment_score(logits, intents, topk=topk, softmax_scores=softmax_scores) 563 | 564 | return top_intents, scores, durations 565 | -------------------------------------------------------------------------------- /openspeechtointent/resources/models/citrinet_spectrogram_filterbank.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dscripka/openSpeechToIntent/664ad58e7c7f8cd0babf4612720101966446eb1c/openspeechtointent/resources/models/citrinet_spectrogram_filterbank.npy -------------------------------------------------------------------------------- /openspeechtointent/resources/models/citrinet_tokenizer.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dscripka/openSpeechToIntent/664ad58e7c7f8cd0babf4612720101966446eb1c/openspeechtointent/resources/models/citrinet_tokenizer.pkl -------------------------------------------------------------------------------- /openspeechtointent/resources/models/citrinet_vocab.json: -------------------------------------------------------------------------------- 1 | ["", "s", "\u2581the", "t", "\u2581a", "\u2581i", "'", "\u2581and", "\u2581to", "ed", "d", "\u2581of", "e", "\u2581in", "ing", ".", "\u2581it", "\u2581you", "n", "\u2581that", "m", "y", "er", "\u2581he", "re", "r", "\u2581was", "\u2581is", "\u2581for", "\u2581know", "a", "p", "c", ",", "\u2581be", "o", "\u2581but", "\u2581they", "g", "\u2581so", "ly", "b", "\u2581s", "\u2581yeah", "\u2581we", "\u2581have", "\u2581re", "\u2581like", "l", "\u2581on", "ll", "u", "\u2581with", "\u2581do", "al", "\u2581not", "\u2581are", "or", "ar", "le", "\u2581this", "\u2581as", "es", "\u2581c", "\u2581de", "f", "in", "i", "ve", "\u2581uh", "ent", "\u2581or", "\u2581what", "\u2581me", "\u2581t", "\u2581at", "\u2581my", "\u2581his", "\u2581there", "w", "\u2581all", "\u2581just", "h", "\u2581can", "ri", "il", "k", "ic", "\u2581e", "\u2581", "\u2581um", "\u2581don", "\u2581b", "\u2581had", "ch", "ation", "en", "th", "\u2581no", "\u2581she", "it", "\u2581one", "\u2581think", "\u2581st", "\u2581if", "\u2581from", "ter", "\u2581an", "an", "ur", "\u2581out", "on", "\u2581go", "ck", "\u2581would", "\u2581were", "\u2581w", "\u2581will", "\u2581about", "\u2581right", "ment", "\u2581her", "te", "ion", "\u2581well", "\u2581by", "ce", "\u2581g", "\u2581oh", "\u2581up", "ro", "ra", "\u2581when", "\u2581some", "\u2581also", "\u2581their", "ers", "ow", "\u2581more", "\u2581time", "ate", "\u2581has", "\u2581people", "\u2581see", "\u2581pa", "el", "\u2581get", "\u2581ex", "\u2581mean", "li", "\u2581really", "v", "\u2581ra", "\u2581been", "\u2581said", "-", "la", "ge", "\u2581how", "\u2581po", "ir", "\u2581mo", "\u2581who", "\u2581because", "\u2581co", "\u2581other", "\u2581f", "id", "ol", "\u2581un", "\u2581now", "\u2581work", "ist", "us", "\u2581your", "\u2581them", "ver", "as", "ne", "\u2581ca", "lo", "\u2581fa", "\u2581him", "ng", "\u2581good", "\u2581could", "\u2581pro", "ive", "\u2581con", "de", "un", "age", "\u2581ma", "?", "at", "\u2581ro", "\u2581ba", "\u2581then", "\u2581com", "est", "vi", "\u2581dis", "ies", "ance", "\u2581su", "\u2581even", "\u2581any", "ut", "ad", "ul", "\u2581se", "\u2581two", "\u2581bu", "\u2581lo", "\u2581say", "\u2581la", "\u2581fi", "is", "\u2581li", "\u2581over", "\u2581new", "\u2581man", "\u2581sp", "ity", "\u2581did", "\u2581bo", "\u2581very", "x", "end", "\u2581which", "\u2581our", "\u2581after", "\u2581o", "ke", "\u2581p", "im", "\u2581want", "\u2581ha", "\u2581v", "z", "\u2581where", "ard", "um", "\u2581into", "ru", "\u2581di", "\u2581lot", "\u2581dr", "mp", "\u2581day", "ated", "ci", "\u2581these", "\u2581than", "\u2581take", "\u2581kind", "\u2581got", "ight", "\u2581make", "ence", "\u2581pre", "\u2581going", "ish", "\u2581k", "able", "\u2581look", "ti", "per", "\u2581here", "\u2581en", "\u2581ah", "ry", "\u2581too", "\u2581part", "ant", "one", "\u2581ho", "\u2581much", "\u2581way", "\u2581sa", "\u2581something", "mo", "\u2581us", "\u2581th", "\u2581mhm", "\u2581mi", "\u2581off", "pe", "\u2581back", "les", "\u2581cr", "\u2581ri", "\u2581fe", "und", "\u2581fl", "port", "\u2581school", "\u2581ch", "\u2581should", "\u2581first", "\u2581only", "\u2581le", "ot", "tion", "\u2581little", "\u2581da", "\u2581hu", "\u2581d", "me", "ta", "\u2581down", "\u2581okay", "\u2581come", "ain", "ff", "\u2581car", "co", "\u2581need", "ture", "\u2581many", "\u2581things", "\u2581ta", "qu", "man", "ty", "iv", "\u2581year", "he", "\u2581thing", "ho", "\u2581singapore", "po", "\u2581vi", "\u2581sc", "\u2581still", "der", "\u2581hi", "\u2581never", "\u2581qu", "ia", "\u2581fr", "\u2581min", "\u2581most", "om", "ful", "\u2581bi", "\u2581long", "ig", "\u2581years", "ous", "\u2581three", "\u2581play", "\u2581before", "\u2581pi", "ical", "\u2581those", "\u2581comp", "huh", "\u2581live", "tor", "ise", "\u2581old", "am", "rr", "\u2581sta", "\u2581n", "ick", "di", "ma", "ary", "ction", "\u2581friend", "ition", "\u2581gu", "\u2581through", "pp", "for", "ie", "ious", "\u2581sh", "\u2581home", "lu", "\u2581high", "ian", "cu", "\u2581help", "\u2581give", "\u2581talk", "\u2581sha", "\u2581such", "\u2581didn", "em", "\u2581may", "\u2581ga", "\u2581'", "\u2581gra", "\u2581guess", "\u2581every", "\u2581app", "tic", "\u2581tra", "\u2581\"", "op", "\u2581made", "\"", "\u2581op", "\u2581own", "\u2581mar", "no", "\u2581ph", "\u2581life", "\u2581y", "ak", "ine", "\u2581pu", "\u2581place", "\u2581always", "\u2581start", "\u2581jo", "\u2581pe", "\u2581let", "\u2581name", "ni", "\u2581same", "\u2581last", "\u2581cl", "ph", "\u2581both", "\u2581pri", "ities", "\u2581another", "and", "\u2581al", "\u2581boy", "ving", "\u2581actually", "\u2581person", "\u2581went", "\u2581yes", "ca", "ally", "\u2581h", "\u2581great", "\u2581thought", "\u2581used", "act", "\u2581feel", "ward", "\u2581different", "\u2581cons", "\u2581show", "\u2581watch", "\u2581being", "\u2581money", "ay", "\u2581try", "\u2581why", "\u2581big", "ens", "\u2581cha", "\u2581find", "\u2581hand", "\u2581real", "\u2581four", "ial", "\u2581ne", "\u2581che", "\u2581read", "\u2581five", "\u2581family", "ag", "\u2581change", "\u2581add", "ha", "\u2581put", "par", "lic", "side", "\u2581came", "\u2581under", "ness", "\u2581per", "j", "\u2581around", "\u2581end", "\u2581house", "if", "\u2581while", "vo", "\u2581act", "\u2581happen", "\u2581plan", "mit", "\u2581far", "\u2581tri", "\u2581ten", "\u2581du", "\u2581win", "\u2581tea", "ze", "\u2581better", "\u2581sure", "\u2581mu", "\u2581use", "\u2581anything", "\u2581love", "\u2581world", "\u2581hard", "ure", "\u2581does", "\u2581war", "\u2581stuff", "\u2581ja", "\u2581must", "min", "gg", "\u2581ru", "\u2581care", "\u2581tell", "\u2581pl", "\u2581doing", "\u2581probably", "\u2581found", "ative", "\u2581point", "ach", "\u2581ju", "ip", "\u2581again", "\u2581interest", "\u2581state", "\u2581week", "na", "\u2581might", "\u2581pretty", "\u2581ki", "\u2581fo", "ber", "\u2581am", "line", "led", "\u2581six", "\u2581acc", "\u2581bri", "\u2581call", "\u2581sw", "\u2581each", "\u2581business", "\u2581keep", "\u2581away", "cause", "\u2581pass", "\u2581va", "\u2581children", "\u2581pay", "\u2581count", "\u2581public", "\u2581everything", "land", "\u2581though", "\u2581men", "bo", "\u2581young", "\u2581na", "\u2581move", "ough", "ating", "com", "\u2581month", "ton", "\u2581close", "\u2581few", "!", "\u2581maybe", "\u2581imp", "son", "\u2581grow", "\u2581u", "\u2581turn", "ible", "\u2581em", "\u2581air", "\u2581ever", "our", "\u2581sea", "\u2581fun", "\u2581government", "\u2581miss", "\u2581done", "\u2581next", "\u2581kids", "\u2581cor", "\u2581set", "\u2581run", "way", "\u2581wa", "\u2581getting", "\u2581eight", "\u2581open", "\u2581job", "\u2581problem", "ook", "\u2581night", "\u2581learn", "\u2581book", "ual", "\u2581ti", "\u2581best", "cept", "\u2581during", "\u2581small", "ex", "\u2581without", "\u2581water", "\u2581trans", "\u2581course", "\u2581once", "\u2581sit", "\u2581area", "\u2581country", "\u2581mister", "\u2581nothing", "\u2581whole", "\u2581believe", "\u2581service", "\u2581took", "\u2581face", "\u2581bad", "\u2581later", "\u2581head", "\u2581called", "\u2581seven", "\u2581art", "\u2581since", "\u2581er", "\u2581fact", "\u2581city", "\u2581market", "\u2581hour", "\u2581continue", "ship", "\u2581invest", "\u2581exactly", "\u2581large", "\u2581true", "\u2581nine", "\u2581sub", "\u2581having", "\u2581game", "va", "\u2581lu", "\u2581conf", "\u2581case", "\u2581doesn", "\u2581certain", "\u2581wi", "\u2581law", "\u2581else", "fi", "\u2581left", "\u2581enough", "\u2581second", "\u2581gonna", "\u2581food", "\u2581hope", "\u2581saw", "\u2581between", "\u2581je", "bi", "\u2581girl", "\u2581company", "\u2581able", "\u2581expect", "\u2581told", "\u2581stand", "\u2581group", "\u2581main", "\u2581walk", "\u2581cause", "\u2581however", "\u2581number", "\u2581follow", "\u2581near", "\u2581yet", "\u2581sometimes", "\u2581train", "\u2581lead", "\u2581system", "\u2581remain", "\u2581develop", "gra", "\u2581word", "\u2581exc", "\u2581together", "\u2581consider", "\u2581town", "\u2581less", "ator", "\u2581important", "\u2581remember", "\u2581free", "\u2581quite", "\u2581understand", "\u2581bra", "\u2581support", "\u2581idea", "\u2581stop", "\u2581reason", "\u2581nice", "\u2581mm", "\u2581agree", "\u2581low", "\u2581against", "\u2581issue", "\u2581become", "\u2581today", "\u2581side", "\u2581student", "\u2581matter", "\u2581question", "\u2581mother", "\u2581father", "\u2581hundred", "\u2581sort", "\u2581eat", "\u2581already", "\u2581rest", "\u2581line", "\u2581asked", "\u2581include", "\u2581upon", "\u2581office", "\u2581won", "\u2581class", "\u2581wait", "\u2581twenty", "\u2581half", "\u2581light", "\u2581price", "\u2581almost", "ash", "\u2581child", "\u2581sign", "\u2581least", "\u2581several", "press", "\u2581either", "\u2581minute", "\u2581himself", "\u2581parents", "\u2581room", "\u2581whatever", "\u2581general", "\u2581cost", "\u2581among", "\u2581direct", "\u2581computer", "\u2581appear", "\u2581meet", "\u2581ski", "\u2581return", "\u2581couple", "\u2581product", "\u2581suppose", "\u2581definitely", "\u2581america", "\u2581term", "\u2581usually", "\u2581strong", "\u2581current", "\u2581arm", "\u2581speak", "\u2581local", "\u2581south", "\u2581experience", "\u2581full", "\u2581north", "\u2581elect", "\u2581leave", "\u2581provide", "qui", "\u2581power", "\u2581movie", "\u2581everyone", "\u2581making", "\u2581member", "\u2581woman", "\u2581somebody", "\u2581wonder", "\u2581short", "\u2581health", "\u2581police", "\u2581bank", "\u2581until", "\u2581companies", "\u2581everybody", "\u2581knew", "\u2581program", "\u2581music", "\u2581york", "\u2581land", "\u2581doctor", "\u2581answer", "\u2581building", "\u2581employ", "\u2581travel", "\u2581major", "\u2581seems", "\u2581safe", "gue", "\u2581college", "\u2581along", "\u2581clear", "\u2581especially", "\u2581umhu", "\u2581result", "\u2581type", "\u2581court", "\u2581black", "\u2581hold", "\u2581myself", "\u2581education", "\u2581social", "\u2581enjoy", "\u2581became", "\u2581whether", "\u2581morning", "\u2581difficult", "\u2581shi", "\u2581felt", "\u2581husband", "\u2581white", "\u2581taking", "\u2581million", "\u2581require", "\u2581early", "ency", "\u2581visit", "\u2581level", "\u2581brother", "\u2581married", "\u2581further", "\u2581affect", "\u2581serve", "\u2581present", "\u2581park", "\u2581effect", "\u2581wife", "\u2581teacher", "\u2581cannot", "\u2581community", "\u2581street", "\u2581period", "\u2581national", "\u2581view", "\u2581future", "\u2581daughter", "\u2581situation", "\u2581grand", "\u2581success", "\u2581perform", "\u2581concern", "\u2581complete", "\u2581example", "ized", "\u2581thousand", "\u2581increase", "\u2581began", "\u2581final", "\u2581east", "\u2581sense", "\u2581charge", "\u2581record", "\u2581born", "\u2581instead", "\u2581receive", "\u2581women", "\u2581across", "\u2581information", "\u2581although", "\u2581process", "\u2581condition", "\u2581security", "\u2581treat", "\u2581funny", "\u2581custom", "\u2581cold", "\u2581behind", "ified", "\u2581ground", "cycl", "\u2581depend", "\u2581themselves", "\u2581design", "\u2581slow", "\u2581third", "\u2581smoke", "\u2581wrong", "\u2581project", "\u2581space", "\u2581drink", "\u2581particular", "\u2581listen", "\u2581thirty", "\u2581special", "ability", "\u2581improve", "\u2581attack", "\u2581happy", "\u2581strange", "\u2581english", "\u2581value", "\u2581brought", "\u2581private", "\u2581account", "\u2581china", "\u2581spoke", "\u2581foreign", "\u2581possible", "\u2581author", "\u2581circ", "\u2581voice", "\u2581figure", "\u2581control", "\u2581according", "\u2581green", "\u2581university", "\u2581language", "\u2581please", "\u2581animal", "\u2581church", "\u2581society", "\u2581dream", "\u2019", "q", ":", ";", "\u2014", "\u2018", "\u201d", "_", "3", "8", "<", ">", "1", "\u2013", "7", "(", ")", "0", "2", "4", "+", "&", "5", "9", "\u00fc", "\u00e9", "/", "\u00e1", "\u00f3", "\u014d", "\u00fa", "]", "\u00e2", "\u00ed", "\u00e3", "\u00f0", "\u0101", "\u0107", "\u010d", "\u0161", "\u00e8", "\u00eb", "`", "\u00e7", "\u016b", "\u1ea1", "\u00f8", "=", "\u00e0", "\u0142", "\u03b1", "\u00f4", "\u043a", "}", "\u00e5", "\u0103", "\u0438", "\u012b", "\u03c0", "\u0153", "\\", "[", "\u00f1", "\u00df", "\u00f6", "\u00e4", "6", "\u0437", "\u043d", "\u00fb", "%", "{", "\u00a1", "\u00e6", "\u00ea", "\u00fe", "\u0119", "\u011b", "\u011f", "\u0144", "\u0151", "\u0159", "\u017e", "\u02bb", "\u0432", "\u0435", "\u0439", "\u043b", "\u044c", "\u03c7", "\u201c", ""] -------------------------------------------------------------------------------- /openspeechtointent/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 David Scripka. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # imports 16 | import os 17 | import urllib.request 18 | import sys 19 | 20 | 21 | def download_file(url, target_directory, file_size=None): 22 | """A simple function to download a file from a URL with a progress bar using only standard libraries.""" 23 | local_filename = url.split('/')[-1] 24 | file_path = os.path.join(target_directory, local_filename) 25 | 26 | # Open the URL 27 | with urllib.request.urlopen(url) as response: 28 | if file_size is None: 29 | file_size = int(response.getheader('Content-Length', 0)) 30 | 31 | # Create a progress bar 32 | print(f"\nDownloading {local_filename} ({file_size} bytes)") 33 | downloaded = 0 34 | 35 | with open(file_path, 'wb') as f: 36 | while True: 37 | chunk = response.read(8192) 38 | if not chunk: 39 | break 40 | f.write(chunk) 41 | downloaded += len(chunk) 42 | 43 | # Update progress 44 | progress = downloaded / file_size * 100 if file_size else 0 45 | sys.stdout.write(f"\rProgress: {progress:.2f}%") 46 | sys.stdout.flush() 47 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2", "wheel", "pybind11>=2.5.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.pytest.ini_options] 6 | addopts = "--cov=openspeechtointent --cov-report term-missing --flake8" 7 | flake8-max-line-length = "140" 8 | testpaths = [ 9 | "test", 10 | "openspeechtointent" 11 | ] 12 | 13 | [project] 14 | name = "openspeechtointent" 15 | description = "A simple, but performant framework for mapping speech directly to categories and intents." 16 | readme = "README.md" 17 | authors = [{name = "David Scripka", email = "david.scripka@gmail.com"}] 18 | license = {file = "LICENSE"} 19 | classifiers = [ 20 | "Development Status :: 3 - Alpha", 21 | "Intended Audience :: Developers", 22 | "License :: OSI Approved :: Apache Software License", 23 | "Operating System :: OS Independent", 24 | "Programming Language :: Python :: 3", 25 | "Programming Language :: Python :: 3.8", 26 | "Programming Language :: Python :: 3.9", 27 | "Programming Language :: Python :: 3.10", 28 | ] 29 | requires-python = ">=3.8" 30 | dependencies = [ 31 | "numpy", 32 | "onnxruntime", 33 | "sentencepiece>=0.2.0" 34 | ] 35 | dynamic = ["version"] 36 | 37 | [project.optional-dependencies] 38 | test = [ 39 | 'pytest', 40 | 'pytest-cov', 41 | 'pytest-flake8', 42 | 'flake8', 43 | 'pytest-mypy' 44 | ] 45 | 46 | [project.urls] 47 | Homepage = "https://github.com/dscripka/openspeechtointent" 48 | 49 | [tool.setuptools_scm] 50 | version_scheme = "post-release" 51 | local_scheme = "dirty-tag" 52 | 53 | [tool.setuptools] 54 | packages = ["openspeechtointent"] 55 | include-package-data = true 56 | 57 | [tool.setuptools.package-data] 58 | openspeechtointent = [ 59 | "resources/models/citrinet_spectrogram_filterbank.py", 60 | "resources/models/citrinet_tokenizer.pkl", 61 | "resources/models/citrinet_vocab.json" 62 | ] -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from pybind11.setup_helpers import Pybind11Extension, build_ext 3 | 4 | # Add configuration to build Pybind11 extension 5 | ext_modules = [ 6 | Pybind11Extension( 7 | "openspeechtointent.forced_alignment", 8 | ["openspeechtointent/forced_alignment.cpp"], 9 | extra_compile_args=["-O3"], 10 | ) 11 | ] 12 | 13 | setup( 14 | ext_modules=ext_modules, 15 | cmdclass={"build_ext": build_ext}, 16 | ) 17 | -------------------------------------------------------------------------------- /test/intents/test_intents.json: -------------------------------------------------------------------------------- 1 | { 2 | "test_intents": [ 3 | "Set a timer for 30 minutes", 4 | "Turn off all lights", 5 | "turn on the lights", 6 | "Set thermostat to 72 degrees", 7 | "Add milk to the shopping list", 8 | "Remind me to call Mom at 5 PM", 9 | "What's the weather like today?", 10 | "How's the traffic to work?", 11 | "Pause the music", 12 | "Turn on the living room TV", 13 | "Show me the front door camera", 14 | "Lock the front door", 15 | "Set an alarm for 7 AM", 16 | "Turn on kitchen lights", 17 | "Lower the temperature by 2 degrees", 18 | "Turn off the coffee maker", 19 | "Remove eggs from the shopping list", 20 | "Remind me to take out the trash tomorrow morning", 21 | "Will it rain this weekend?", 22 | "Is there a traffic jam on Highway 101?", 23 | "Skip to the next song", 24 | "Mute the TV", 25 | "Check the backyard camera", 26 | "Unlock the garage door", 27 | "Set a 15-minute meditation timer", 28 | "Dim the bedroom lights", 29 | "Increase the temperature to 75 degrees", 30 | "Turn on the dishwasher", 31 | "Add paper towels to the shopping list", 32 | "Remind me to water the plants on Wednesday", 33 | "What's the temperature outside?", 34 | "How long will it take to get to the airport?", 35 | "Play my favorite playlist", 36 | "Turn off all TVs", 37 | "Show me the baby's room camera", 38 | "Open the blinds", 39 | "Wake me up at 6:30 AM", 40 | "Turn on porch lights", 41 | "Set the AC to 68 degrees", 42 | "Turn off the oven", 43 | "Add bananas to the grocery list", 44 | "Remind me to pay the electric bill", 45 | "What's the forecast for tomorrow?", 46 | "Is there construction on my route to work?", 47 | "Increase the volume by 20%", 48 | "Change the TV channel to NBC", 49 | "Check the driveway camera", 50 | "Lock all doors", 51 | "Set a timer for 45 minutes", 52 | "Turn on hallway lights", 53 | "Lower the blinds", 54 | "Turn on the fan", 55 | "Remove coffee from the shopping list", 56 | "Remind me to pick up dry cleaning", 57 | "What's the humidity level today?", 58 | "How's the traffic to downtown?", 59 | "Shuffle my workout playlist", 60 | "Pause the movie", 61 | "Show me the side gate camera", 62 | "Unlock the back door", 63 | "Set an alarm for 8:15 AM", 64 | "Turn off bathroom lights", 65 | "Set the heat to 70 degrees", 66 | "Turn on the humidifier", 67 | "Add toothpaste to the shopping list", 68 | "Remind me of my dentist appointment tomorrow", 69 | "Will it be windy today?", 70 | "What's the fastest route to the mall?", 71 | "Play the next episode", 72 | "Lower TV volume", 73 | "Check the garage camera", 74 | "Open the front door", 75 | "Set a 5-minute snooze", 76 | "Turn on reading light", 77 | "Increase thermostat by 3 degrees", 78 | "Turn off the vacuum", 79 | "Add bread to the grocery list", 80 | "Remind me to call the doctor at 2 PM", 81 | "What's the UV index for today?", 82 | "Is there an accident on I-95?", 83 | "Rewind 30 seconds", 84 | "Switch TV input to HDMI 2", 85 | "Show me all cameras", 86 | "Lock the patio door", 87 | "Set a countdown for New Year's", 88 | "Turn on closet light", 89 | "Set temperature to eco mode", 90 | "Turn on the air purifier", 91 | "Remove apples from the shopping list", 92 | "Remind me to buy a birthday gift", 93 | "What time will the sun set today?", 94 | "How long is my commute with current traffic?", 95 | "Play white noise", 96 | "Turn off the living room TV in 30 minutes", 97 | "Check the front yard camera", 98 | "Unlock all doors", 99 | "Wake me up with nature sounds at 7:30 AM", 100 | "Turn off all downstairs lights", 101 | "Set the thermostat to night mode", 102 | "Turn on the slow cooker", 103 | "Add laundry detergent to the shopping list", 104 | "Remind me to start the laundry after work", 105 | "What's the air quality index today?", 106 | "Is there a delay on the subway?", 107 | "Fast forward 2 minutes", 108 | "Adjust TV picture settings", 109 | "Show me the pool camera", 110 | "Close the garage door", 111 | "Set a timer for 2 hours", 112 | "Turn on accent lighting", 113 | "Set fan speed to medium", 114 | "Turn off the printer", 115 | "Add dog food to the grocery list", 116 | "Remind me to charge my phone before bed", 117 | "Will it snow this week?", 118 | "What's the best time to leave for my appointment?", 119 | "Play my sleep sounds playlist", 120 | "Record the current TV show", 121 | "Check the basement camera", 122 | "Lock the side door", 123 | "Set an alarm for weekdays at 6:45 AM", 124 | "Turn on outdoor lights at sunset", 125 | "Set the thermostat to vacation mode", 126 | "Turn on the space heater", 127 | "Remove milk from the shopping list", 128 | "Remind me to take my medication at 9 AM", 129 | "What's the pollen count today?", 130 | "How's the traffic on the bridge?", 131 | "Stop playing music in 1 hour", 132 | "Turn on closed captions", 133 | "Show me the doorbell camera", 134 | "Unlock the front door for 5 minutes", 135 | "Set a wake-up routine for 7 AM", 136 | "Turn off lights in empty rooms", 137 | "Set the AC to energy-saving mode", 138 | "Turn on the electric blanket", 139 | "Add batteries to the shopping list", 140 | "Remind me to send a birthday card", 141 | "What's the chance of rain tomorrow?", 142 | "Is there a better route to avoid traffic?", 143 | "Play my morning news briefing", 144 | "Set TV sleep timer for 1 hour", 145 | "Check all door locks", 146 | "Open the garage door halfway", 147 | "Set a gentle alarm for 8 AM", 148 | "Sync all light colors", 149 | "Set thermostat schedule for the week", 150 | "Turn on the dehumidifier", 151 | "Add trash bags to the grocery list", 152 | "Remind me to update my calendar", 153 | "What time will it stop raining?", 154 | "How long will it take to walk to the park?", 155 | "Play relaxing music", 156 | "Turn off the TV after this show", 157 | "Show me security camera history", 158 | "Lock doors at 10 PM", 159 | "Set a power nap timer for 20 minutes", 160 | "Turn on party mode lighting", 161 | "Set temperature based on occupancy", 162 | "Turn on the towel warmer", 163 | "Remove soap from the shopping list", 164 | "Remind me to check the mailbox", 165 | "What's the wind chill factor?", 166 | "Is there a traffic alert on my route?", 167 | "Resume podcast where I left off", 168 | "Enable motion detection on all cameras", 169 | "Unlock the door for the dog walker", 170 | "Set a recurring alarm for Mondays at 7 AM", 171 | "Turn on night light", 172 | "Set thermostat to 68 when I leave", 173 | "Turn on the ceiling fan", 174 | "Add lightbulbs to the shopping list", 175 | "Remind me to water the garden on Saturday", 176 | "What's the heat index today?", 177 | "How's the parking situation downtown?", 178 | "Create a playlist of my most played songs", 179 | "Switch TV to gaming mode", 180 | "Show me a live view of all cameras", 181 | "Set the front door to auto-lock", 182 | "Schedule vacuum to run at 2 PM", 183 | "Sync lights with sunrise", 184 | "Set bedroom temperature to 65 at night", 185 | "Turn on the oil diffuser", 186 | "Remove pasta from the grocery list", 187 | "Remind me to defrost the chicken", 188 | "What time does the sun rise tomorrow?", 189 | "What's the estimated drive time to the beach?", 190 | "Play ambient sounds in the bedroom", 191 | "Turn off all devices downstairs", 192 | "Rotate outdoor cameras", 193 | "Temporarily disable door alerts", 194 | "Set a recurring reminder for trash day", 195 | "Create a romantic lighting scene", 196 | "Adjust thermostat when windows are open", 197 | "Turn on the heated floor", 198 | "Add paper to the shopping list", 199 | "Remind me to return library books", 200 | "What's the visibility for driving?", 201 | "Are there any road closures nearby?", 202 | "Start my workout playlist in 10 minutes", 203 | "Enable parental controls on the TV", 204 | "Show me camera alerts from today", 205 | "Set guest access code for the front door", 206 | "Dim living room to movie mode", 207 | "Crank up the AC, it's boiling in here", 208 | "Add avocados to my Whole Foods order", 209 | "Ping me in an hour about the laundry", 210 | "What's the weather looking like for my weekend hike?", 211 | "Any accidents on my commute route?", 212 | "Pump up the jams in the kitchen", 213 | "Flip to the news on the bedroom TV", 214 | "Give me a peek at the backyard", 215 | "Seal up the house, we're leaving", 216 | "Wake me with my favorite playlist at 6:45", 217 | "Make the bathroom cozy for my shower", 218 | "Kick the heat up a notch in the basement", 219 | "Fire up the coffee maker", 220 | "Scratch oranges off the grocery list", 221 | "Nudge me about Mom's birthday next week", 222 | "Is it umbrella weather today?", 223 | "How's the rush hour looking?", 224 | "Queue up my 'Chill Vibes' playlist", 225 | "Kill the sound on all TVs", 226 | "Let's see who's at the front door", 227 | "Pop open the garage, I'm almost home", 228 | "I need a 25-minute power nap timer", 229 | "Create a romantic ambiance in the bedroom", 230 | "Cool things down for bedtime", 231 | "Get the robot vacuum going in the living room", 232 | "We need more milk, add it to the list", 233 | "Buzz me when it's time for my meeting", 234 | "What's the forecast for my camping trip?", 235 | "Is the highway backed up?", 236 | "Let's rock the house with my party mix", 237 | "Zap the TV off in 30 minutes", 238 | "I want to check on Fluffy in the kitchen", 239 | "Time to let some fresh air in", 240 | "Ease me awake at 7:15 tomorrow", 241 | "Light up the path to the bathroom", 242 | "Make it toasty in here", 243 | "Kick off the sprinklers for 20 minutes", 244 | "Jot down 'birthday candles' on the shopping list", 245 | "Don't let me forget my dentist appointment", 246 | "Any chance of a white Christmas?", 247 | "What's the quickest way to the airport right now?", 248 | "Blast my workout playlist in 5 minutes", 249 | "Boost the bass on the living room speakers", 250 | "Give me a 360 view of the house", 251 | "Batten down the hatches for the night", 252 | "I could use a 10-minute breather timer", 253 | "Make the kitchen bright and cheery", 254 | "Let's get tropical in here", 255 | "Rev up the fan in the office", 256 | "Cross off 'shampoo' from the shopping list", 257 | "Prod me about the school bake sale tomorrow", 258 | "How's the air quality for my run?", 259 | "Are the trains running on time?", 260 | "Can you find my 'Focus' playlist?", 261 | "Hush all the devices downstairs", 262 | "Show me what's happening in the driveway", 263 | "Unlock the side gate for the gardener", 264 | "Give me a gentle wake-up call at sunrise", 265 | "Kill the lights downstairs", 266 | "Warm up the house for when I'm back", 267 | "Let's get the slow cooker bubbling", 268 | "We're out of cereal, add it to the list", 269 | "Flag me to call Grandma this evening", 270 | "What's the UV situation for the beach today?", 271 | "How long will it take to get to the concert?", 272 | "Spin my 'Throwback Thursday' playlist", 273 | "Pipe down the volume on all speakers", 274 | "Let's see who's been at the door today", 275 | "Time to open up shop", 276 | "Hit me with a 5-minute warning before my show", 277 | "Make it disco time in the game room", 278 | "Chill the bedroom for sleeping", 279 | "Juice up the electric car", 280 | "We need more dog treats on the list", 281 | "Remind me about the parent-teacher conference", 282 | "Is it a good day for hang-drying laundry?", 283 | "What's traffic like around the stadium?", 284 | "Let's get some classical music flowing", 285 | "Freeze the TV for a sec", 286 | "I want to check on the baby", 287 | "Secure the fort, we're hitting the hay", 288 | "How about a 15-minute yoga timer?", 289 | "Give the living room a sunset glow", 290 | "Make it feel like autumn in here", 291 | "Get the robo-mop dancing in the kitchen", 292 | "Jot down 'light bulbs' on the to-buy list", 293 | "Buzz me to water the plants in the morning", 294 | "What's the wind situation for sailing?", 295 | "Is there a parade route I should avoid?", 296 | "Spin up my 'Cooking with Jazz' playlist", 297 | "Fade out the TV audio in 20 minutes", 298 | "Let's peek at the pet cam", 299 | "Time to roll up the garage door", 300 | "I need a power hour timer", 301 | "Make the patio perfect for evening drinks", 302 | "Prep the house for my arrival", 303 | "Fire up the white noise machine", 304 | "Add sunscreen to the beach day list", 305 | "Ping me about the cake in the oven", 306 | "How's the visibility for stargazing tonight?", 307 | "What's my ETA to the gym with current traffic?", 308 | "Queue up my bedtime story podcast", 309 | "Mute all notifications for an hour", 310 | "Give me a house-wide security sweep", 311 | "Grant access to the cleaning service", 312 | "Ease me into Monday with a gradual wake-up", 313 | "Create a productivity atmosphere in the office", 314 | "Adjust for the incoming cold front", 315 | "Kick on the bathroom fan", 316 | "We're low on coffee, add it to the list", 317 | "Flag me to pick up the dry cleaning", 318 | "What's the allergy forecast for today?", 319 | "Are there any detours on my usual route?", 320 | "Let's get some nature sounds in the yoga room", 321 | "Prep the TV for movie night", 322 | "I want to see all entry points", 323 | "Shut it down, we're going on vacation", 324 | "Give me a 30-second countdown timer", 325 | "Make it feel like a spring day inside", 326 | "Optimize the AC for my workout", 327 | "Let's get the air purifier humming", 328 | "Pencil in 'birthday card' on the shopping list", 329 | "Nudge me about the car's oil change", 330 | "How's the forecast looking for the outdoor wedding?", 331 | "What's the least congested route to downtown?", 332 | "Time for my 'Motivation Monday' playlist", 333 | "Dial down the brightness on all screens", 334 | "Show me today's security footage", 335 | "Grant temporary access to the dog walker", 336 | "I need a progressive alarm starting at 7", 337 | "Sync up all the clocks in the house", 338 | "Prepare the house for winter", 339 | "Activate the aromatherapy diffuser", 340 | "We're out of toothpaste, add it to the list", 341 | "Remind me to send that important email", 342 | "What's the precipitation chance for my garden party?", 343 | "How's the parking situation at the mall?", 344 | "Cue up my 'Road Trip' playlist", 345 | "Enable subtitles on all TVs", 346 | "Let's do a perimeter check", 347 | "Time to close up the pool for the night", 348 | "Set a Pomodoro timer for 25 minutes", 349 | "Create a zen atmosphere in the meditation room", 350 | "Adjust the house for the heatwave", 351 | "Get the white noise machine going in the nursery", 352 | "Add 'thank you cards' to my to-do list", 353 | "Buzz me when it's time to leave for the airport", 354 | "What's the fire risk for camping this weekend?", 355 | "Are there any road works on my route to the client?", 356 | "Let's get some lo-fi beats in the study", 357 | "Prep the gaming room for tonight's session", 358 | "Give me a full house status report", 359 | "Initiate nighttime security protocol", 360 | "I need a 2-minute toothbrushing timer", 361 | "Create a focus-friendly environment in the office", 362 | "Optimize the house for energy savings", 363 | "Freshen up the air in here", 364 | "We need more recycling bags on the list", 365 | "Ping me about the parent-teacher meeting", 366 | "How's the marine layer for my beach day?", 367 | "What's the best time to avoid gym crowds?", 368 | "Shuffle my 'Indie Discoveries' playlist", 369 | "Optimize TV settings for sports", 370 | "Let's see who's been in the backyard", 371 | "Prepare the house for the dog sitter", 372 | "Wake me up when the sun rises", 373 | "Synchronize all smart bulbs", 374 | "Get the house ready for the cold snap", 375 | "Activate the towel warmer in 20 minutes", 376 | "Add 'batteries' to the hardware store list", 377 | "Remind me to schedule a haircut", 378 | "What's the forecast for the farmers market?", 379 | "How's the traffic to the concert venue?", 380 | "Play my 'Sunday Morning' playlist", 381 | "Optimize the TV for colorblind viewing", 382 | "Show me all motion alerts from today", 383 | "Revoke access for the old cleaning service", 384 | "Give me a staged alarm from 6:30 to 7", 385 | "Make the house feel like a beach resort", 386 | "Prep the sunroom for my yoga session", 387 | "Get the humidifier going in the bedroom", 388 | "Add 'light bulbs' to the home improvement list", 389 | "Nudge me to call the vet tomorrow", 390 | "How's the smog situation today?", 391 | "What's the best route for scenic driving?", 392 | "Queue up my 'Dinner Party' playlist", 393 | "Adjust all screens for night mode", 394 | "Let's see who's been at the back door", 395 | "Secure all windows, a storm's coming", 396 | "Set a timer for six thirty AM", 397 | "stop the music", 398 | "check the status of refrigerator", 399 | "how much coffee is left", 400 | "boil water for tea", 401 | "clean the floors in the dining room", 402 | "let me know when the washer is done" 403 | ] 404 | } -------------------------------------------------------------------------------- /test/resources/sample_1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dscripka/openSpeechToIntent/664ad58e7c7f8cd0babf4612720101966446eb1c/test/resources/sample_1.wav --------------------------------------------------------------------------------