├── .github ├── PULL_REQUEST_TEMPLATE.md ├── actions │ └── moshi_build │ │ └── action.yml └── workflows │ └── precommit.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .vscode └── settings.json ├── CONTRIBUTING.md ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md ├── audio ├── bria.mp3 ├── loona.mp3 └── sample_fr_hibiki_crepes.mp3 ├── configs ├── config-stt-en-hf.toml ├── config-stt-en_fr-hf.toml └── config-tts.toml ├── scripts ├── stt_evaluate_on_dataset.py ├── stt_from_file_mlx.py ├── stt_from_file_pytorch.py ├── stt_from_file_rust_server.py ├── stt_from_file_with_prompt_pytorch.py ├── stt_from_mic_mlx.py ├── stt_from_mic_rust_server.py ├── tts_mlx.py ├── tts_pytorch.py └── tts_rust_server.py ├── stt-rs ├── Cargo.lock ├── Cargo.toml └── src │ └── main.rs ├── stt_pytorch.ipynb └── tts_pytorch.ipynb /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Checklist 2 | 3 | - [ ] Read CONTRIBUTING.md, and accept the CLA by including the provided snippet. We will not accept PR without this. 4 | - [ ] Run pre-commit hook. 5 | - [ ] If you changed Rust code, run `cargo check`, `cargo clippy`, `cargo test`. 6 | 7 | ## PR Description 8 | 9 | <!-- Description for the PR --> 10 | -------------------------------------------------------------------------------- /.github/actions/moshi_build/action.yml: -------------------------------------------------------------------------------- 1 | name: moshi_build 2 | description: 'Build env.' 3 | runs: 4 | using: "composite" 5 | steps: 6 | - uses: actions/setup-python@v2 7 | with: 8 | python-version: '3.10.14' 9 | - uses: actions/cache@v3 10 | id: cache 11 | with: 12 | path: env 13 | key: env-${{ hashFiles('moshi/pyproject.toml') }} 14 | - name: Install dependencies 15 | if: steps.cache.outputs.cache-hit != 'true' 16 | shell: bash 17 | run: | 18 | python3 -m venv env 19 | . env/bin/activate 20 | python -m pip install --upgrade pip 21 | pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cpu 22 | pip install moshi==0.2.7 23 | pip install pre-commit 24 | - name: Setup env 25 | shell: bash 26 | run: | 27 | source env/bin/activate 28 | pre-commit install 29 | -------------------------------------------------------------------------------- /.github/workflows/precommit.yml: -------------------------------------------------------------------------------- 1 | name: precommit 2 | on: 3 | push: 4 | branches: [ main ] 5 | pull_request: 6 | branches: [ main ] 7 | 8 | jobs: 9 | run_precommit: 10 | name: Run precommit 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v2 14 | - uses: ./.github/actions/moshi_build 15 | - run: | 16 | source env/bin/activate 17 | pre-commit run --all-files 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Abstra 171 | # Abstra is an AI-powered process automation framework. 172 | # Ignore directories containing user credentials, local state, and settings. 173 | # Learn more at https://abstra.io/docs 174 | .abstra/ 175 | 176 | # Visual Studio Code 177 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore 178 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore 179 | # and can be added to the global gitignore or merged into this file. However, if you prefer, 180 | # you could uncomment the following to ignore the enitre vscode folder 181 | # .vscode/ 182 | 183 | # Ruff stuff: 184 | .ruff_cache/ 185 | 186 | # PyPI configuration file 187 | .pypirc 188 | 189 | # Cursor 190 | # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to 191 | # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data 192 | # refer to https://docs.cursor.com/context/ignore-files 193 | .cursorignore 194 | .cursorindexingignore 195 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # Get rid of Jupyter Notebook output because we don't want to keep it in Git 3 | - repo: https://github.com/kynan/nbstripout 4 | rev: 0.8.1 5 | hooks: 6 | - id: nbstripout 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v5.0.0 9 | hooks: 10 | - id: check-added-large-files 11 | args: ["--maxkb=2048"] 12 | - repo: https://github.com/astral-sh/ruff-pre-commit 13 | # Ruff version. 14 | rev: v0.11.7 15 | hooks: 16 | # Run the linter. 17 | - id: ruff 18 | types_or: [python, pyi] # Don't run on `jupyter` files 19 | args: [--fix] 20 | # Run the formatter. 21 | - id: ruff-format 22 | types_or: [python, pyi] # Don't run on `jupyter` files 23 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.analysis.typeCheckingMode": "standard" 3 | } 4 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Delayed-Streams-Modeling 2 | 3 | ## Pull Requests 4 | 5 | Delayed-Streams-Modeling is the implementation of a research paper. 6 | Therefore, we do not plan on accepting many pull requests for new features. 7 | However, we certainly welcome them for bug fixes. 8 | 9 | 1. Fork the repo and create your branch from `main`. 10 | 2. If you have changed APIs, update the documentation accordingly. 11 | 3. Ensure pre-commit hooks pass properly, in particular the linting and typing. 12 | 4. When changing the Rust code, run `cargo check`, `cargo clippy`, `cargo test`. 13 | 5. Accept the Contributor License Agreement (see after). 14 | 15 | Note that in general, we will not accept refactoring of the code. 16 | 17 | 18 | ## Contributor License Agreement ("CLA") 19 | 20 | In order to accept your pull request, we need you to submit a Contributor License Agreement. 21 | 22 | If you agree with the full CLA provided in the next paragraph, copy the following statement in your PR, changing your Github Handle: 23 | 24 | > I, {your GitHub handle}, confirm that I have read and understood the terms of the CLA of Kyutai-labs, as outlined in the repository's CONTRIBUTING.md, and I agree to be bound by these terms. 25 | 26 | The full CLA is provided as follows: 27 | 28 | > I, {your GitHub handle}, hereby grant to Kyutai-labs a perpetual, worldwide, non-exclusive, royalty-free, 29 | > irrevocable license to use, modify, distribute, and sublicense my Contributions. 30 | 31 | > I understand and accept that Contributions are limited to modifications, improvements, or changes 32 | > to the project’s source code submitted via pull requests. I accept that Kyutai-labs has full discretion to 33 | > review, accept, reject, or request changes to any Contributions I submit, and that submitting 34 | > a pull request does not guarantee its inclusion in the project. 35 | 36 | > By submitting a Contribution, I grant Kyutai-labs a perpetual, worldwide license to use, modify, 37 | > reproduce, distribute, and create derivative works based on my Contributions. 38 | > I also agree to assign all patent rights for any inventions or improvements that arise from my Contributions, 39 | > giving the Kyutai-labs full rights to file for and enforce patents. 40 | > I understand that the Kyutai-labs may commercialize, relicense, or exploit the project and my Contributions without further notice or obligation to me. 41 | > I confirm that my Contributions are original and that I have the legal right to grant this license. 42 | > If my Contributions include third-party materials, I will ensure that I have the necessary permissions 43 | > and will disclose this information. I accept that once my Contributions are integrated, they may be altered or removed at the Kyutai-labs’s discretion. 44 | 45 | > I acknowledge that I am making these Contributions voluntarily and will not receive any compensation. 46 | > Furthermore, I understand that all Contributions, including mine, are provided on an "as-is" basis, with no warranties. 47 | > By submitting a pull request, I agree to be bound by these terms. 48 | 49 | ## Issues 50 | 51 | Please submit issues on our Github repository. 52 | 53 | ## License 54 | 55 | By contributing to Delayed-Streams-Modeling, you agree that your contributions 56 | will be licensed under the LICENSE-* files in the root directory of this source 57 | tree. In particular, the rust code is licensed under APACHE, and the python code 58 | under MIT. 59 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | Permission is hereby granted, free of charge, to any 2 | person obtaining a copy of this software and associated 3 | documentation files (the "Software"), to deal in the 4 | Software without restriction, including without 5 | limitation the rights to use, copy, modify, merge, 6 | publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software 8 | is furnished to do so, subject to the following 9 | conditions: 10 | 11 | The above copyright notice and this permission notice 12 | shall be included in all copies or substantial portions 13 | of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 16 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 17 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 18 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 19 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 20 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 21 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 22 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 23 | DEALINGS IN THE SOFTWARE. 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Delayed Streams Modeling: Kyutai STT & TTS 2 | 3 | This repo contains instructions and examples of how to run 4 | [Kyutai Speech-To-Text](#kyutai-speech-to-text) 5 | and [Kyutai Text-To-Speech](#kyutai-text-to-speech) models. 6 | These models are powered by delayed streams modeling (DSM), 7 | a flexible formulation for streaming, multimodal sequence-to-sequence learning. 8 | See also [Unmute](https://github.com/kyutai-labs/unmute), an voice AI system built using Kyutai STT and Kyutai TTS. 9 | 10 | But wait, what is "Delayed Streams Modeling"? It is a technique for solving many streaming X-to-Y tasks (with X, Y in `{speech, text}`) 11 | that formalize the approach we had with Moshi and Hibiki. A pre-print paper is coming soon! 12 | 13 | ## Kyutai Speech-To-Text 14 | 15 | <a href="https://huggingface.co/collections/kyutai/speech-to-text-685403682cf8a23ab9466886" target="_blank" style="margin: 2px;"> 16 | <img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-KyutaiSTT-blue" style="display: inline-block; vertical-align: middle;"/> 17 | </a> 18 | <a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/stt_pytorch.ipynb"> 19 | <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> 20 | </a> 21 | 22 | **More details can be found on the [project page](https://kyutai.org/next/stt).** 23 | 24 | Kyutai STT models are optimized for real-time usage, can be batched for efficiency, and return word level timestamps. 25 | We provide two models: 26 | - `kyutai/stt-1b-en_fr`, an English and French model with ~1B parameters, a 0.5 second delay, and a [semantic VAD](https://kyutai.org/next/stt#semantic-vad). 27 | - `kyutai/stt-2.6b-en`, an English-only model with ~2.6B parameters and a 2.5 second delay. 28 | 29 | These speech-to-text models have several advantages: 30 | - Streaming inference: the models can process audio in chunks, which allows 31 | for real-time transcription, and is great for interactive applications. 32 | - Easy batching for maximum efficiency: a H100 can process 400 streams in 33 | real-time. 34 | - They return word-level timestamps. 35 | - The 1B model has a semantic Voice Activity Detection (VAD) component that 36 | can be used to detect when the user is speaking. This is especially useful 37 | for building voice agents. 38 | 39 | ### Implementations overview 40 | 41 | We provide different implementations of Kyutai STT for different use cases. 42 | Here is how to choose which one to use: 43 | 44 | - **PyTorch: for research and tinkering.** 45 | If you want to call the model from Python for research or experimentation, use our PyTorch implementation. 46 | - **Rust: for production.** 47 | If you want to serve Kyutai STT in a production setting, use our Rust server. 48 | Our robust Rust server provides streaming access to the model over websockets. 49 | We use this server to run [Unmute](https://unmute.sh/); on a L40S GPU, we can serve 64 simultaneous connections at a real-time factor of 3x. 50 | - **MLX: for on-device inference on iPhone and Mac.** 51 | MLX is Apple's ML framework that allows you to use hardware acceleration on Apple silicon. 52 | If you want to run the model on a Mac or an iPhone, choose the MLX implementation. 53 | 54 | <details> 55 | <summary>PyTorch implementation</summary> 56 | <a href="https://huggingface.co/kyutai/stt-2.6b-en" target="_blank" style="margin: 2px;"> 57 | <img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/> 58 | </a> 59 | <a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/stt_pytorch.ipynb"> 60 | <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> 61 | </a> 62 | 63 | For an example of how to use the model in a way where you can directly stream in PyTorch tensors, 64 | [see our Colab notebook](https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/stt_pytorch.ipynb). 65 | 66 | This requires the [moshi package](https://pypi.org/project/moshi/) 67 | with version 0.2.6 or later, which can be installed via pip. 68 | 69 | If you just want to run the model on a file, you can use `moshi.run_inference`. 70 | 71 | ```bash 72 | python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en audio/bria.mp3 73 | ``` 74 | 75 | If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step 76 | and just prefix the command above with `uvx --with moshi`. 77 | 78 | Additionally, we provide two scripts that highlight different usage scenarios. The first script illustrates how to extract word-level timestamps from the model's outputs: 79 | 80 | ```bash 81 | uv run \ 82 | scripts/stt_from_file_pytorch.py \ 83 | --hf-repo kyutai/stt-2.6b-en \ 84 | audio/bria.mp3 85 | ``` 86 | 87 | The second script can be used to run a model on an existing Hugging Face dataset and calculate its performance metrics: 88 | ```bash 89 | uv run scripts/evaluate_on_dataset.py \ 90 | --dataset meanwhile \ 91 | --hf-repo kyutai/stt-2.6b-en 92 | ``` 93 | 94 | Another example shows how one can provide a text-, audio-, or text-audio prompt to our STT model: 95 | ```bash 96 | uv run scripts/stt_from_file_pytorch_with_prompt.py \ 97 | --hf-repo kyutai/stt-2.6b-en \ 98 | --file bria.mp3 \ 99 | --prompt_file ./audio/loonah.mp3 \ 100 | --prompt_text "Loonah" \ 101 | --cut-prompt-transcript 102 | ``` 103 | Produces the transcript of `bria.mp3` using the `Loonah` spelling for the name, instead of the `Luna` used without any prompt: 104 | ``` 105 | In the heart of an ancient forest, where the trees whispered secrets of the past, there lived a peculiar rabbit named Loonah (...) 106 | ``` 107 | 108 | Apart from nudging the model for a specific spelling of a word, other potential use-cases include speaker adaptation and steering the model towards a specific formatting style or even a language. 109 | However, please bear in mind that is an experimental feature and its behavior is very sensitive to the prompt provided. 110 | </details> 111 | 112 | <details> 113 | <summary>Rust server</summary> 114 | 115 | <a href="https://huggingface.co/kyutai/stt-2.6b-en-candle" target="_blank" style="margin: 2px;"> 116 | <img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/> 117 | </a> 118 | 119 | The Rust implementation provides a server that can process multiple streaming 120 | queries in parallel. Depending on the amount of memory on your GPU, you may 121 | have to adjust the batch size from the config file. For a L40S GPU, a batch size 122 | of 64 works well and requests can be processed at 3x real-time speed. 123 | 124 | In order to run the server, install the [moshi-server 125 | crate](https://crates.io/crates/moshi-server) via the following command. The 126 | server code can be found in the 127 | [kyutai-labs/moshi](https://github.com/kyutai-labs/moshi/tree/main/rust/moshi-server) 128 | repository. 129 | ```bash 130 | cargo install --features cuda moshi-server 131 | ``` 132 | 133 | Then the server can be started via the following command using the config file 134 | from this repository. 135 | For `kyutai/stt-1b-en_fr`, use `configs/config-stt-en_fr.hf.toml`, 136 | and for `kyutai/stt-2.6b-en`, use `configs/config-stt-en-hf.toml`, 137 | 138 | ```bash 139 | moshi-server worker --config configs/config-stt-en_fr-hf.toml 140 | ``` 141 | 142 | Once the server has started you can transcribe audio from your microphone with the following script. 143 | ```bash 144 | uv run scripts/stt_from_mic_rust_server.py 145 | ``` 146 | 147 | We also provide a script for transcribing from an audio file. 148 | ```bash 149 | uv run scripts/stt_from_file_rust_server.py audio/bria.mp3 150 | ``` 151 | 152 | The script limits the decoding speed to simulates real-time processing of the audio. 153 | Faster processing can be triggered by setting 154 | the real-time factor, e.g. `--rtf 1000` will process 155 | the data as fast as possible. 156 | </details> 157 | 158 | <details> 159 | <summary>Rust standalone</summary> 160 | <a href="https://huggingface.co/kyutai/stt-2.6b-en-candle" target="_blank" style="margin: 2px;"> 161 | <img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/> 162 | </a> 163 | 164 | A standalone Rust example script is provided in the `stt-rs` directory in this repo. 165 | This can be used as follows: 166 | ```bash 167 | cd stt-rs 168 | cargo run --features cuda -r -- ../audio/bria.mp3 169 | ``` 170 | You can get the timestamps by adding the `--timestamps` flag, and see the output 171 | of the semantic VAD by adding the `--vad` flag. 172 | </details> 173 | 174 | <details> 175 | <summary>MLX implementation</summary> 176 | <a href="https://huggingface.co/kyutai/stt-2.6b-en-mlx" target="_blank" style="margin: 2px;"> 177 | <img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/> 178 | </a> 179 | 180 | [MLX](https://ml-explore.github.io/mlx/build/html/index.html) is Apple's ML framework that allows you to use 181 | hardware acceleration on Apple silicon. 182 | 183 | This requires the [moshi-mlx package](https://pypi.org/project/moshi-mlx/) 184 | with version 0.2.6 or later, which can be installed via pip. 185 | 186 | If you just want to run the model on a file, you can use `moshi_mlx.run_inference`: 187 | 188 | ```bash 189 | python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx audio/bria.mp3 --temp 0 190 | ``` 191 | 192 | If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step 193 | and just prefix the command above with `uvx --with moshi-mlx`. 194 | 195 | If you want to transcribe audio from your microphone, use: 196 | 197 | ```bash 198 | python scripts/stt_from_mic_mlx.py 199 | ``` 200 | 201 | The MLX models can also be used in swift using the [moshi-swift 202 | codebase](https://github.com/kyutai-labs/moshi-swift), the 1b model has been 203 | tested to work fine on an iPhone 16 Pro. 204 | </details> 205 | 206 | ## Kyutai Text-to-Speech 207 | 208 | <a href="https://huggingface.co/collections/kyutai/text-to-speech-6866192e7e004ed04fd39e29" target="_blank" style="margin: 2px;"> 209 | <img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-KyutaiTTS-blue" style="display: inline-block; vertical-align: middle;"/> 210 | </a> 211 | <a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/tts_pytorch.ipynb"> 212 | <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> 213 | </a> 214 | 215 | **More details can be found on the [project page](https://kyutai.org/next/tts).** 216 | 217 | We provide different implementations of Kyutai TTS for different use cases. Here is how to choose which one to use: 218 | 219 | - PyTorch: for research and tinkering. If you want to call the model from Python for research or experimentation, use our PyTorch implementation. 220 | - Rust: for production. If you want to serve Kyutai TTS in a production setting, use our Rust server. Our robust Rust server provides streaming access to the model over websockets. We use this server to run Unmute. 221 | - MLX: for on-device inference on iPhone and Mac. MLX is Apple's ML framework that allows you to use hardware acceleration on Apple silicon. If you want to run the model on a Mac or an iPhone, choose the MLX implementation. 222 | 223 | <details> 224 | <summary>PyTorch implementation</summary> 225 | 226 | <a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/tts_pytorch.ipynb"> 227 | <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> 228 | </a> 229 | 230 | Check out our [Colab notebook](https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/tts_pytorch.ipynb) or use the script: 231 | 232 | ```bash 233 | # From stdin, plays audio immediately 234 | echo "Hey, how are you?" | python scripts/tts_pytorch.py - - 235 | 236 | # From text file to audio file 237 | python scripts/tts_pytorch.py text_to_say.txt audio_output.wav 238 | ``` 239 | 240 | This requires the [moshi package](https://pypi.org/project/moshi/), which can be installed via pip. 241 | If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step 242 | and just prefix the command above with `uvx --with moshi`. 243 | </details> 244 | 245 | <details> 246 | <summary>Rust server</summary> 247 | 248 | 249 | The Rust implementation provides a server that can process multiple streaming 250 | queries in parallel. 251 | 252 | Installing the Rust server is a bit tricky because it uses our Python implementation under the hood, 253 | which also requires installing the Python dependencies. 254 | Use the [start_tts.sh](https://github.com/kyutai-labs/unmute/blob/main/dockerless/start_tts.sh) script to properly install the Rust server. 255 | If you already installed the `moshi-server` crate before and it's not working, you might need to force a reinstall by running `cargo uninstall moshi-server` first. 256 | Feel free to open an issue if the installation is still broken. 257 | 258 | Once installed, the server can be started via the following command using the config file 259 | from this repository. 260 | 261 | ```bash 262 | moshi-server worker --config configs/config-tts.toml 263 | ``` 264 | 265 | Once the server has started you can connect to it using our script as follows: 266 | ```bash 267 | # From stdin, plays audio immediately 268 | echo "Hey, how are you?" | python scripts/tts_rust_server.py - - 269 | 270 | # From text file to audio file 271 | python scripts/tts_rust_server.py text_to_say.txt audio_output.wav 272 | ``` 273 | </details> 274 | 275 | <details> 276 | <summary>MLX implementation</summary> 277 | 278 | [MLX](https://ml-explore.github.io/mlx/build/html/index.html) is Apple's ML framework that allows you to use 279 | hardware acceleration on Apple silicon. 280 | 281 | Use our example script to run Kyutai TTS on MLX. 282 | The script takes text from stdin or a file and can output to a file or stream the resulting audio. 283 | When streaming the output, if the model is not fast enough to keep with 284 | real-time, you can use the `--quantize 8` or `--quantize 4` flags to quantize 285 | the model resulting in faster inference. 286 | 287 | ```bash 288 | # From stdin, plays audio immediately 289 | echo "Hey, how are you?" | python scripts/tts_mlx.py - - --quantize 8 290 | 291 | # From text file to audio file 292 | python scripts/tts_mlx.py text_to_say.txt audio_output.wav 293 | ``` 294 | 295 | This requires the [moshi-mlx package](https://pypi.org/project/moshi-mlx/), which can be installed via pip. 296 | If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step 297 | and just prefix the command above with `uvx --with moshi-mlx`. 298 | </details> 299 | 300 | ## License 301 | 302 | The present code is provided under the MIT license for the Python parts, and Apache license for the Rust backend. 303 | The web client code is provided under the MIT license. 304 | Note that parts of this code is based on [AudioCraft](https://github.com/facebookresearch/audiocraft), released under 305 | the MIT license. 306 | 307 | The weights for the speech-to-text models are released under the CC-BY 4.0 license. 308 | 309 | ## Developing 310 | 311 | Install the [pre-commit hooks](https://pre-commit.com/) by running: 312 | 313 | ```bash 314 | pip install pre-commit 315 | pre-commit install 316 | ``` 317 | 318 | If you're using `uv`, you can replace the two commands with `uvx pre-commit install`. 319 | -------------------------------------------------------------------------------- /audio/bria.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/delayed-streams-modeling/baf0c75bba89608e921cb26e03c959981df2ad5f/audio/bria.mp3 -------------------------------------------------------------------------------- /audio/loona.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/delayed-streams-modeling/baf0c75bba89608e921cb26e03c959981df2ad5f/audio/loona.mp3 -------------------------------------------------------------------------------- /audio/sample_fr_hibiki_crepes.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/delayed-streams-modeling/baf0c75bba89608e921cb26e03c959981df2ad5f/audio/sample_fr_hibiki_crepes.mp3 -------------------------------------------------------------------------------- /configs/config-stt-en-hf.toml: -------------------------------------------------------------------------------- 1 | static_dir = "./static/" 2 | log_dir = "$HOME/tmp/tts-logs" 3 | instance_name = "tts" 4 | authorized_ids = ["public_token"] 5 | 6 | [modules.asr] 7 | path = "/api/asr-streaming" 8 | type = "BatchedAsr" 9 | lm_model_file = "hf://kyutai/stt-2.6b-en-candle/model.safetensors" 10 | text_tokenizer_file = "hf://kyutai/stt-2.6b-en-candle/tokenizer_en_audio_4000.model" 11 | audio_tokenizer_file = "hf://kyutai/stt-2.6b-en-candle/mimi-pytorch-e351c8d8@125.safetensors" 12 | asr_delay_in_tokens = 32 13 | batch_size = 16 14 | conditioning_learnt_padding = true 15 | temperature = 0 16 | 17 | [modules.asr.model] 18 | audio_vocab_size = 2049 19 | text_in_vocab_size = 4001 20 | text_out_vocab_size = 4000 21 | audio_codebooks = 32 22 | 23 | [modules.asr.model.transformer] 24 | d_model = 2048 25 | num_heads = 32 26 | num_layers = 48 27 | dim_feedforward = 8192 28 | causal = true 29 | norm_first = true 30 | bias_ff = false 31 | bias_attn = false 32 | context = 375 33 | max_period = 100000 34 | use_conv_block = false 35 | use_conv_bias = true 36 | gating = "silu" 37 | norm = "RmsNorm" 38 | positional_embedding = "Rope" 39 | conv_layout = false 40 | conv_kernel_size = 3 41 | kv_repeat = 1 42 | max_seq_len = 40960 43 | -------------------------------------------------------------------------------- /configs/config-stt-en_fr-hf.toml: -------------------------------------------------------------------------------- 1 | static_dir = "./static/" 2 | log_dir = "$HOME/tmp/tts-logs" 3 | instance_name = "tts" 4 | authorized_ids = ["public_token"] 5 | 6 | [modules.asr] 7 | path = "/api/asr-streaming" 8 | type = "BatchedAsr" 9 | lm_model_file = "hf://kyutai/stt-1b-en_fr-candle/model.safetensors" 10 | text_tokenizer_file = "hf://kyutai/stt-1b-en_fr-candle/tokenizer_en_fr_audio_8000.model" 11 | audio_tokenizer_file = "hf://kyutai/stt-1b-en_fr-candle/mimi-pytorch-e351c8d8@125.safetensors" 12 | asr_delay_in_tokens = 6 13 | batch_size = 64 14 | conditioning_learnt_padding = true 15 | temperature = 0.0 16 | 17 | [modules.asr.model] 18 | audio_vocab_size = 2049 19 | text_in_vocab_size = 8001 20 | text_out_vocab_size = 8000 21 | audio_codebooks = 32 22 | 23 | [modules.asr.model.transformer] 24 | d_model = 2048 25 | num_heads = 16 26 | num_layers = 16 27 | dim_feedforward = 8192 28 | causal = true 29 | norm_first = true 30 | bias_ff = false 31 | bias_attn = false 32 | context = 750 33 | max_period = 100000 34 | use_conv_block = false 35 | use_conv_bias = true 36 | gating = "silu" 37 | norm = "RmsNorm" 38 | positional_embedding = "Rope" 39 | conv_layout = false 40 | conv_kernel_size = 3 41 | kv_repeat = 1 42 | max_seq_len = 40960 43 | 44 | [modules.asr.model.extra_heads] 45 | num_heads = 4 46 | dim = 6 47 | -------------------------------------------------------------------------------- /configs/config-tts.toml: -------------------------------------------------------------------------------- 1 | static_dir = "./static/" 2 | log_dir = "$HOME/tmp/tts-logs" 3 | instance_name = "tts" 4 | authorized_ids = ["public_token"] 5 | 6 | [modules.tts_py] 7 | type = "Py" 8 | path = "/api/tts_streaming" 9 | text_tokenizer_file = "hf://kyutai/tts-1.6b-en_fr/tokenizer_spm_8k_en_fr_audio.model" 10 | batch_size = 8 # Adjust to your GPU memory capacity 11 | text_bos_token = 1 12 | 13 | [modules.tts_py.py] 14 | log_folder = "$HOME/tmp/moshi-server-logs" 15 | voice_folder = "hf-snapshot://kyutai/tts-voices/**/*.safetensors" 16 | default_voice = "unmute-prod-website/default_voice.wav" 17 | cfg_coef = 2.0 18 | cfg_is_no_text = true 19 | padding_between = 1 20 | n_q = 24 21 | -------------------------------------------------------------------------------- /scripts/stt_evaluate_on_dataset.py: -------------------------------------------------------------------------------- 1 | # /// script 2 | # requires-python = ">=3.12" 3 | # dependencies = [ 4 | # "datasets", 5 | # "jiwer==3.1.0", 6 | # "julius", 7 | # "librosa", 8 | # "moshi", 9 | # "openai-whisper", 10 | # "soundfile", 11 | # ] 12 | # /// 13 | """ 14 | Example implementation of the streaming STT example. Here we group 15 | test utterances in batches (pre- and post-padded with silence) and 16 | and then feed these batches into the streaming STT model frame-by-frame. 17 | """ 18 | 19 | # The outputs I get on my H100 using this code with the 2.6B model, 20 | # bsz 32: 21 | 22 | # LibriVox === cer: 4.09% wer: 7.33% corpus_wer: 6.78% RTF = 52.72 23 | # Ami === cer: 15.99% wer: 18.78% corpus_wer: 12.20% RTF = 28.37 24 | # LibriSpeech other === cer: 2.31% wer: 5.24% corpus_wer: 4.33% RTF = 44.76 25 | # LibriSpeech clean === cer: 0.67% wer: 1.95% corpus_wer: 1.69% RTF = 68.19 26 | # Tedlium (short) === cer: 2.15% wer: 3.65% corpus_wer: 3.33% RTF = 67.44 27 | # spgispeech === cer: 0.99% wer: 2.00% corpus_wer: 2.03% RTF = 78.64 28 | # gigaspeech === cer: 6.80% wer: 11.31% corpus_wer: 9.81% RTF = 64.04 29 | # earnings22 (short) === cer: 12.63% wer: 15.70% corpus_wer: 11.02% RTF = 50.13 30 | 31 | # Meanwhile === cer: 2.02% wer: 5.50% corpus_wer: 5.60% RTF = 69.19 32 | # Tedlium (long) == cer: 1.53% wer: 2.56% corpus_wer: 2.97% RTF = 33.92 33 | # Rev16 === cer: 6.57% wer: 10.08% corpus_wer: 11.43% RTF = 40.34 34 | # Earnings21 === cer: 5.73% wer: 9.84% corpus_wer: 10.38% RTF = 73.15 35 | 36 | import argparse 37 | import dataclasses 38 | import time 39 | 40 | import jiwer 41 | import julius 42 | import moshi.models 43 | import torch 44 | import tqdm 45 | from datasets import Dataset, load_dataset 46 | from whisper.normalizers import EnglishTextNormalizer 47 | 48 | _NORMALIZER = EnglishTextNormalizer() 49 | 50 | 51 | def get_text(sample): 52 | possible_keys = [ 53 | "text", 54 | "sentence", 55 | "normalized_text", 56 | "transcript", 57 | "transcription", 58 | ] 59 | for key in possible_keys: 60 | if key in sample: 61 | return sample[key] 62 | raise ValueError( 63 | f"Expected transcript column of either {possible_keys}." 64 | f"Got sample with keys: {', '.join(sample.keys())}. Ensure a text column name is present in the dataset." 65 | ) 66 | 67 | 68 | # The two functions below are adapted from https://github.com/huggingface/open_asr_leaderboard/blob/main/normalizer/data_utils.py 69 | 70 | 71 | def normalize(batch): 72 | batch["original_text"] = get_text(batch) 73 | batch["norm_text"] = _NORMALIZER(batch["original_text"]) 74 | return batch 75 | 76 | 77 | def is_target_text_in_range(ref): 78 | if ref.strip() == "ignore time segment in scoring": 79 | return False 80 | else: 81 | return ref.strip() != "" 82 | 83 | 84 | # End of the adapted part 85 | 86 | 87 | class AsrMetrics: 88 | def __init__(self): 89 | self.cer_sum = 0.0 90 | self.wer_sum = 0.0 91 | self.errors_sum = 0.0 92 | self.total_words_sum = 0.0 93 | self.num_sequences = 0.0 94 | 95 | def update(self, hyp: str, ref: str) -> None: 96 | normalized_ref = _NORMALIZER(ref) 97 | normalized_hyp = _NORMALIZER(hyp) 98 | 99 | this_wer = jiwer.wer(normalized_ref, normalized_hyp) 100 | this_cer = jiwer.cer(normalized_ref, normalized_hyp) 101 | measures = jiwer.compute_measures(normalized_ref, normalized_hyp) 102 | 103 | self.wer_sum += this_wer 104 | self.cer_sum += this_cer 105 | self.errors_sum += ( 106 | measures["substitutions"] + measures["deletions"] + measures["insertions"] 107 | ) 108 | self.total_words_sum += ( 109 | measures["substitutions"] + measures["deletions"] + measures["hits"] 110 | ) 111 | self.num_sequences += 1 112 | 113 | def compute(self) -> dict: 114 | assert self.num_sequences > 0, ( 115 | "Unable to compute with total number of comparisons <= 0" 116 | ) # type: ignore 117 | return { 118 | "cer": (self.cer_sum / self.num_sequences), 119 | "wer": (self.wer_sum / self.num_sequences), 120 | "corpus_wer": (self.errors_sum / self.total_words_sum), 121 | } 122 | 123 | def __str__(self) -> str: 124 | result = self.compute() 125 | return " ".join(f"{k}: {100 * v:.2f}%" for k, v in result.items()) 126 | 127 | 128 | class Timer: 129 | def __init__(self): 130 | self.total = 0 131 | self._start_time = None 132 | 133 | def __enter__(self): 134 | self._start_time = time.perf_counter() 135 | return self 136 | 137 | def __exit__(self, *_): 138 | self.total += time.perf_counter() - self._start_time 139 | self._start_time = None 140 | 141 | 142 | @dataclasses.dataclass 143 | class _DatasetInfo: 144 | alias: str 145 | 146 | name: str 147 | config: str 148 | split: str = "test" 149 | 150 | 151 | _DATASETS = [ 152 | # Long-form datasets from distil-whisper 153 | _DatasetInfo("rev16", "distil-whisper/rev16", "whisper_subset"), 154 | _DatasetInfo("earnings21", "distil-whisper/earnings21", "full"), 155 | _DatasetInfo("earnings22", "distil-whisper/earnings22", "full"), 156 | _DatasetInfo("tedlium", "distil-whisper/tedlium-long-form", None), 157 | _DatasetInfo("meanwhile", "distil-whisper/meanwhile", None), 158 | # Short-form datasets from OpenASR leaderboard 159 | _DatasetInfo("ami", "hf-audio/esb-datasets-test-only-sorted", "ami"), 160 | _DatasetInfo( 161 | "librispeech.clean", 162 | "hf-audio/esb-datasets-test-only-sorted", 163 | "librispeech", 164 | split="test.clean", 165 | ), 166 | _DatasetInfo( 167 | "librispeech.other", 168 | "hf-audio/esb-datasets-test-only-sorted", 169 | "librispeech", 170 | split="test.other", 171 | ), 172 | _DatasetInfo("voxpopuli", "hf-audio/esb-datasets-test-only-sorted", "voxpopuli"), 173 | _DatasetInfo("spgispeech", "hf-audio/esb-datasets-test-only-sorted", "spgispeech"), 174 | _DatasetInfo("gigaspeech", "hf-audio/esb-datasets-test-only-sorted", "gigaspeech"), 175 | _DatasetInfo("tedlium-short", "hf-audio/esb-datasets-test-only-sorted", "tedlium"), 176 | _DatasetInfo( 177 | "earnings22-short", "hf-audio/esb-datasets-test-only-sorted", "earnings22" 178 | ), 179 | ] 180 | DATASET_MAP = {dataset.alias: dataset for dataset in _DATASETS} 181 | 182 | 183 | def get_dataset(args) -> Dataset: 184 | if args.dataset not in DATASET_MAP: 185 | raise RuntimeError(f"Unknown dataset: {args.dataset}") 186 | 187 | info = DATASET_MAP[args.dataset] 188 | 189 | dataset = load_dataset( 190 | info.name, 191 | info.config, 192 | split=info.split, 193 | cache_dir=args.hf_cache_dir, 194 | streaming=False, 195 | token=True, 196 | ) 197 | dataset = dataset.map(normalize) 198 | dataset = dataset.filter(is_target_text_in_range, input_columns=["norm_text"]) 199 | 200 | return dataset 201 | 202 | 203 | @torch.no_grad 204 | def get_padded_batch( 205 | audios: list[tuple[torch.Tensor, int]], 206 | before_padding: float, 207 | after_padding: float, 208 | audio_encoder, 209 | ): 210 | sample_rate = audio_encoder.sample_rate 211 | 212 | max_len = 0 213 | batch = [] 214 | durations = [] 215 | for audio, sr in audios: 216 | durations.append(audio.shape[-1] / sr) 217 | audio = julius.resample_frac(audio, int(sr), int(sample_rate)) 218 | audio = torch.nn.functional.pad( 219 | audio, (int(before_padding * sample_rate), int(after_padding * sample_rate)) 220 | ) 221 | max_len = max(max_len, audio.shape[-1]) 222 | batch.append(audio) 223 | 224 | target = max_len 225 | if target % audio_encoder.frame_size != 0: 226 | target = target + ( 227 | audio_encoder.frame_size - max_len % audio_encoder.frame_size 228 | ) 229 | padded_batch = torch.stack( 230 | [ 231 | torch.nn.functional.pad(audio, (0, target - audio.shape[-1])) 232 | for audio in batch 233 | ] 234 | ) 235 | return padded_batch 236 | 237 | 238 | @torch.no_grad 239 | def streaming_transcribe( 240 | padded_batch: torch.Tensor, 241 | mimi, 242 | lm_gen, 243 | ): 244 | bsz = padded_batch.shape[0] 245 | 246 | text_tokens_acc = [] 247 | 248 | with mimi.streaming(bsz), lm_gen.streaming(bsz): 249 | for offset in range(0, padded_batch.shape[-1], mimi.frame_size): 250 | audio_chunk = padded_batch[:, offset : offset + mimi.frame_size] 251 | audio_chunk = audio_chunk[:, None, :] 252 | 253 | audio_tokens = mimi.encode(audio_chunk) 254 | text_tokens = lm_gen.step(audio_tokens) 255 | if text_tokens is not None: 256 | text_tokens_acc.append(text_tokens) 257 | 258 | return torch.concat(text_tokens_acc, axis=-1) 259 | 260 | 261 | def run_inference( 262 | dataset, 263 | mimi, 264 | lm_gen, 265 | tokenizer, 266 | padding_token_id, 267 | before_padding_sec, 268 | after_padding_sec, 269 | ): 270 | metrics = AsrMetrics() 271 | audio_time = 0.0 272 | inference_timer = Timer() 273 | 274 | for batch in tqdm.tqdm(dataset.iter(args.batch_size)): 275 | audio_data = list( 276 | zip( 277 | [torch.tensor(x["array"]).float() for x in batch["audio"]], 278 | [x["sampling_rate"] for x in batch["audio"]], 279 | ) 280 | ) 281 | 282 | audio_time += sum(audio.shape[-1] / sr for (audio, sr) in audio_data) 283 | 284 | gt_transcripts = batch["original_text"] 285 | 286 | padded_batch = get_padded_batch( 287 | audio_data, 288 | before_padding=before_padding_sec, 289 | after_padding=after_padding_sec, 290 | audio_encoder=mimi, 291 | ) 292 | padded_batch = padded_batch.cuda() 293 | 294 | with inference_timer: 295 | text_tokens = streaming_transcribe( 296 | padded_batch, 297 | mimi=mimi, 298 | lm_gen=lm_gen, 299 | ) 300 | 301 | for batch_index in range(text_tokens.shape[0]): 302 | utterance_tokens = text_tokens[batch_index, ...] 303 | utterance_tokens = utterance_tokens[utterance_tokens > padding_token_id] 304 | text = tokenizer.decode(utterance_tokens.cpu().numpy().tolist()) 305 | metrics.update(hyp=text, ref=gt_transcripts[batch_index]) 306 | 307 | return metrics, inference_timer.total, audio_time 308 | 309 | 310 | def main(args): 311 | torch.set_float32_matmul_precision("high") 312 | 313 | info = moshi.models.loaders.CheckpointInfo.from_hf_repo( 314 | args.hf_repo, 315 | moshi_weights=args.moshi_weight, 316 | mimi_weights=args.mimi_weight, 317 | tokenizer=args.tokenizer, 318 | config_path=args.config_path, 319 | ) 320 | 321 | mimi = info.get_mimi(device=args.device) 322 | tokenizer = info.get_text_tokenizer() 323 | lm = info.get_moshi( 324 | device=args.device, 325 | dtype=torch.bfloat16, 326 | ) 327 | lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0) 328 | dataset = get_dataset(args) 329 | 330 | padding_token_id = info.raw_config.get("text_padding_token_id", 3) 331 | # Putting in some conservative defaults 332 | audio_silence_prefix_seconds = info.stt_config.get( 333 | "audio_silence_prefix_seconds", 1.0 334 | ) 335 | audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0) 336 | 337 | wer_metric, inference_time, audio_time = run_inference( 338 | dataset, 339 | mimi, 340 | lm_gen, 341 | tokenizer, 342 | padding_token_id, 343 | audio_silence_prefix_seconds, 344 | audio_delay_seconds + 0.5, 345 | ) 346 | 347 | print(wer_metric, f"RTF = {audio_time / inference_time:.2f}") 348 | 349 | 350 | if __name__ == "__main__": 351 | parser = argparse.ArgumentParser(description="Example streaming STT inference.") 352 | parser.add_argument( 353 | "--dataset", 354 | required=True, 355 | choices=DATASET_MAP.keys(), 356 | help="Dataset to run inference on.", 357 | ) 358 | 359 | parser.add_argument( 360 | "--hf-repo", type=str, help="HF repo to load the STT model from." 361 | ) 362 | parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.") 363 | parser.add_argument( 364 | "--moshi-weight", type=str, help="Path to a local checkpoint file." 365 | ) 366 | parser.add_argument( 367 | "--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi." 368 | ) 369 | parser.add_argument( 370 | "--config-path", type=str, help="Path to a local config file.", default=None 371 | ) 372 | parser.add_argument( 373 | "--batch-size", 374 | type=int, 375 | help="Batch size.", 376 | default=32, 377 | ) 378 | parser.add_argument( 379 | "--device", 380 | type=str, 381 | default="cuda", 382 | help="Device on which to run, defaults to 'cuda'.", 383 | ) 384 | parser.add_argument("--hf-cache-dir", type=str, help="HuggingFace cache folder.") 385 | args = parser.parse_args() 386 | 387 | main(args) 388 | -------------------------------------------------------------------------------- /scripts/stt_from_file_mlx.py: -------------------------------------------------------------------------------- 1 | # /// script 2 | # requires-python = ">=3.12" 3 | # dependencies = [ 4 | # "huggingface_hub", 5 | # "moshi_mlx==0.2.10", 6 | # "numpy", 7 | # "sentencepiece", 8 | # "sounddevice", 9 | # "sphn", 10 | # ] 11 | # /// 12 | 13 | import argparse 14 | import json 15 | 16 | import mlx.core as mx 17 | import mlx.nn as nn 18 | import sentencepiece 19 | import sphn 20 | from huggingface_hub import hf_hub_download 21 | from moshi_mlx import models, utils 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("in_file", help="The file to transcribe.") 26 | parser.add_argument("--max-steps", default=4096) 27 | parser.add_argument("--hf-repo", default="kyutai/stt-1b-en_fr-mlx") 28 | parser.add_argument( 29 | "--vad", action="store_true", help="Enable VAD (Voice Activity Detection)." 30 | ) 31 | args = parser.parse_args() 32 | 33 | audio, _ = sphn.read(args.in_file, sample_rate=24000) 34 | lm_config = hf_hub_download(args.hf_repo, "config.json") 35 | with open(lm_config, "r") as fobj: 36 | lm_config = json.load(fobj) 37 | mimi_weights = hf_hub_download(args.hf_repo, lm_config["mimi_name"]) 38 | moshi_name = lm_config.get("moshi_name", "model.safetensors") 39 | moshi_weights = hf_hub_download(args.hf_repo, moshi_name) 40 | text_tokenizer = hf_hub_download(args.hf_repo, lm_config["tokenizer_name"]) 41 | 42 | lm_config = models.LmConfig.from_config_dict(lm_config) 43 | model = models.Lm(lm_config) 44 | model.set_dtype(mx.bfloat16) 45 | if moshi_weights.endswith(".q4.safetensors"): 46 | nn.quantize(model, bits=4, group_size=32) 47 | elif moshi_weights.endswith(".q8.safetensors"): 48 | nn.quantize(model, bits=8, group_size=64) 49 | 50 | print(f"loading model weights from {moshi_weights}") 51 | if args.hf_repo.endswith("-candle"): 52 | model.load_pytorch_weights(moshi_weights, lm_config, strict=True) 53 | else: 54 | model.load_weights(moshi_weights, strict=True) 55 | 56 | print(f"loading the text tokenizer from {text_tokenizer}") 57 | text_tokenizer = sentencepiece.SentencePieceProcessor(text_tokenizer) # type: ignore 58 | 59 | print(f"loading the audio tokenizer {mimi_weights}") 60 | audio_tokenizer = models.mimi.Mimi(models.mimi_202407(32)) 61 | audio_tokenizer.load_pytorch_weights(str(mimi_weights), strict=True) 62 | print("warming up the model") 63 | model.warmup() 64 | gen = models.LmGen( 65 | model=model, 66 | max_steps=args.max_steps, 67 | text_sampler=utils.Sampler(top_k=25, temp=0), 68 | audio_sampler=utils.Sampler(top_k=250, temp=0.8), 69 | check=False, 70 | ) 71 | 72 | print(f"starting inference {audio.shape}") 73 | audio = mx.concat([mx.array(audio), mx.zeros((1, 48000))], axis=-1) 74 | last_print_was_vad = False 75 | for start_idx in range(0, audio.shape[-1] // 1920 * 1920, 1920): 76 | block = audio[:, None, start_idx : start_idx + 1920] 77 | other_audio_tokens = audio_tokenizer.encode_step(block).transpose(0, 2, 1) 78 | if args.vad: 79 | text_token, vad_heads = gen.step_with_extra_heads(other_audio_tokens[0]) 80 | if vad_heads: 81 | pr_vad = vad_heads[2][0, 0, 0].item() 82 | if pr_vad > 0.5 and not last_print_was_vad: 83 | print(" [end of turn detected]") 84 | last_print_was_vad = True 85 | else: 86 | text_token = gen.step(other_audio_tokens[0]) 87 | text_token = text_token[0].item() 88 | audio_tokens = gen.last_audio_tokens() 89 | _text = None 90 | if text_token not in (0, 3): 91 | _text = text_tokenizer.id_to_piece(text_token) # type: ignore 92 | _text = _text.replace("▁", " ") 93 | print(_text, end="", flush=True) 94 | last_print_was_vad = False 95 | print() 96 | -------------------------------------------------------------------------------- /scripts/stt_from_file_pytorch.py: -------------------------------------------------------------------------------- 1 | # /// script 2 | # requires-python = ">=3.12" 3 | # dependencies = [ 4 | # "julius", 5 | # "librosa", 6 | # "soundfile", 7 | # "moshi==0.2.9", 8 | # ] 9 | # /// 10 | 11 | """An example script that illustrates how one can get per-word timestamps from 12 | Kyutai STT models. 13 | """ 14 | 15 | import argparse 16 | import dataclasses 17 | import itertools 18 | import math 19 | 20 | import julius 21 | import moshi.models 22 | import sphn 23 | import time 24 | import torch 25 | 26 | 27 | @dataclasses.dataclass 28 | class TimestampedText: 29 | text: str 30 | timestamp: tuple[float, float] 31 | 32 | def __str__(self): 33 | return f"{self.text} ({self.timestamp[0]:.2f}:{self.timestamp[1]:.2f})" 34 | 35 | 36 | def tokens_to_timestamped_text( 37 | text_tokens, 38 | tokenizer, 39 | frame_rate, 40 | end_of_padding_id, 41 | padding_token_id, 42 | offset_seconds, 43 | ) -> list[TimestampedText]: 44 | text_tokens = text_tokens.cpu().view(-1) 45 | 46 | # Normally `end_of_padding` tokens indicate word boundaries. 47 | # Everything between them should be a single word; 48 | # the time offset of the those tokens correspond to word start and 49 | # end timestamps (minus silence prefix and audio delay). 50 | # 51 | # However, in rare cases some complexities could arise. Firstly, 52 | # for words that are said quickly but are represented with 53 | # multiple tokens, the boundary might be omitted. Secondly, 54 | # for the very last word the end boundary might not happen. 55 | # Below is a code snippet that handles those situations a bit 56 | # more carefully. 57 | 58 | sequence_timestamps = [] 59 | 60 | def _tstmp(start_position, end_position): 61 | return ( 62 | max(0, start_position / frame_rate - offset_seconds), 63 | max(0, end_position / frame_rate - offset_seconds), 64 | ) 65 | 66 | def _decode(t): 67 | t = t[t > padding_token_id] 68 | return tokenizer.decode(t.numpy().tolist()) 69 | 70 | def _decode_segment(start, end): 71 | nonlocal text_tokens 72 | nonlocal sequence_timestamps 73 | 74 | text = _decode(text_tokens[start:end]) 75 | words_inside_segment = text.split() 76 | 77 | if len(words_inside_segment) == 0: 78 | return 79 | if len(words_inside_segment) == 1: 80 | # Single word within the boundaries, the general case 81 | sequence_timestamps.append( 82 | TimestampedText(text=text, timestamp=_tstmp(start, end)) 83 | ) 84 | else: 85 | # We're in a rare situation where multiple words are so close they are not separated by `end_of_padding`. 86 | # We tokenize words one-by-one; each word is assigned with as many frames as much tokens it has. 87 | for adjacent_word in words_inside_segment[:-1]: 88 | n_tokens = len(tokenizer.encode(adjacent_word)) 89 | sequence_timestamps.append( 90 | TimestampedText( 91 | text=adjacent_word, timestamp=_tstmp(start, start + n_tokens) 92 | ) 93 | ) 94 | start += n_tokens 95 | 96 | # The last word takes everything until the boundary 97 | adjacent_word = words_inside_segment[-1] 98 | sequence_timestamps.append( 99 | TimestampedText(text=adjacent_word, timestamp=_tstmp(start, end)) 100 | ) 101 | 102 | (segment_boundaries,) = torch.where(text_tokens == end_of_padding_id) 103 | 104 | if not segment_boundaries.numel(): 105 | return [] 106 | 107 | for i in range(len(segment_boundaries) - 1): 108 | segment_start = int(segment_boundaries[i]) + 1 109 | segment_end = int(segment_boundaries[i + 1]) 110 | 111 | _decode_segment(segment_start, segment_end) 112 | 113 | last_segment_start = segment_boundaries[-1] + 1 114 | 115 | boundary_token = torch.tensor([tokenizer.eos_id()]) 116 | (end_of_last_segment,) = torch.where( 117 | torch.isin(text_tokens[last_segment_start:], boundary_token) 118 | ) 119 | 120 | if not end_of_last_segment.numel(): 121 | # upper-bound either end of the audio or 1 second duration, whicher is smaller 122 | last_segment_end = min(text_tokens.shape[-1], last_segment_start + frame_rate) 123 | else: 124 | last_segment_end = last_segment_start + end_of_last_segment[0] 125 | _decode_segment(last_segment_start, last_segment_end) 126 | 127 | return sequence_timestamps 128 | 129 | 130 | def main(args): 131 | info = moshi.models.loaders.CheckpointInfo.from_hf_repo( 132 | args.hf_repo, 133 | moshi_weights=args.moshi_weight, 134 | mimi_weights=args.mimi_weight, 135 | tokenizer=args.tokenizer, 136 | config_path=args.config_path, 137 | ) 138 | 139 | mimi = info.get_mimi(device=args.device) 140 | tokenizer = info.get_text_tokenizer() 141 | lm = info.get_moshi( 142 | device=args.device, 143 | dtype=torch.bfloat16, 144 | ) 145 | lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0) 146 | 147 | audio_silence_prefix_seconds = info.stt_config.get( 148 | "audio_silence_prefix_seconds", 1.0 149 | ) 150 | audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0) 151 | padding_token_id = info.raw_config.get("text_padding_token_id", 3) 152 | 153 | audio, input_sample_rate = sphn.read(args.in_file) 154 | audio = torch.from_numpy(audio).to(args.device) 155 | audio = julius.resample_frac(audio, input_sample_rate, mimi.sample_rate) 156 | if audio.shape[-1] % mimi.frame_size != 0: 157 | to_pad = mimi.frame_size - audio.shape[-1] % mimi.frame_size 158 | audio = torch.nn.functional.pad(audio, (0, to_pad)) 159 | 160 | text_tokens_accum = [] 161 | 162 | n_prefix_chunks = math.ceil(audio_silence_prefix_seconds * mimi.frame_rate) 163 | n_suffix_chunks = math.ceil(audio_delay_seconds * mimi.frame_rate) 164 | silence_chunk = torch.zeros( 165 | (1, 1, mimi.frame_size), dtype=torch.float32, device=args.device 166 | ) 167 | 168 | chunks = itertools.chain( 169 | itertools.repeat(silence_chunk, n_prefix_chunks), 170 | torch.split(audio[:, None], mimi.frame_size, dim=-1), 171 | itertools.repeat(silence_chunk, n_suffix_chunks), 172 | ) 173 | 174 | start_time = time.time() 175 | nchunks = 0 176 | last_print_was_vad = False 177 | with mimi.streaming(1), lm_gen.streaming(1): 178 | for audio_chunk in chunks: 179 | nchunks += 1 180 | audio_tokens = mimi.encode(audio_chunk) 181 | if args.vad: 182 | text_tokens, vad_heads = lm_gen.step_with_extra_heads(audio_tokens) 183 | if vad_heads: 184 | pr_vad = vad_heads[2][0, 0, 0].cpu().item() 185 | if pr_vad > 0.5 and not last_print_was_vad: 186 | print(" [end of turn detected]") 187 | last_print_was_vad = True 188 | else: 189 | text_tokens = lm_gen.step(audio_tokens) 190 | text_token = text_tokens[0, 0, 0].cpu().item() 191 | if text_token not in (0, 3): 192 | _text = tokenizer.id_to_piece(text_tokens[0, 0, 0].cpu().item()) # type: ignore 193 | _text = _text.replace("▁", " ") 194 | print(_text, end="", flush=True) 195 | last_print_was_vad = False 196 | text_tokens_accum.append(text_tokens) 197 | 198 | utterance_tokens = torch.concat(text_tokens_accum, dim=-1) 199 | dt = time.time() - start_time 200 | print( 201 | f"\nprocessed {nchunks} chunks in {dt:.2f} seconds, steps per second: {nchunks / dt:.2f}" 202 | ) 203 | timed_text = tokens_to_timestamped_text( 204 | utterance_tokens, 205 | tokenizer, 206 | mimi.frame_rate, 207 | end_of_padding_id=0, 208 | padding_token_id=padding_token_id, 209 | offset_seconds=int(n_prefix_chunks / mimi.frame_rate) + audio_delay_seconds, 210 | ) 211 | 212 | decoded = " ".join([str(t) for t in timed_text]) 213 | print(decoded) 214 | 215 | 216 | if __name__ == "__main__": 217 | parser = argparse.ArgumentParser(description="Example streaming STT w/ timestamps.") 218 | parser.add_argument("in_file", help="The file to transcribe.") 219 | 220 | parser.add_argument( 221 | "--hf-repo", type=str, help="HF repo to load the STT model from. " 222 | ) 223 | parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.") 224 | parser.add_argument( 225 | "--moshi-weight", type=str, help="Path to a local checkpoint file." 226 | ) 227 | parser.add_argument( 228 | "--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi." 229 | ) 230 | parser.add_argument( 231 | "--config-path", type=str, help="Path to a local config file.", default=None 232 | ) 233 | parser.add_argument( 234 | "--vad", action="store_true", help="Enable VAD (Voice Activity Detection)." 235 | ) 236 | parser.add_argument( 237 | "--device", 238 | type=str, 239 | default="cuda", 240 | help="Device on which to run, defaults to 'cuda'.", 241 | ) 242 | args = parser.parse_args() 243 | 244 | main(args) 245 | -------------------------------------------------------------------------------- /scripts/stt_from_file_rust_server.py: -------------------------------------------------------------------------------- 1 | # /// script 2 | # requires-python = ">=3.12" 3 | # dependencies = [ 4 | # "msgpack", 5 | # "numpy", 6 | # "sphn", 7 | # "websockets", 8 | # ] 9 | # /// 10 | import argparse 11 | import asyncio 12 | import time 13 | 14 | import msgpack 15 | import numpy as np 16 | import sphn 17 | import websockets 18 | 19 | SAMPLE_RATE = 24000 20 | FRAME_SIZE = 1920 # Send data in chunks 21 | 22 | 23 | def load_and_process_audio(file_path): 24 | """Load an MP3 file, resample to 24kHz, convert to mono, and extract PCM float32 data.""" 25 | pcm_data, _ = sphn.read(file_path, sample_rate=SAMPLE_RATE) 26 | return pcm_data[0] 27 | 28 | 29 | async def receive_messages(websocket): 30 | transcript = [] 31 | 32 | async for message in websocket: 33 | data = msgpack.unpackb(message, raw=False) 34 | if data["type"] == "Step": 35 | # This message contains the signal from the semantic VAD, and tells us how 36 | # much audio the server has already processed. We don't use either here. 37 | continue 38 | if data["type"] == "Word": 39 | print(data["text"], end=" ", flush=True) 40 | transcript.append( 41 | { 42 | "text": data["text"], 43 | "timestamp": [data["start_time"], data["start_time"]], 44 | } 45 | ) 46 | if data["type"] == "EndWord": 47 | if len(transcript) > 0: 48 | transcript[-1]["timestamp"][1] = data["stop_time"] 49 | if data["type"] == "Marker": 50 | # Received marker, stopping stream 51 | break 52 | 53 | return transcript 54 | 55 | 56 | async def send_messages(websocket, rtf: float): 57 | audio_data = load_and_process_audio(args.in_file) 58 | 59 | async def send_audio(audio: np.ndarray): 60 | await websocket.send( 61 | msgpack.packb( 62 | {"type": "Audio", "pcm": [float(x) for x in audio]}, 63 | use_single_float=True, 64 | ) 65 | ) 66 | 67 | # Start with a second of silence. 68 | # This is needed for the 2.6B model for technical reasons. 69 | await send_audio([0.0] * SAMPLE_RATE) 70 | 71 | start_time = time.time() 72 | for i in range(0, len(audio_data), FRAME_SIZE): 73 | await send_audio(audio_data[i : i + FRAME_SIZE]) 74 | 75 | expected_send_time = start_time + (i + 1) / SAMPLE_RATE / rtf 76 | current_time = time.time() 77 | if current_time < expected_send_time: 78 | await asyncio.sleep(expected_send_time - current_time) 79 | else: 80 | await asyncio.sleep(0.001) 81 | 82 | for _ in range(5): 83 | await send_audio([0.0] * SAMPLE_RATE) 84 | 85 | # Send a marker to indicate the end of the stream. 86 | await websocket.send( 87 | msgpack.packb({"type": "Marker", "id": 0}, use_single_float=True) 88 | ) 89 | 90 | # We'll get back the marker once the corresponding audio has been transcribed, 91 | # accounting for the delay of the model. That's why we need to send some silence 92 | # after the marker, because the model will not return the marker immediately. 93 | for _ in range(35): 94 | await send_audio([0.0] * SAMPLE_RATE) 95 | 96 | 97 | async def stream_audio(url: str, api_key: str, rtf: float): 98 | """Stream audio data to a WebSocket server.""" 99 | headers = {"kyutai-api-key": api_key} 100 | 101 | # Instead of using the header, you can authenticate by adding `?auth_id={api_key}` to the URL 102 | async with websockets.connect(url, additional_headers=headers) as websocket: 103 | send_task = asyncio.create_task(send_messages(websocket, rtf)) 104 | receive_task = asyncio.create_task(receive_messages(websocket)) 105 | _, transcript = await asyncio.gather(send_task, receive_task) 106 | 107 | return transcript 108 | 109 | 110 | if __name__ == "__main__": 111 | parser = argparse.ArgumentParser() 112 | parser.add_argument("in_file") 113 | parser.add_argument( 114 | "--url", 115 | help="The url of the server to which to send the audio", 116 | default="ws://127.0.0.1:8080", 117 | ) 118 | parser.add_argument("--api-key", default="public_token") 119 | parser.add_argument( 120 | "--rtf", 121 | type=float, 122 | default=1.01, 123 | help="The real-time factor of how fast to feed in the audio.", 124 | ) 125 | args = parser.parse_args() 126 | 127 | url = f"{args.url}/api/asr-streaming" 128 | transcript = asyncio.run(stream_audio(url, args.api_key, args.rtf)) 129 | 130 | print() 131 | print() 132 | for word in transcript: 133 | print( 134 | f"{word['timestamp'][0]:7.2f} -{word['timestamp'][1]:7.2f} {word['text']}" 135 | ) 136 | -------------------------------------------------------------------------------- /scripts/stt_from_file_with_prompt_pytorch.py: -------------------------------------------------------------------------------- 1 | """An example script that illustrates how one can prompt Kyutai STT models.""" 2 | 3 | import argparse 4 | import itertools 5 | import math 6 | from collections import deque 7 | 8 | import julius 9 | import moshi.models 10 | import sphn 11 | import torch 12 | import tqdm 13 | 14 | 15 | class PromptHook: 16 | def __init__(self, tokenizer, prefix, padding_tokens=(0, 3)): 17 | self.tokenizer = tokenizer 18 | self.prefix_enforce = deque(self.tokenizer.encode(prefix)) 19 | self.padding_tokens = padding_tokens 20 | 21 | def on_token(self, token): 22 | if not self.prefix_enforce: 23 | return 24 | 25 | token = token.item() 26 | 27 | if token in self.padding_tokens: 28 | pass 29 | elif token == self.prefix_enforce[0]: 30 | self.prefix_enforce.popleft() 31 | else: 32 | assert False 33 | 34 | def on_logits(self, logits): 35 | if not self.prefix_enforce: 36 | return 37 | 38 | mask = torch.zeros_like(logits, dtype=torch.bool) 39 | for t in self.padding_tokens: 40 | mask[..., t] = True 41 | mask[..., self.prefix_enforce[0]] = True 42 | 43 | logits[:] = torch.where(mask, logits, float("-inf")) 44 | 45 | 46 | def main(args): 47 | info = moshi.models.loaders.CheckpointInfo.from_hf_repo( 48 | args.hf_repo, 49 | moshi_weights=args.moshi_weight, 50 | mimi_weights=args.mimi_weight, 51 | tokenizer=args.tokenizer, 52 | config_path=args.config_path, 53 | ) 54 | 55 | mimi = info.get_mimi(device=args.device) 56 | tokenizer = info.get_text_tokenizer() 57 | lm = info.get_moshi( 58 | device=args.device, 59 | dtype=torch.bfloat16, 60 | ) 61 | 62 | if args.prompt_text: 63 | prompt_hook = PromptHook(tokenizer, args.prompt_text) 64 | lm_gen = moshi.models.LMGen( 65 | lm, 66 | temp=0, 67 | temp_text=0.0, 68 | on_text_hook=prompt_hook.on_token, 69 | on_text_logits_hook=prompt_hook.on_logits, 70 | ) 71 | else: 72 | lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0) 73 | 74 | audio_silence_prefix_seconds = info.stt_config.get( 75 | "audio_silence_prefix_seconds", 1.0 76 | ) 77 | audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0) 78 | padding_token_id = info.raw_config.get("text_padding_token_id", 3) 79 | 80 | def _load_and_process(path): 81 | audio, input_sample_rate = sphn.read(path) 82 | audio = torch.from_numpy(audio).to(args.device).mean(axis=0, keepdim=True) 83 | audio = julius.resample_frac(audio, input_sample_rate, mimi.sample_rate) 84 | if audio.shape[-1] % mimi.frame_size != 0: 85 | to_pad = mimi.frame_size - audio.shape[-1] % mimi.frame_size 86 | audio = torch.nn.functional.pad(audio, (0, to_pad)) 87 | return audio 88 | 89 | n_prefix_chunks = math.ceil(audio_silence_prefix_seconds * mimi.frame_rate) 90 | n_suffix_chunks = math.ceil(audio_delay_seconds * mimi.frame_rate) 91 | silence_chunk = torch.zeros( 92 | (1, 1, mimi.frame_size), dtype=torch.float32, device=args.device 93 | ) 94 | 95 | audio = _load_and_process(args.file) 96 | if args.prompt_file: 97 | audio_prompt = _load_and_process(args.prompt_file) 98 | else: 99 | audio_prompt = None 100 | 101 | chain = [itertools.repeat(silence_chunk, n_prefix_chunks)] 102 | 103 | if audio_prompt is not None: 104 | chain.append(torch.split(audio_prompt[:, None, :], mimi.frame_size, dim=-1)) 105 | # adding a bit (0.8s) of silence to separate prompt and the actual audio 106 | chain.append(itertools.repeat(silence_chunk, 10)) 107 | 108 | chain += [ 109 | torch.split(audio[:, None, :], mimi.frame_size, dim=-1), 110 | itertools.repeat(silence_chunk, n_suffix_chunks), 111 | ] 112 | 113 | chunks = itertools.chain(*chain) 114 | 115 | text_tokens_accum = [] 116 | with mimi.streaming(1), lm_gen.streaming(1): 117 | for audio_chunk in tqdm.tqdm(chunks): 118 | audio_tokens = mimi.encode(audio_chunk) 119 | text_tokens = lm_gen.step(audio_tokens) 120 | if text_tokens is not None: 121 | text_tokens_accum.append(text_tokens) 122 | 123 | utterance_tokens = torch.concat(text_tokens_accum, dim=-1) 124 | text_tokens = utterance_tokens.cpu().view(-1) 125 | 126 | # if we have an audio prompt and we don't want to have it in the transcript, 127 | # we should cut the corresponding number of frames from the output tokens. 128 | # However, there is also some amount of padding that happens before it 129 | # due to silence_prefix and audio_delay. Normally it is ignored in detokenization, 130 | # but now we should account for it to find the position of the prompt transcript. 131 | if args.cut_prompt_transcript and audio_prompt is not None: 132 | prompt_frames = audio_prompt.shape[1] // mimi.frame_size 133 | no_prompt_offset_seconds = audio_delay_seconds + audio_silence_prefix_seconds 134 | no_prompt_offset = int(no_prompt_offset_seconds * mimi.frame_rate) 135 | text_tokens = text_tokens[prompt_frames + no_prompt_offset :] 136 | 137 | text = tokenizer.decode( 138 | text_tokens[text_tokens > padding_token_id].numpy().tolist() 139 | ) 140 | 141 | print(text) 142 | 143 | 144 | if __name__ == "__main__": 145 | parser = argparse.ArgumentParser(description="Example streaming STT w/ a prompt.") 146 | parser.add_argument( 147 | "--file", 148 | required=True, 149 | help="File to transcribe.", 150 | ) 151 | parser.add_argument( 152 | "--prompt_file", 153 | required=False, 154 | help="Audio of the prompt.", 155 | ) 156 | parser.add_argument( 157 | "--prompt_text", 158 | required=False, 159 | help="Text of the prompt.", 160 | ) 161 | parser.add_argument( 162 | "--cut-prompt-transcript", 163 | action="store_true", 164 | help="Cut the prompt from the output transcript", 165 | ) 166 | parser.add_argument( 167 | "--hf-repo", type=str, help="HF repo to load the STT model from. " 168 | ) 169 | parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.") 170 | parser.add_argument( 171 | "--moshi-weight", type=str, help="Path to a local checkpoint file." 172 | ) 173 | parser.add_argument( 174 | "--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi." 175 | ) 176 | parser.add_argument( 177 | "--config-path", type=str, help="Path to a local config file.", default=None 178 | ) 179 | parser.add_argument( 180 | "--device", 181 | type=str, 182 | default="cuda", 183 | help="Device on which to run, defaults to 'cuda'.", 184 | ) 185 | args = parser.parse_args() 186 | 187 | main(args) 188 | -------------------------------------------------------------------------------- /scripts/stt_from_mic_mlx.py: -------------------------------------------------------------------------------- 1 | # /// script 2 | # requires-python = ">=3.12" 3 | # dependencies = [ 4 | # "huggingface_hub", 5 | # "moshi_mlx==0.2.10", 6 | # "numpy", 7 | # "rustymimi", 8 | # "sentencepiece", 9 | # "sounddevice", 10 | # ] 11 | # /// 12 | 13 | import argparse 14 | import json 15 | import queue 16 | 17 | import mlx.core as mx 18 | import mlx.nn as nn 19 | import rustymimi 20 | import sentencepiece 21 | import sounddevice as sd 22 | from huggingface_hub import hf_hub_download 23 | from moshi_mlx import models, utils 24 | 25 | if __name__ == "__main__": 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("--max-steps", default=4096) 28 | parser.add_argument("--hf-repo", default="kyutai/stt-1b-en_fr-mlx") 29 | parser.add_argument( 30 | "--vad", action="store_true", help="Enable VAD (Voice Activity Detection)." 31 | ) 32 | args = parser.parse_args() 33 | 34 | lm_config = hf_hub_download(args.hf_repo, "config.json") 35 | with open(lm_config, "r") as fobj: 36 | lm_config = json.load(fobj) 37 | mimi_weights = hf_hub_download(args.hf_repo, lm_config["mimi_name"]) 38 | moshi_name = lm_config.get("moshi_name", "model.safetensors") 39 | moshi_weights = hf_hub_download(args.hf_repo, moshi_name) 40 | tokenizer = hf_hub_download(args.hf_repo, lm_config["tokenizer_name"]) 41 | 42 | lm_config = models.LmConfig.from_config_dict(lm_config) 43 | model = models.Lm(lm_config) 44 | model.set_dtype(mx.bfloat16) 45 | if moshi_weights.endswith(".q4.safetensors"): 46 | nn.quantize(model, bits=4, group_size=32) 47 | elif moshi_weights.endswith(".q8.safetensors"): 48 | nn.quantize(model, bits=8, group_size=64) 49 | 50 | print(f"loading model weights from {moshi_weights}") 51 | if args.hf_repo.endswith("-candle"): 52 | model.load_pytorch_weights(moshi_weights, lm_config, strict=True) 53 | else: 54 | model.load_weights(moshi_weights, strict=True) 55 | 56 | print(f"loading the text tokenizer from {tokenizer}") 57 | text_tokenizer = sentencepiece.SentencePieceProcessor(tokenizer) # type: ignore 58 | 59 | print(f"loading the audio tokenizer {mimi_weights}") 60 | generated_codebooks = lm_config.generated_codebooks 61 | other_codebooks = lm_config.other_codebooks 62 | mimi_codebooks = max(generated_codebooks, other_codebooks) 63 | audio_tokenizer = rustymimi.Tokenizer(mimi_weights, num_codebooks=mimi_codebooks) # type: ignore 64 | print("warming up the model") 65 | model.warmup() 66 | gen = models.LmGen( 67 | model=model, 68 | max_steps=args.max_steps, 69 | text_sampler=utils.Sampler(top_k=25, temp=0), 70 | audio_sampler=utils.Sampler(top_k=250, temp=0.8), 71 | check=False, 72 | ) 73 | 74 | block_queue = queue.Queue() 75 | 76 | def audio_callback(indata, _frames, _time, _status): 77 | block_queue.put(indata.copy()) 78 | 79 | print("recording audio from microphone, speak to get your words transcribed") 80 | last_print_was_vad = False 81 | with sd.InputStream( 82 | channels=1, 83 | dtype="float32", 84 | samplerate=24000, 85 | blocksize=1920, 86 | callback=audio_callback, 87 | ): 88 | while True: 89 | block = block_queue.get() 90 | block = block[None, :, 0] 91 | other_audio_tokens = audio_tokenizer.encode_step(block[None, 0:1]) 92 | other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[ 93 | :, :, :other_codebooks 94 | ] 95 | if args.vad: 96 | text_token, vad_heads = gen.step_with_extra_heads(other_audio_tokens[0]) 97 | if vad_heads: 98 | pr_vad = vad_heads[2][0, 0, 0].item() 99 | if pr_vad > 0.5 and not last_print_was_vad: 100 | print(" [end of turn detected]") 101 | last_print_was_vad = True 102 | else: 103 | text_token = gen.step(other_audio_tokens[0]) 104 | text_token = text_token[0].item() 105 | audio_tokens = gen.last_audio_tokens() 106 | _text = None 107 | if text_token not in (0, 3): 108 | _text = text_tokenizer.id_to_piece(text_token) # type: ignore 109 | _text = _text.replace("▁", " ") 110 | print(_text, end="", flush=True) 111 | last_print_was_vad = False 112 | -------------------------------------------------------------------------------- /scripts/stt_from_mic_rust_server.py: -------------------------------------------------------------------------------- 1 | # /// script 2 | # requires-python = ">=3.12" 3 | # dependencies = [ 4 | # "msgpack", 5 | # "numpy", 6 | # "sounddevice", 7 | # "websockets", 8 | # ] 9 | # /// 10 | import argparse 11 | import asyncio 12 | import signal 13 | 14 | import msgpack 15 | import numpy as np 16 | import sounddevice as sd 17 | import websockets 18 | 19 | SAMPLE_RATE = 24000 20 | 21 | # The VAD has several prediction heads, each of which tries to determine whether there 22 | # has been a pause of a given length. The lengths are 0.5, 1.0, 2.0, and 3.0 seconds. 23 | # Lower indices predict pauses more aggressively. In Unmute, we use 2.0 seconds = index 2. 24 | PAUSE_PREDICTION_HEAD_INDEX = 2 25 | 26 | 27 | async def receive_messages(websocket, show_vad: bool = False): 28 | """Receive and process messages from the WebSocket server.""" 29 | try: 30 | speech_started = False 31 | async for message in websocket: 32 | data = msgpack.unpackb(message, raw=False) 33 | 34 | # The Step message only gets sent if the model has semantic VAD available 35 | if data["type"] == "Step" and show_vad: 36 | pause_prediction = data["prs"][PAUSE_PREDICTION_HEAD_INDEX] 37 | if pause_prediction > 0.5 and speech_started: 38 | print("| ", end="", flush=True) 39 | speech_started = False 40 | 41 | elif data["type"] == "Word": 42 | print(data["text"], end=" ", flush=True) 43 | speech_started = True 44 | except websockets.ConnectionClosed: 45 | print("Connection closed while receiving messages.") 46 | 47 | 48 | async def send_messages(websocket, audio_queue): 49 | """Send audio data from microphone to WebSocket server.""" 50 | try: 51 | # Start by draining the queue to avoid lags 52 | while not audio_queue.empty(): 53 | await audio_queue.get() 54 | 55 | print("Starting the transcription") 56 | 57 | while True: 58 | audio_data = await audio_queue.get() 59 | chunk = {"type": "Audio", "pcm": [float(x) for x in audio_data]} 60 | msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True) 61 | await websocket.send(msg) 62 | 63 | except websockets.ConnectionClosed: 64 | print("Connection closed while sending messages.") 65 | 66 | 67 | async def stream_audio(url: str, api_key: str, show_vad: bool): 68 | """Stream audio data to a WebSocket server.""" 69 | print("Starting microphone recording...") 70 | print("Press Ctrl+C to stop recording") 71 | audio_queue = asyncio.Queue() 72 | 73 | loop = asyncio.get_event_loop() 74 | 75 | def audio_callback(indata, frames, time, status): 76 | loop.call_soon_threadsafe( 77 | audio_queue.put_nowait, indata[:, 0].astype(np.float32).copy() 78 | ) 79 | 80 | # Start audio stream 81 | with sd.InputStream( 82 | samplerate=SAMPLE_RATE, 83 | channels=1, 84 | dtype="float32", 85 | callback=audio_callback, 86 | blocksize=1920, # 80ms blocks 87 | ): 88 | headers = {"kyutai-api-key": api_key} 89 | # Instead of using the header, you can authenticate by adding `?auth_id={api_key}` to the URL 90 | async with websockets.connect(url, additional_headers=headers) as websocket: 91 | send_task = asyncio.create_task(send_messages(websocket, audio_queue)) 92 | receive_task = asyncio.create_task( 93 | receive_messages(websocket, show_vad=show_vad) 94 | ) 95 | await asyncio.gather(send_task, receive_task) 96 | 97 | 98 | if __name__ == "__main__": 99 | parser = argparse.ArgumentParser(description="Real-time microphone transcription") 100 | parser.add_argument( 101 | "--url", 102 | help="The URL of the server to which to send the audio", 103 | default="ws://127.0.0.1:8080", 104 | ) 105 | parser.add_argument("--api-key", default="public_token") 106 | parser.add_argument( 107 | "--list-devices", action="store_true", help="List available audio devices" 108 | ) 109 | parser.add_argument( 110 | "--device", type=int, help="Input device ID (use --list-devices to see options)" 111 | ) 112 | parser.add_argument( 113 | "--show-vad", 114 | action="store_true", 115 | help="Visualize the predictions of the semantic voice activity detector with a '|' symbol", 116 | ) 117 | 118 | args = parser.parse_args() 119 | 120 | def handle_sigint(signum, frame): 121 | print("Interrupted by user") # Don't complain about KeyboardInterrupt 122 | exit(0) 123 | 124 | signal.signal(signal.SIGINT, handle_sigint) 125 | 126 | if args.list_devices: 127 | print("Available audio devices:") 128 | print(sd.query_devices()) 129 | exit(0) 130 | 131 | if args.device is not None: 132 | sd.default.device[0] = args.device # Set input device 133 | 134 | url = f"{args.url}/api/asr-streaming" 135 | asyncio.run(stream_audio(url, args.api_key, args.show_vad)) 136 | -------------------------------------------------------------------------------- /scripts/tts_mlx.py: -------------------------------------------------------------------------------- 1 | # /// script 2 | # requires-python = ">=3.12" 3 | # dependencies = [ 4 | # "huggingface_hub", 5 | # "moshi_mlx==0.2.9", 6 | # "numpy", 7 | # "sounddevice", 8 | # ] 9 | # /// 10 | 11 | import argparse 12 | import json 13 | import queue 14 | import sys 15 | import time 16 | 17 | import mlx.core as mx 18 | import mlx.nn as nn 19 | import numpy as np 20 | import sentencepiece 21 | import sounddevice as sd 22 | import sphn 23 | from moshi_mlx import models 24 | from moshi_mlx.client_utils import make_log 25 | from moshi_mlx.models.tts import ( 26 | DEFAULT_DSM_TTS_REPO, 27 | DEFAULT_DSM_TTS_VOICE_REPO, 28 | TTSModel, 29 | ) 30 | from moshi_mlx.utils.loaders import hf_get 31 | 32 | 33 | def log(level: str, msg: str): 34 | print(make_log(level, msg)) 35 | 36 | 37 | def main(): 38 | parser = argparse.ArgumentParser( 39 | description="Run Kyutai TTS using the PyTorch implementation" 40 | ) 41 | parser.add_argument("inp", type=str, help="Input file, use - for stdin") 42 | parser.add_argument( 43 | "out", type=str, help="Output file to generate, use - for playing the audio" 44 | ) 45 | parser.add_argument( 46 | "--hf-repo", 47 | type=str, 48 | default=DEFAULT_DSM_TTS_REPO, 49 | help="HF repo in which to look for the pretrained models.", 50 | ) 51 | parser.add_argument( 52 | "--voice-repo", 53 | default=DEFAULT_DSM_TTS_VOICE_REPO, 54 | help="HF repo in which to look for pre-computed voice embeddings.", 55 | ) 56 | parser.add_argument( 57 | "--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav" 58 | ) 59 | parser.add_argument( 60 | "--quantize", 61 | type=int, 62 | help="The quantization to be applied, e.g. 8 for 8 bits.", 63 | ) 64 | args = parser.parse_args() 65 | 66 | mx.random.seed(299792458) 67 | 68 | log("info", "retrieving checkpoints") 69 | 70 | raw_config = hf_get("config.json", args.hf_repo) 71 | with open(hf_get(raw_config), "r") as fobj: 72 | raw_config = json.load(fobj) 73 | 74 | mimi_weights = hf_get(raw_config["mimi_name"], args.hf_repo) 75 | moshi_name = raw_config.get("moshi_name", "model.safetensors") 76 | moshi_weights = hf_get(moshi_name, args.hf_repo) 77 | tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo) 78 | lm_config = models.LmConfig.from_config_dict(raw_config) 79 | model = models.Lm(lm_config) 80 | model.set_dtype(mx.bfloat16) 81 | 82 | log("info", f"loading model weights from {moshi_weights}") 83 | model.load_pytorch_weights(str(moshi_weights), lm_config, strict=True) 84 | 85 | if args.quantize is not None: 86 | log("info", f"quantizing model to {args.quantize} bits") 87 | nn.quantize(model.depformer, bits=args.quantize) 88 | for layer in model.transformer.layers: 89 | nn.quantize(layer.self_attn, bits=args.quantize) 90 | nn.quantize(layer.gating, bits=args.quantize) 91 | 92 | log("info", f"loading the text tokenizer from {tokenizer}") 93 | text_tokenizer = sentencepiece.SentencePieceProcessor(str(tokenizer)) # type: ignore 94 | 95 | log("info", f"loading the audio tokenizer {mimi_weights}") 96 | generated_codebooks = lm_config.generated_codebooks 97 | audio_tokenizer = models.mimi.Mimi(models.mimi_202407(generated_codebooks)) 98 | audio_tokenizer.load_pytorch_weights(str(mimi_weights), strict=True) 99 | 100 | cfg_coef_conditioning = None 101 | tts_model = TTSModel( 102 | model, 103 | audio_tokenizer, 104 | text_tokenizer, 105 | voice_repo=args.voice_repo, 106 | temp=0.6, 107 | cfg_coef=1, 108 | max_padding=8, 109 | initial_padding=2, 110 | final_padding=2, 111 | padding_bonus=0, 112 | raw_config=raw_config, 113 | ) 114 | if tts_model.valid_cfg_conditionings: 115 | # Model was trained with CFG distillation. 116 | cfg_coef_conditioning = tts_model.cfg_coef 117 | tts_model.cfg_coef = 1.0 118 | cfg_is_no_text = False 119 | cfg_is_no_prefix = False 120 | else: 121 | cfg_is_no_text = True 122 | cfg_is_no_prefix = True 123 | mimi = tts_model.mimi 124 | 125 | log("info", f"reading input from {args.inp}") 126 | if args.inp == "-": 127 | if sys.stdin.isatty(): # Interactive 128 | print("Enter text to synthesize (Ctrl+D to end input):") 129 | text_to_tts = sys.stdin.read().strip() 130 | else: 131 | with open(args.inp, "r") as fobj: 132 | text_to_tts = fobj.read().strip() 133 | 134 | all_entries = [tts_model.prepare_script([text_to_tts])] 135 | if tts_model.multi_speaker: 136 | voices = [tts_model.get_voice_path(args.voice)] 137 | else: 138 | voices = [] 139 | all_attributes = [ 140 | tts_model.make_condition_attributes(voices, cfg_coef_conditioning) 141 | ] 142 | 143 | wav_frames = queue.Queue() 144 | 145 | def _on_frame(frame): 146 | if (frame == -1).any(): 147 | return 148 | _pcm = tts_model.mimi.decode_step(frame[:, :, None]) 149 | _pcm = np.array(mx.clip(_pcm[0, 0], -1, 1)) 150 | wav_frames.put_nowait(_pcm) 151 | 152 | def run(): 153 | log("info", "starting the inference loop") 154 | begin = time.time() 155 | result = tts_model.generate( 156 | all_entries, 157 | all_attributes, 158 | cfg_is_no_prefix=cfg_is_no_prefix, 159 | cfg_is_no_text=cfg_is_no_text, 160 | on_frame=_on_frame, 161 | ) 162 | frames = mx.concat(result.frames, axis=-1) 163 | total_duration = frames.shape[0] * frames.shape[-1] / mimi.frame_rate 164 | time_taken = time.time() - begin 165 | total_speed = total_duration / time_taken 166 | log("info", f"[LM] took {time_taken:.2f}s, total speed {total_speed:.2f}x") 167 | return result 168 | 169 | if args.out == "-": 170 | 171 | def audio_callback(outdata, _a, _b, _c): 172 | try: 173 | pcm_data = wav_frames.get(block=False) 174 | outdata[:, 0] = pcm_data 175 | except queue.Empty: 176 | outdata[:] = 0 177 | 178 | with sd.OutputStream( 179 | samplerate=mimi.sample_rate, 180 | blocksize=1920, 181 | channels=1, 182 | callback=audio_callback, 183 | ): 184 | run() 185 | time.sleep(3) 186 | while True: 187 | if wav_frames.qsize() == 0: 188 | break 189 | time.sleep(1) 190 | else: 191 | run() 192 | frames = [] 193 | while True: 194 | try: 195 | frames.append(wav_frames.get_nowait()) 196 | except queue.Empty: 197 | break 198 | wav = np.concat(frames, -1) 199 | sphn.write_wav(args.out, wav, mimi.sample_rate) 200 | 201 | 202 | if __name__ == "__main__": 203 | main() 204 | -------------------------------------------------------------------------------- /scripts/tts_pytorch.py: -------------------------------------------------------------------------------- 1 | # /// script 2 | # requires-python = ">=3.12" 3 | # dependencies = [ 4 | # "moshi==0.2.8", 5 | # "torch", 6 | # "sphn", 7 | # "sounddevice", 8 | # ] 9 | # /// 10 | import argparse 11 | import sys 12 | 13 | import numpy as np 14 | import queue 15 | import sphn 16 | import time 17 | import torch 18 | from moshi.models.loaders import CheckpointInfo 19 | from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel 20 | 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser( 24 | description="Run Kyutai TTS using the PyTorch implementation" 25 | ) 26 | parser.add_argument("inp", type=str, help="Input file, use - for stdin.") 27 | parser.add_argument( 28 | "out", type=str, help="Output file to generate, use - for playing the audio" 29 | ) 30 | parser.add_argument( 31 | "--hf-repo", 32 | type=str, 33 | default=DEFAULT_DSM_TTS_REPO, 34 | help="HF repo in which to look for the pretrained models.", 35 | ) 36 | parser.add_argument( 37 | "--voice-repo", 38 | default=DEFAULT_DSM_TTS_VOICE_REPO, 39 | help="HF repo in which to look for pre-computed voice embeddings.", 40 | ) 41 | parser.add_argument( 42 | "--voice", 43 | default="expresso/ex03-ex01_happy_001_channel1_334s.wav", 44 | help="The voice to use, relative to the voice repo root. " 45 | f"See {DEFAULT_DSM_TTS_VOICE_REPO}", 46 | ) 47 | parser.add_argument( 48 | "--device", 49 | type=str, 50 | default="cuda", 51 | help="Device on which to run, defaults to 'cuda'.", 52 | ) 53 | args = parser.parse_args() 54 | 55 | print("Loading model...") 56 | checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo) 57 | tts_model = TTSModel.from_checkpoint_info( 58 | checkpoint_info, n_q=32, temp=0.6, device=args.device 59 | ) 60 | 61 | if args.inp == "-": 62 | if sys.stdin.isatty(): # Interactive 63 | print("Enter text to synthesize (Ctrl+D to end input):") 64 | text = sys.stdin.read().strip() 65 | else: 66 | with open(args.inp, "r") as fobj: 67 | text = fobj.read().strip() 68 | 69 | # If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...] 70 | entries = tts_model.prepare_script([text], padding_between=1) 71 | voice_path = tts_model.get_voice_path(args.voice) 72 | # CFG coef goes here because the model was trained with CFG distillation, 73 | # so it's not _actually_ doing CFG at inference time. 74 | # Also, if you are generating a dialog, you should have two voices in the list. 75 | condition_attributes = tts_model.make_condition_attributes( 76 | [voice_path], cfg_coef=2.0 77 | ) 78 | 79 | if args.out == "-": 80 | # Stream the audio to the speakers using sounddevice. 81 | import sounddevice as sd 82 | 83 | pcms = queue.Queue() 84 | 85 | def _on_frame(frame): 86 | if (frame != -1).all(): 87 | pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy() 88 | pcms.put_nowait(np.clip(pcm[0, 0], -1, 1)) 89 | 90 | def audio_callback(outdata, _a, _b, _c): 91 | try: 92 | pcm_data = pcms.get(block=False) 93 | outdata[:, 0] = pcm_data 94 | except queue.Empty: 95 | outdata[:] = 0 96 | 97 | with sd.OutputStream( 98 | samplerate=tts_model.mimi.sample_rate, 99 | blocksize=1920, 100 | channels=1, 101 | callback=audio_callback, 102 | ): 103 | with tts_model.mimi.streaming(1): 104 | tts_model.generate( 105 | [entries], [condition_attributes], on_frame=_on_frame 106 | ) 107 | time.sleep(3) 108 | while True: 109 | if pcms.qsize() == 0: 110 | break 111 | time.sleep(1) 112 | else: 113 | result = tts_model.generate([entries], [condition_attributes]) 114 | with tts_model.mimi.streaming(1), torch.no_grad(): 115 | pcms = [] 116 | for frame in result.frames[tts_model.delay_steps :]: 117 | pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy() 118 | pcms.append(np.clip(pcm[0, 0], -1, 1)) 119 | pcm = np.concatenate(pcms, axis=-1) 120 | sphn.write_wav(args.out, pcm, tts_model.mimi.sample_rate) 121 | 122 | 123 | if __name__ == "__main__": 124 | main() 125 | -------------------------------------------------------------------------------- /scripts/tts_rust_server.py: -------------------------------------------------------------------------------- 1 | # /// script 2 | # requires-python = ">=3.12" 3 | # dependencies = [ 4 | # "msgpack", 5 | # "numpy", 6 | # "sphn", 7 | # "websockets", 8 | # "sounddevice", 9 | # "tqdm", 10 | # ] 11 | # /// 12 | import argparse 13 | import asyncio 14 | import sys 15 | from urllib.parse import urlencode 16 | 17 | import msgpack 18 | import numpy as np 19 | import sounddevice as sd 20 | import sphn 21 | import tqdm 22 | import websockets 23 | 24 | SAMPLE_RATE = 24000 25 | 26 | TTS_TEXT = "Hello, this is a test of the moshi text to speech system, this should result in some nicely sounding generated voice." 27 | DEFAULT_DSM_TTS_VOICE_REPO = "kyutai/tts-voices" 28 | AUTH_TOKEN = "public_token" 29 | 30 | 31 | async def receive_messages(websocket: websockets.ClientConnection, output_queue): 32 | with tqdm.tqdm(desc="Receiving audio", unit=" seconds generated") as pbar: 33 | accumulated_samples = 0 34 | last_seconds = 0 35 | 36 | async for message_bytes in websocket: 37 | msg = msgpack.unpackb(message_bytes) 38 | 39 | if msg["type"] == "Audio": 40 | pcm = np.array(msg["pcm"]).astype(np.float32) 41 | await output_queue.put(pcm) 42 | 43 | accumulated_samples += len(msg["pcm"]) 44 | current_seconds = accumulated_samples // SAMPLE_RATE 45 | if current_seconds > last_seconds: 46 | pbar.update(current_seconds - last_seconds) 47 | last_seconds = current_seconds 48 | 49 | print("End of audio.") 50 | await output_queue.put(None) # Signal end of audio 51 | 52 | 53 | async def output_audio(out: str, output_queue: asyncio.Queue[np.ndarray | None]): 54 | if out == "-": 55 | should_exit = False 56 | 57 | def audio_callback(outdata, _a, _b, _c): 58 | nonlocal should_exit 59 | 60 | try: 61 | pcm_data = output_queue.get_nowait() 62 | if pcm_data is not None: 63 | outdata[:, 0] = pcm_data 64 | else: 65 | should_exit = True 66 | outdata[:] = 0 67 | except asyncio.QueueEmpty: 68 | outdata[:] = 0 69 | 70 | with sd.OutputStream( 71 | samplerate=SAMPLE_RATE, 72 | blocksize=1920, 73 | channels=1, 74 | callback=audio_callback, 75 | ): 76 | while True: 77 | if should_exit: 78 | break 79 | await asyncio.sleep(1) 80 | else: 81 | frames = [] 82 | while True: 83 | item = await output_queue.get() 84 | if item is None: 85 | break 86 | frames.append(item) 87 | 88 | sphn.write_wav(out, np.concat(frames, -1), SAMPLE_RATE) 89 | print(f"Saved audio to {out}") 90 | 91 | 92 | async def websocket_client(): 93 | parser = argparse.ArgumentParser(description="Use the TTS streaming API") 94 | parser.add_argument("inp", type=str, help="Input file, use - for stdin.") 95 | parser.add_argument( 96 | "out", type=str, help="Output file to generate, use - for playing the audio" 97 | ) 98 | parser.add_argument( 99 | "--voice", 100 | default="expresso/ex03-ex01_happy_001_channel1_334s.wav", 101 | help="The voice to use, relative to the voice repo root. " 102 | f"See {DEFAULT_DSM_TTS_VOICE_REPO}", 103 | ) 104 | parser.add_argument( 105 | "--url", 106 | help="The URL of the server to which to send the audio", 107 | default="ws://127.0.0.1:8080", 108 | ) 109 | parser.add_argument("--api-key", default="public_token") 110 | args = parser.parse_args() 111 | 112 | params = {"voice": args.voice, "format": "PcmMessagePack"} 113 | uri = f"{args.url}/api/tts_streaming?{urlencode(params)}" 114 | print(uri) 115 | 116 | # TODO: stream the text instead of sending it all at once 117 | if args.inp == "-": 118 | if sys.stdin.isatty(): # Interactive 119 | print("Enter text to synthesize (Ctrl+D to end input):") 120 | text_to_tts = sys.stdin.read().strip() 121 | else: 122 | with open(args.inp, "r") as fobj: 123 | text_to_tts = fobj.read().strip() 124 | 125 | headers = {"kyutai-api-key": args.api_key} 126 | 127 | async with websockets.connect(uri, additional_headers=headers) as websocket: 128 | await websocket.send(msgpack.packb({"type": "Text", "text": text_to_tts})) 129 | await websocket.send(msgpack.packb({"type": "Eos"})) 130 | 131 | output_queue = asyncio.Queue() 132 | receive_task = asyncio.create_task(receive_messages(websocket, output_queue)) 133 | output_audio_task = asyncio.create_task(output_audio(args.out, output_queue)) 134 | await asyncio.gather(receive_task, output_audio_task) 135 | 136 | 137 | if __name__ == "__main__": 138 | asyncio.run(websocket_client()) 139 | -------------------------------------------------------------------------------- /stt-rs/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "kyutai-stt-rs" 3 | version = "0.1.0" 4 | edition = "2024" 5 | 6 | [dependencies] 7 | anyhow = "1.0" 8 | candle = { version = "0.9.1", package = "candle-core" } 9 | candle-nn = "0.9.1" 10 | clap = { version = "4.4.12", features = ["derive"] } 11 | hf-hub = "0.4.3" 12 | kaudio = "0.2.1" 13 | moshi = "0.6.1" 14 | sentencepiece = "0.11.3" 15 | serde = { version = "1.0.210", features = ["derive"] } 16 | serde_json = "1.0.115" 17 | 18 | [features] 19 | default = [] 20 | cuda = ["candle/cuda", "candle-nn/cuda"] 21 | cudnn = ["candle/cudnn", "candle-nn/cudnn"] 22 | metal = ["candle/metal", "candle-nn/metal"] 23 | 24 | [profile.release] 25 | debug = true 26 | 27 | [profile.release-no-debug] 28 | inherits = "release" 29 | debug = false 30 | 31 | -------------------------------------------------------------------------------- /stt-rs/src/main.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) Kyutai, all rights reserved. 2 | // This source code is licensed under the license found in the 3 | // LICENSE file in the root directory of this source tree. 4 | 5 | use anyhow::Result; 6 | use candle::{Device, Tensor}; 7 | use clap::Parser; 8 | 9 | #[derive(Debug, Parser)] 10 | struct Args { 11 | /// The audio input file, in wav/mp3/ogg/... format. 12 | in_file: String, 13 | 14 | /// The repo where to get the model from. 15 | #[arg(long, default_value = "kyutai/stt-1b-en_fr-candle")] 16 | hf_repo: String, 17 | 18 | /// Run the model on cpu. 19 | #[arg(long)] 20 | cpu: bool, 21 | 22 | /// Display word level timestamps. 23 | #[arg(long)] 24 | timestamps: bool, 25 | 26 | /// Display the level of voice activity detection (VAD). 27 | #[arg(long)] 28 | vad: bool, 29 | } 30 | 31 | fn device(cpu: bool) -> Result<Device> { 32 | if cpu { 33 | Ok(Device::Cpu) 34 | } else if candle::utils::cuda_is_available() { 35 | Ok(Device::new_cuda(0)?) 36 | } else if candle::utils::metal_is_available() { 37 | Ok(Device::new_metal(0)?) 38 | } else { 39 | Ok(Device::Cpu) 40 | } 41 | } 42 | 43 | #[derive(Debug, serde::Deserialize)] 44 | struct SttConfig { 45 | audio_silence_prefix_seconds: f64, 46 | audio_delay_seconds: f64, 47 | } 48 | 49 | #[derive(Debug, serde::Deserialize)] 50 | struct Config { 51 | mimi_name: String, 52 | tokenizer_name: String, 53 | card: usize, 54 | text_card: usize, 55 | dim: usize, 56 | n_q: usize, 57 | context: usize, 58 | max_period: f64, 59 | num_heads: usize, 60 | num_layers: usize, 61 | causal: bool, 62 | stt_config: SttConfig, 63 | } 64 | 65 | impl Config { 66 | fn model_config(&self, vad: bool) -> moshi::lm::Config { 67 | let lm_cfg = moshi::transformer::Config { 68 | d_model: self.dim, 69 | num_heads: self.num_heads, 70 | num_layers: self.num_layers, 71 | dim_feedforward: self.dim * 4, 72 | causal: self.causal, 73 | norm_first: true, 74 | bias_ff: false, 75 | bias_attn: false, 76 | layer_scale: None, 77 | context: self.context, 78 | max_period: self.max_period as usize, 79 | use_conv_block: false, 80 | use_conv_bias: true, 81 | cross_attention: None, 82 | gating: Some(candle_nn::Activation::Silu), 83 | norm: moshi::NormType::RmsNorm, 84 | positional_embedding: moshi::transformer::PositionalEmbedding::Rope, 85 | conv_layout: false, 86 | conv_kernel_size: 3, 87 | kv_repeat: 1, 88 | max_seq_len: 4096 * 4, 89 | shared_cross_attn: false, 90 | }; 91 | let extra_heads = if vad { 92 | Some(moshi::lm::ExtraHeadsConfig { 93 | num_heads: 4, 94 | dim: 6, 95 | }) 96 | } else { 97 | None 98 | }; 99 | moshi::lm::Config { 100 | transformer: lm_cfg, 101 | depformer: None, 102 | audio_vocab_size: self.card + 1, 103 | text_in_vocab_size: self.text_card + 1, 104 | text_out_vocab_size: self.text_card, 105 | audio_codebooks: self.n_q, 106 | conditioners: Default::default(), 107 | extra_heads, 108 | } 109 | } 110 | } 111 | 112 | struct Model { 113 | state: moshi::asr::State, 114 | text_tokenizer: sentencepiece::SentencePieceProcessor, 115 | timestamps: bool, 116 | vad: bool, 117 | config: Config, 118 | dev: Device, 119 | } 120 | 121 | impl Model { 122 | fn load_from_hf(args: &Args, dev: &Device) -> Result<Self> { 123 | let dtype = dev.bf16_default_to_f32(); 124 | 125 | // Retrieve the model files from the Hugging Face Hub 126 | let api = hf_hub::api::sync::Api::new()?; 127 | let repo = api.model(args.hf_repo.to_string()); 128 | let config_file = repo.get("config.json")?; 129 | let config: Config = serde_json::from_str(&std::fs::read_to_string(&config_file)?)?; 130 | let tokenizer_file = repo.get(&config.tokenizer_name)?; 131 | let model_file = repo.get("model.safetensors")?; 132 | let mimi_file = repo.get(&config.mimi_name)?; 133 | 134 | let text_tokenizer = sentencepiece::SentencePieceProcessor::open(&tokenizer_file)?; 135 | let vb_lm = 136 | unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], dtype, dev)? }; 137 | let audio_tokenizer = moshi::mimi::load(mimi_file.to_str().unwrap(), Some(32), dev)?; 138 | let lm = moshi::lm::LmModel::new( 139 | &config.model_config(args.vad), 140 | moshi::nn::MaybeQuantizedVarBuilder::Real(vb_lm), 141 | )?; 142 | let asr_delay_in_tokens = (config.stt_config.audio_delay_seconds * 12.5) as usize; 143 | let state = moshi::asr::State::new(1, asr_delay_in_tokens, 0., audio_tokenizer, lm)?; 144 | Ok(Model { 145 | state, 146 | config, 147 | text_tokenizer, 148 | timestamps: args.timestamps, 149 | vad: args.vad, 150 | dev: dev.clone(), 151 | }) 152 | } 153 | 154 | fn run(&mut self, mut pcm: Vec<f32>) -> Result<()> { 155 | use std::io::Write; 156 | 157 | // Add the silence prefix to the audio. 158 | if self.config.stt_config.audio_silence_prefix_seconds > 0.0 { 159 | let silence_len = 160 | (self.config.stt_config.audio_silence_prefix_seconds * 24000.0) as usize; 161 | pcm.splice(0..0, vec![0.0; silence_len]); 162 | } 163 | // Add some silence at the end to ensure all the audio is processed. 164 | let suffix = (self.config.stt_config.audio_delay_seconds * 24000.0) as usize; 165 | pcm.resize(pcm.len() + suffix + 24000, 0.0); 166 | 167 | let mut last_word = None; 168 | let mut printed_eot = false; 169 | for pcm in pcm.chunks(1920) { 170 | let pcm = Tensor::new(pcm, &self.dev)?.reshape((1, 1, ()))?; 171 | let asr_msgs = self.state.step_pcm(pcm, None, &().into(), |_, _, _| ())?; 172 | for asr_msg in asr_msgs.iter() { 173 | match asr_msg { 174 | moshi::asr::AsrMsg::Step { prs, .. } => { 175 | // prs is the probability of having no voice activity for different time 176 | // horizons. 177 | // In kyutai/stt-1b-en_fr-candle, these horizons are 0.5s, 1s, 2s, and 3s. 178 | if self.vad && prs[2][0] > 0.5 && !printed_eot { 179 | printed_eot = true; 180 | if !self.timestamps { 181 | print!(" <endofturn pr={}>", prs[2][0]); 182 | } else { 183 | println!("<endofturn pr={}>", prs[2][0]); 184 | } 185 | } 186 | } 187 | moshi::asr::AsrMsg::EndWord { stop_time, .. } => { 188 | printed_eot = false; 189 | if self.timestamps { 190 | if let Some((word, start_time)) = last_word.take() { 191 | println!("[{start_time:5.2}-{stop_time:5.2}] {word}"); 192 | } 193 | } 194 | } 195 | moshi::asr::AsrMsg::Word { 196 | tokens, start_time, .. 197 | } => { 198 | printed_eot = false; 199 | let word = self 200 | .text_tokenizer 201 | .decode_piece_ids(tokens) 202 | .unwrap_or_else(|_| String::new()); 203 | if !self.timestamps { 204 | print!(" {word}"); 205 | std::io::stdout().flush()? 206 | } else { 207 | if let Some((word, prev_start_time)) = last_word.take() { 208 | println!("[{prev_start_time:5.2}-{start_time:5.2}] {word}"); 209 | } 210 | last_word = Some((word, *start_time)); 211 | } 212 | } 213 | } 214 | } 215 | } 216 | if let Some((word, start_time)) = last_word.take() { 217 | println!("[{start_time:5.2}- ] {word}"); 218 | } 219 | println!(); 220 | Ok(()) 221 | } 222 | } 223 | 224 | fn main() -> Result<()> { 225 | let args = Args::parse(); 226 | let device = device(args.cpu)?; 227 | println!("Using device: {:?}", device); 228 | 229 | println!("Loading audio file from: {}", args.in_file); 230 | let (pcm, sample_rate) = kaudio::pcm_decode(&args.in_file)?; 231 | let pcm = if sample_rate != 24_000 { 232 | kaudio::resample(&pcm, sample_rate as usize, 24_000)? 233 | } else { 234 | pcm 235 | }; 236 | println!("Loading model from repository: {}", args.hf_repo); 237 | let mut model = Model::load_from_hf(&args, &device)?; 238 | println!("Running inference"); 239 | model.run(pcm)?; 240 | Ok(()) 241 | } 242 | -------------------------------------------------------------------------------- /stt_pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "colab": { 8 | "base_uri": "https://localhost:8080/" 9 | }, 10 | "id": "gJEMjPgeI-rw", 11 | "outputId": "7491c067-b1be-4505-b3f5-19ba4c00a593" 12 | }, 13 | "outputs": [], 14 | "source": [ 15 | "!pip install moshi" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": { 22 | "colab": { 23 | "base_uri": "https://localhost:8080/" 24 | }, 25 | "id": "CA4K5iDFJcqJ", 26 | "outputId": "b609843a-a193-4729-b099-5f8780532333" 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "!wget https://github.com/kyutai-labs/moshi/raw/refs/heads/main/data/sample_fr_hibiki_crepes.mp3" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": { 37 | "id": "VA3Haix3IZ8Q" 38 | }, 39 | "outputs": [], 40 | "source": [ 41 | "from dataclasses import dataclass\n", 42 | "import time\n", 43 | "import sentencepiece\n", 44 | "import sphn\n", 45 | "import textwrap\n", 46 | "import torch\n", 47 | "\n", 48 | "from moshi.models import loaders, MimiModel, LMModel, LMGen" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": { 55 | "id": "9AK5zBMTI9bw" 56 | }, 57 | "outputs": [], 58 | "source": [ 59 | "@dataclass\n", 60 | "class InferenceState:\n", 61 | " mimi: MimiModel\n", 62 | " text_tokenizer: sentencepiece.SentencePieceProcessor\n", 63 | " lm_gen: LMGen\n", 64 | "\n", 65 | " def __init__(\n", 66 | " self,\n", 67 | " mimi: MimiModel,\n", 68 | " text_tokenizer: sentencepiece.SentencePieceProcessor,\n", 69 | " lm: LMModel,\n", 70 | " batch_size: int,\n", 71 | " device: str | torch.device,\n", 72 | " ):\n", 73 | " self.mimi = mimi\n", 74 | " self.text_tokenizer = text_tokenizer\n", 75 | " self.lm_gen = LMGen(lm, temp=0, temp_text=0, use_sampling=False)\n", 76 | " self.device = device\n", 77 | " self.frame_size = int(self.mimi.sample_rate / self.mimi.frame_rate)\n", 78 | " self.batch_size = batch_size\n", 79 | " self.mimi.streaming_forever(batch_size)\n", 80 | " self.lm_gen.streaming_forever(batch_size)\n", 81 | "\n", 82 | " def run(self, in_pcms: torch.Tensor):\n", 83 | " device = self.lm_gen.lm_model.device\n", 84 | " ntokens = 0\n", 85 | " first_frame = True\n", 86 | " chunks = [\n", 87 | " c\n", 88 | " for c in in_pcms.split(self.frame_size, dim=2)\n", 89 | " if c.shape[-1] == self.frame_size\n", 90 | " ]\n", 91 | " start_time = time.time()\n", 92 | " all_text = []\n", 93 | " for chunk in chunks:\n", 94 | " codes = self.mimi.encode(chunk)\n", 95 | " if first_frame:\n", 96 | " # Ensure that the first slice of codes is properly seen by the transformer\n", 97 | " # as otherwise the first slice is replaced by the initial tokens.\n", 98 | " tokens = self.lm_gen.step(codes)\n", 99 | " first_frame = False\n", 100 | " tokens = self.lm_gen.step(codes)\n", 101 | " if tokens is None:\n", 102 | " continue\n", 103 | " assert tokens.shape[1] == 1\n", 104 | " one_text = tokens[0, 0].cpu()\n", 105 | " if one_text.item() not in [0, 3]:\n", 106 | " text = self.text_tokenizer.id_to_piece(one_text.item())\n", 107 | " text = text.replace(\"▁\", \" \")\n", 108 | " all_text.append(text)\n", 109 | " ntokens += 1\n", 110 | " dt = time.time() - start_time\n", 111 | " print(\n", 112 | " f\"processed {ntokens} steps in {dt:.0f}s, {1000 * dt / ntokens:.2f}ms/step\"\n", 113 | " )\n", 114 | " return \"\".join(all_text)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": { 121 | "colab": { 122 | "base_uri": "https://localhost:8080/", 123 | "height": 353, 124 | "referenced_widgets": [ 125 | "0a5f6f887e2b4cd1990a0e9ec0153ed9", 126 | "f7893826fcba4bdc87539589d669249b", 127 | "8805afb12c484781be85082ff02dad13", 128 | "97679c0d9ab44bed9a3456f2fcb541fd", 129 | "d73c0321bed54a52b5e1da0a7788e32a", 130 | "d67be13a920d4fc89e5570b5b29fc1d2", 131 | "6b377c2d7bf945fb89e46c39d246a332", 132 | "b82ff365c78e41ad8094b46daf79449d", 133 | "477aa7fa82dc42d5bce6f1743c45d626", 134 | "cbd288510c474430beb66f346f382c45", 135 | "aafc347cdf28428ea6a7abe5b46b726f", 136 | "fca09acd5d0d45468c8b04bfb2de7646", 137 | "79e35214b51b4a9e9b3f7144b0b34f7b", 138 | "89e9a37f69904bd48b954d627bff6687", 139 | "57028789c78248a7b0ad4f031c9545c9", 140 | "1150fcb427994c2984d4d0f4e4745fe5", 141 | "e24b1fc52f294f849019c9b3befb613f", 142 | "8724878682cf4c3ca992667c45009398", 143 | "36a22c977d5242008871310133b7d2af", 144 | "5b3683cad5cb4877b43fadd003edf97f", 145 | "703f98272e4d469d8f27f5a465715dd8", 146 | "9dbe02ef5fac41cfaee3d02946e65c88", 147 | "37faa87ad03a4271992c21ce6a629e18", 148 | "570c547e48cd421b814b2c5e028e4c0b", 149 | "b173768580fc4c0a8e3abf272e4c363a", 150 | "e57d1620f0a9427b85d8b4885ef4e8e3", 151 | "5dd4474df70743498b616608182714dd", 152 | "cc907676a65f4ad1bf68a77b4a00e89b", 153 | "a34abc3b118e4305951a466919c28ff6", 154 | "a77ccfcdb90146c7a63b4b2d232bc494", 155 | "f7313e6e3a27475993cab3961d6ae363", 156 | "39b47fad9c554839868fe9e4bbf7def2", 157 | "14e9511ea0bd44c49f0cf3abf1a6d40e", 158 | "a4ea8e0c4cac4d5e88b7e3f527e4fe90", 159 | "571afc0f4b2840c9830d6b5a307ed1f9", 160 | "6ec593cab5b64f0ea638bb175b9daa5c", 161 | "77a52aed00ae408bb24524880e19ec8a", 162 | "0b2de4b29b4b44fe9d96361a40c793d0", 163 | "3c5b5fb1a5ac468a89c1058bd90cfb58", 164 | "e53e0a2a240e43cfa562c89b3d703dea", 165 | "35966343cf9249ef8bc028a0d5c5f97d", 166 | "e36a37e0d41c47ccb8bc6d56c19fb17c", 167 | "279ccf7de43847a1a6579c9182a46cc8", 168 | "41b5d6ab0b7d43c790a55f125c0e7494" 169 | ] 170 | }, 171 | "id": "UsQJdAgkLp9n", 172 | "outputId": "9b7131c3-69c5-4323-8312-2ce7621d8869" 173 | }, 174 | "outputs": [], 175 | "source": [ 176 | "device = \"cuda\"\n", 177 | "# Use the en+fr low latency model, an alternative is kyutai/stt-2.6b-en\n", 178 | "checkpoint_info = loaders.CheckpointInfo.from_hf_repo(\"kyutai/stt-1b-en_fr\")\n", 179 | "mimi = checkpoint_info.get_mimi(device=device)\n", 180 | "text_tokenizer = checkpoint_info.get_text_tokenizer()\n", 181 | "lm = checkpoint_info.get_moshi(device=device)\n", 182 | "in_pcms, _ = sphn.read(\"sample_fr_hibiki_crepes.mp3\", sample_rate=mimi.sample_rate)\n", 183 | "in_pcms = torch.from_numpy(in_pcms).to(device=device)\n", 184 | "\n", 185 | "stt_config = checkpoint_info.stt_config\n", 186 | "pad_left = int(stt_config.get(\"audio_silence_prefix_seconds\", 0.0) * 24000)\n", 187 | "pad_right = int((stt_config.get(\"audio_delay_seconds\", 0.0) + 1.0) * 24000)\n", 188 | "in_pcms = torch.nn.functional.pad(in_pcms, (pad_left, pad_right), mode=\"constant\")\n", 189 | "in_pcms = in_pcms[None, 0:1].expand(1, -1, -1)\n", 190 | "\n", 191 | "state = InferenceState(mimi, text_tokenizer, lm, batch_size=1, device=device)\n", 192 | "text = state.run(in_pcms)\n", 193 | "print(textwrap.fill(text, width=100))" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "metadata": { 200 | "colab": { 201 | "base_uri": "https://localhost:8080/", 202 | "height": 75 203 | }, 204 | "id": "CIAXs9oaPrtj", 205 | "outputId": "94cc208c-2454-4dd4-a64e-d79025144af5" 206 | }, 207 | "outputs": [], 208 | "source": [ 209 | "from IPython.display import Audio\n", 210 | "\n", 211 | "Audio(\"sample_fr_hibiki_crepes.mp3\")" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "metadata": { 218 | "id": "qkUZ6CBKOdTa" 219 | }, 220 | "outputs": [], 221 | "source": [] 222 | } 223 | ], 224 | "metadata": { 225 | "accelerator": "GPU", 226 | "colab": { 227 | "gpuType": "L4", 228 | "provenance": [] 229 | }, 230 | "kernelspec": { 231 | "display_name": "Python 3 (ipykernel)", 232 | "language": "python", 233 | "name": "python3" 234 | } 235 | }, 236 | "nbformat": 4, 237 | "nbformat_minor": 0 238 | } 239 | -------------------------------------------------------------------------------- /tts_pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "0", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# Fast install, might break in the future.\n", 11 | "!pip install 'sphn<0.2'\n", 12 | "!pip install --no-deps \"moshi==0.2.8\"\n", 13 | "# Slow install (will download torch and cuda), but future proof.\n", 14 | "# !pip install \"moshi==0.2.8\"" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "id": "1", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "import argparse\n", 25 | "import sys\n", 26 | "\n", 27 | "import numpy as np\n", 28 | "import torch\n", 29 | "from moshi.models.loaders import CheckpointInfo\n", 30 | "from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel\n", 31 | "\n", 32 | "from IPython.display import display, Audio" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "id": "2", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "# Configuration\n", 43 | "text = \"Hey there! How are you? I had the craziest day today.\"\n", 44 | "voice = \"expresso/ex03-ex01_happy_001_channel1_334s.wav\"\n", 45 | "print(f\"See https://huggingface.co/{DEFAULT_DSM_TTS_VOICE_REPO} for available voices.\")" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "id": "3", 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "# Set everything up\n", 56 | "checkpoint_info = CheckpointInfo.from_hf_repo(DEFAULT_DSM_TTS_REPO)\n", 57 | "tts_model = TTSModel.from_checkpoint_info(\n", 58 | " checkpoint_info, n_q=32, temp=0.6, device=torch.device(\"cuda\")\n", 59 | ")\n", 60 | "\n", 61 | "# If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...]\n", 62 | "entries = tts_model.prepare_script([text], padding_between=1)\n", 63 | "voice_path = tts_model.get_voice_path(voice)\n", 64 | "# CFG coef goes here because the model was trained with CFG distillation,\n", 65 | "# so it's not _actually_ doing CFG at inference time.\n", 66 | "# Also, if you are generating a dialog, you should have two voices in the list.\n", 67 | "condition_attributes = tts_model.make_condition_attributes(\n", 68 | " [voice_path], cfg_coef=2.0\n", 69 | ")" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "id": "4", 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "print(\"Generating audio...\")\n", 80 | "\n", 81 | "pcms = []\n", 82 | "def _on_frame(frame):\n", 83 | " print(\"Step\", len(pcms), end=\"\\r\")\n", 84 | " if (frame != -1).all():\n", 85 | " pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()\n", 86 | " pcms.append(np.clip(pcm[0, 0], -1, 1))\n", 87 | "\n", 88 | "# You could also generate multiple audios at once by extending the following lists.\n", 89 | "all_entries = [entries]\n", 90 | "all_condition_attributes = [condition_attributes]\n", 91 | "with tts_model.mimi.streaming(len(all_entries)):\n", 92 | " result = tts_model.generate(all_entries, all_condition_attributes, on_frame=_on_frame)\n", 93 | "\n", 94 | "print(\"Done generating.\")\n", 95 | "audio = np.concatenate(pcms, axis=-1)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "id": "5", 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "display(\n", 106 | " Audio(audio, rate=tts_model.mimi.sample_rate, autoplay=True)\n", 107 | ")" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "id": "6", 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [] 117 | } 118 | ], 119 | "metadata": { 120 | "kernelspec": { 121 | "display_name": "Python 3 (ipykernel)", 122 | "language": "python", 123 | "name": "python3" 124 | }, 125 | "language_info": { 126 | "codemirror_mode": { 127 | "name": "ipython", 128 | "version": 3 129 | }, 130 | "file_extension": ".py", 131 | "mimetype": "text/x-python", 132 | "name": "python", 133 | "nbconvert_exporter": "python", 134 | "pygments_lexer": "ipython3", 135 | "version": "3.13.2" 136 | } 137 | }, 138 | "nbformat": 4, 139 | "nbformat_minor": 5 140 | } 141 | --------------------------------------------------------------------------------