The response has been limited to 50k tokens of the smallest files in the repo. You can remove this limitation by removing the max tokens filter.
├── .github
    ├── PULL_REQUEST_TEMPLATE.md
    ├── actions
    │   └── moshi_build
    │   │   └── action.yml
    └── workflows
    │   └── precommit.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .vscode
    └── settings.json
├── CONTRIBUTING.md
├── LICENSE-APACHE
├── LICENSE-MIT
├── README.md
├── audio
    ├── bria.mp3
    ├── loona.mp3
    └── sample_fr_hibiki_crepes.mp3
├── configs
    ├── config-stt-en-hf.toml
    ├── config-stt-en_fr-hf.toml
    └── config-tts.toml
├── scripts
    ├── stt_evaluate_on_dataset.py
    ├── stt_from_file_mlx.py
    ├── stt_from_file_pytorch.py
    ├── stt_from_file_rust_server.py
    ├── stt_from_file_with_prompt_pytorch.py
    ├── stt_from_mic_mlx.py
    ├── stt_from_mic_rust_server.py
    ├── tts_mlx.py
    ├── tts_pytorch.py
    └── tts_rust_server.py
├── stt-rs
    ├── Cargo.lock
    ├── Cargo.toml
    └── src
    │   └── main.rs
├── stt_pytorch.ipynb
└── tts_pytorch.ipynb


/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
 1 | ## Checklist
 2 | 
 3 | - [ ] Read CONTRIBUTING.md, and accept the CLA by including the provided snippet. We will not accept PR without this.
 4 | - [ ] Run pre-commit hook.
 5 | - [ ] If you changed Rust code, run `cargo check`, `cargo clippy`, `cargo test`.
 6 | 
 7 | ## PR Description
 8 | 
 9 | <!-- Description for the PR -->
10 | 


--------------------------------------------------------------------------------
/.github/actions/moshi_build/action.yml:
--------------------------------------------------------------------------------
 1 | name: moshi_build
 2 | description: 'Build env.'
 3 | runs:
 4 |   using: "composite"
 5 |   steps:
 6 |   - uses: actions/setup-python@v2
 7 |     with:
 8 |       python-version: '3.10.14'
 9 |   - uses: actions/cache@v3
10 |     id: cache
11 |     with:
12 |       path: env
13 |       key: env-${{ hashFiles('moshi/pyproject.toml') }}
14 |   - name: Install dependencies
15 |     if: steps.cache.outputs.cache-hit != 'true'
16 |     shell: bash
17 |     run: |
18 |       python3 -m venv env
19 |       .  env/bin/activate
20 |       python -m pip install --upgrade pip
21 |       pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cpu
22 |       pip install moshi==0.2.7
23 |       pip install pre-commit
24 |   - name: Setup env
25 |     shell: bash
26 |     run: |
27 |       source  env/bin/activate
28 |       pre-commit install
29 | 


--------------------------------------------------------------------------------
/.github/workflows/precommit.yml:
--------------------------------------------------------------------------------
 1 | name: precommit
 2 | on:
 3 |   push:
 4 |     branches: [ main ]
 5 |   pull_request:
 6 |     branches: [ main ]
 7 | 
 8 | jobs:
 9 |   run_precommit:
10 |     name: Run precommit
11 |     runs-on: ubuntu-latest
12 |     steps:
13 |       - uses: actions/checkout@v2
14 |       - uses: ./.github/actions/moshi_build
15 |       - run: |
16 |           source env/bin/activate
17 |           pre-commit run --all-files
18 | 


--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
  1 | # Byte-compiled / optimized / DLL files
  2 | __pycache__/
  3 | *.py[cod]
  4 | *$py.class
  5 | 
  6 | # C extensions
  7 | *.so
  8 | 
  9 | # Distribution / packaging
 10 | .Python
 11 | build/
 12 | develop-eggs/
 13 | dist/
 14 | downloads/
 15 | eggs/
 16 | .eggs/
 17 | lib/
 18 | lib64/
 19 | parts/
 20 | sdist/
 21 | var/
 22 | wheels/
 23 | share/python-wheels/
 24 | *.egg-info/
 25 | .installed.cfg
 26 | *.egg
 27 | MANIFEST
 28 | 
 29 | # PyInstaller
 30 | #  Usually these files are written by a python script from a template
 31 | #  before PyInstaller builds the exe, so as to inject date/other infos into it.
 32 | *.manifest
 33 | *.spec
 34 | 
 35 | # Installer logs
 36 | pip-log.txt
 37 | pip-delete-this-directory.txt
 38 | 
 39 | # Unit test / coverage reports
 40 | htmlcov/
 41 | .tox/
 42 | .nox/
 43 | .coverage
 44 | .coverage.*
 45 | .cache
 46 | nosetests.xml
 47 | coverage.xml
 48 | *.cover
 49 | *.py,cover
 50 | .hypothesis/
 51 | .pytest_cache/
 52 | cover/
 53 | 
 54 | # Translations
 55 | *.mo
 56 | *.pot
 57 | 
 58 | # Django stuff:
 59 | *.log
 60 | local_settings.py
 61 | db.sqlite3
 62 | db.sqlite3-journal
 63 | 
 64 | # Flask stuff:
 65 | instance/
 66 | .webassets-cache
 67 | 
 68 | # Scrapy stuff:
 69 | .scrapy
 70 | 
 71 | # Sphinx documentation
 72 | docs/_build/
 73 | 
 74 | # PyBuilder
 75 | .pybuilder/
 76 | target/
 77 | 
 78 | # Jupyter Notebook
 79 | .ipynb_checkpoints
 80 | 
 81 | # IPython
 82 | profile_default/
 83 | ipython_config.py
 84 | 
 85 | # pyenv
 86 | #   For a library or package, you might want to ignore these files since the code is
 87 | #   intended to run in multiple environments; otherwise, check them in:
 88 | # .python-version
 89 | 
 90 | # pipenv
 91 | #   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
 92 | #   However, in case of collaboration, if having platform-specific dependencies or dependencies
 93 | #   having no cross-platform support, pipenv may install dependencies that don't work, or not
 94 | #   install all needed dependencies.
 95 | #Pipfile.lock
 96 | 
 97 | # UV
 98 | #   Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
 99 | #   This is especially recommended for binary packages to ensure reproducibility, and is more
100 | #   commonly ignored for libraries.
101 | #uv.lock
102 | 
103 | # poetry
104 | #   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | #   This is especially recommended for binary packages to ensure reproducibility, and is more
106 | #   commonly ignored for libraries.
107 | #   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 | 
110 | # pdm
111 | #   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | #   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | #   in version control.
115 | #   https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 | 
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 | 
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 | 
127 | # SageMath parsed files
128 | *.sage.py
129 | 
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 | 
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 | 
143 | # Rope project settings
144 | .ropeproject
145 | 
146 | # mkdocs documentation
147 | /site
148 | 
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 | 
154 | # Pyre type checker
155 | .pyre/
156 | 
157 | # pytype static type analyzer
158 | .pytype/
159 | 
160 | # Cython debug symbols
161 | cython_debug/
162 | 
163 | # PyCharm
164 | #  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | #  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | #  and can be added to the global gitignore or merged into this file.  For a more nuclear
167 | #  option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 | 
170 | # Abstra
171 | # Abstra is an AI-powered process automation framework.
172 | # Ignore directories containing user credentials, local state, and settings.
173 | # Learn more at https://abstra.io/docs
174 | .abstra/
175 | 
176 | # Visual Studio Code
177 | #  Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore 
178 | #  that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
179 | #  and can be added to the global gitignore or merged into this file. However, if you prefer, 
180 | #  you could uncomment the following to ignore the enitre vscode folder
181 | # .vscode/
182 | 
183 | # Ruff stuff:
184 | .ruff_cache/
185 | 
186 | # PyPI configuration file
187 | .pypirc
188 | 
189 | # Cursor
190 | #  Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
191 | #  exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
192 | #  refer to https://docs.cursor.com/context/ignore-files
193 | .cursorignore
194 | .cursorindexingignore
195 | 


--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
 1 | repos:
 2 |     # Get rid of Jupyter Notebook output because we don't want to keep it in Git
 3 |   - repo: https://github.com/kynan/nbstripout
 4 |     rev: 0.8.1
 5 |     hooks:
 6 |       - id: nbstripout
 7 |   - repo: https://github.com/pre-commit/pre-commit-hooks
 8 |     rev: v5.0.0
 9 |     hooks:
10 |       - id: check-added-large-files
11 |         args: ["--maxkb=2048"]
12 |   - repo: https://github.com/astral-sh/ruff-pre-commit
13 |     # Ruff version.
14 |     rev: v0.11.7
15 |     hooks:
16 |       # Run the linter.
17 |       - id: ruff
18 |         types_or: [python, pyi] # Don't run on `jupyter` files
19 |         args: [--fix]
20 |       # Run the formatter.
21 |       - id: ruff-format
22 |         types_or: [python, pyi] # Don't run on `jupyter` files
23 | 


--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 |   "python.analysis.typeCheckingMode": "standard"
3 | }
4 | 


--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
 1 | # Contributing to Delayed-Streams-Modeling
 2 | 
 3 | ## Pull Requests
 4 | 
 5 | Delayed-Streams-Modeling is the implementation of a research paper.
 6 | Therefore, we do not plan on accepting many pull requests for new features.
 7 | However, we certainly welcome them for bug fixes.
 8 | 
 9 | 1. Fork the repo and create your branch from `main`.
10 | 2. If you have changed APIs, update the documentation accordingly.
11 | 3. Ensure pre-commit hooks pass properly, in particular the linting and typing.
12 | 4. When changing the Rust code, run `cargo check`, `cargo clippy`, `cargo test`.
13 | 5. Accept the Contributor License Agreement (see after).
14 | 
15 | Note that in general, we will not accept refactoring of the code.
16 | 
17 | 
18 | ## Contributor License Agreement ("CLA")
19 | 
20 | In order to accept your pull request, we need you to submit a Contributor License Agreement.
21 | 
22 | If you agree with the full CLA provided in the next paragraph, copy the following statement in your PR, changing your Github Handle:
23 | 
24 | > I, {your GitHub handle}, confirm that I have read and understood the terms of the CLA of Kyutai-labs, as outlined in the repository's CONTRIBUTING.md, and I agree to be bound by these terms.
25 | 
26 | The full CLA is provided as follows:
27 | 
28 | > I, {your GitHub handle}, hereby grant to Kyutai-labs a perpetual, worldwide, non-exclusive, royalty-free,
29 | > irrevocable license to use, modify, distribute, and sublicense my Contributions.
30 | 
31 | > I understand and accept that Contributions are limited to modifications, improvements, or changes
32 | > to the project’s source code submitted via pull requests. I accept that Kyutai-labs has full discretion to
33 | > review, accept, reject, or request changes to any Contributions I submit, and that submitting
34 | > a pull request does not guarantee its inclusion in the project.
35 | 
36 | > By submitting a Contribution, I grant Kyutai-labs a perpetual, worldwide license to use, modify,
37 | > reproduce, distribute, and create derivative works based on my Contributions.
38 | > I also agree to assign all patent rights for any inventions or improvements that arise from my Contributions,
39 | > giving the Kyutai-labs full rights to file for and enforce patents.
40 | > I understand that the Kyutai-labs may commercialize, relicense, or exploit the project and my Contributions without further notice or obligation to me.
41 | > I confirm that my Contributions are original and that I have the legal right to grant this license.
42 | > If my Contributions include third-party materials, I will ensure that I have the necessary permissions
43 | > and will disclose this information. I accept that once my Contributions are integrated, they may be altered or removed at the Kyutai-labs’s discretion.
44 | 
45 | > I acknowledge that I am making these Contributions voluntarily and will not receive any compensation.
46 | > Furthermore, I understand that all Contributions, including mine, are provided on an "as-is" basis, with no warranties.
47 | > By submitting a pull request, I agree to be bound by these terms.
48 | 
49 | ## Issues
50 | 
51 | Please submit issues on our Github repository.
52 | 
53 | ## License
54 | 
55 | By contributing to Delayed-Streams-Modeling, you agree that your contributions
56 | will be licensed under the LICENSE-* files in the root directory of this source
57 | tree. In particular, the rust code is licensed under APACHE, and the python code
58 | under MIT.
59 | 


--------------------------------------------------------------------------------
/LICENSE-APACHE:
--------------------------------------------------------------------------------
  1 |                                  Apache License
  2 |                            Version 2.0, January 2004
  3 |                         http://www.apache.org/licenses/
  4 | 
  5 |    TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
  6 | 
  7 |    1. Definitions.
  8 | 
  9 |       "License" shall mean the terms and conditions for use, reproduction,
 10 |       and distribution as defined by Sections 1 through 9 of this document.
 11 | 
 12 |       "Licensor" shall mean the copyright owner or entity authorized by
 13 |       the copyright owner that is granting the License.
 14 | 
 15 |       "Legal Entity" shall mean the union of the acting entity and all
 16 |       other entities that control, are controlled by, or are under common
 17 |       control with that entity. For the purposes of this definition,
 18 |       "control" means (i) the power, direct or indirect, to cause the
 19 |       direction or management of such entity, whether by contract or
 20 |       otherwise, or (ii) ownership of fifty percent (50%) or more of the
 21 |       outstanding shares, or (iii) beneficial ownership of such entity.
 22 | 
 23 |       "You" (or "Your") shall mean an individual or Legal Entity
 24 |       exercising permissions granted by this License.
 25 | 
 26 |       "Source" form shall mean the preferred form for making modifications,
 27 |       including but not limited to software source code, documentation
 28 |       source, and configuration files.
 29 | 
 30 |       "Object" form shall mean any form resulting from mechanical
 31 |       transformation or translation of a Source form, including but
 32 |       not limited to compiled object code, generated documentation,
 33 |       and conversions to other media types.
 34 | 
 35 |       "Work" shall mean the work of authorship, whether in Source or
 36 |       Object form, made available under the License, as indicated by a
 37 |       copyright notice that is included in or attached to the work
 38 |       (an example is provided in the Appendix below).
 39 | 
 40 |       "Derivative Works" shall mean any work, whether in Source or Object
 41 |       form, that is based on (or derived from) the Work and for which the
 42 |       editorial revisions, annotations, elaborations, or other modifications
 43 |       represent, as a whole, an original work of authorship. For the purposes
 44 |       of this License, Derivative Works shall not include works that remain
 45 |       separable from, or merely link (or bind by name) to the interfaces of,
 46 |       the Work and Derivative Works thereof.
 47 | 
 48 |       "Contribution" shall mean any work of authorship, including
 49 |       the original version of the Work and any modifications or additions
 50 |       to that Work or Derivative Works thereof, that is intentionally
 51 |       submitted to Licensor for inclusion in the Work by the copyright owner
 52 |       or by an individual or Legal Entity authorized to submit on behalf of
 53 |       the copyright owner. For the purposes of this definition, "submitted"
 54 |       means any form of electronic, verbal, or written communication sent
 55 |       to the Licensor or its representatives, including but not limited to
 56 |       communication on electronic mailing lists, source code control systems,
 57 |       and issue tracking systems that are managed by, or on behalf of, the
 58 |       Licensor for the purpose of discussing and improving the Work, but
 59 |       excluding communication that is conspicuously marked or otherwise
 60 |       designated in writing by the copyright owner as "Not a Contribution."
 61 | 
 62 |       "Contributor" shall mean Licensor and any individual or Legal Entity
 63 |       on behalf of whom a Contribution has been received by Licensor and
 64 |       subsequently incorporated within the Work.
 65 | 
 66 |    2. Grant of Copyright License. Subject to the terms and conditions of
 67 |       this License, each Contributor hereby grants to You a perpetual,
 68 |       worldwide, non-exclusive, no-charge, royalty-free, irrevocable
 69 |       copyright license to reproduce, prepare Derivative Works of,
 70 |       publicly display, publicly perform, sublicense, and distribute the
 71 |       Work and such Derivative Works in Source or Object form.
 72 | 
 73 |    3. Grant of Patent License. Subject to the terms and conditions of
 74 |       this License, each Contributor hereby grants to You a perpetual,
 75 |       worldwide, non-exclusive, no-charge, royalty-free, irrevocable
 76 |       (except as stated in this section) patent license to make, have made,
 77 |       use, offer to sell, sell, import, and otherwise transfer the Work,
 78 |       where such license applies only to those patent claims licensable
 79 |       by such Contributor that are necessarily infringed by their
 80 |       Contribution(s) alone or by combination of their Contribution(s)
 81 |       with the Work to which such Contribution(s) was submitted. If You
 82 |       institute patent litigation against any entity (including a
 83 |       cross-claim or counterclaim in a lawsuit) alleging that the Work
 84 |       or a Contribution incorporated within the Work constitutes direct
 85 |       or contributory patent infringement, then any patent licenses
 86 |       granted to You under this License for that Work shall terminate
 87 |       as of the date such litigation is filed.
 88 | 
 89 |    4. Redistribution. You may reproduce and distribute copies of the
 90 |       Work or Derivative Works thereof in any medium, with or without
 91 |       modifications, and in Source or Object form, provided that You
 92 |       meet the following conditions:
 93 | 
 94 |       (a) You must give any other recipients of the Work or
 95 |           Derivative Works a copy of this License; and
 96 | 
 97 |       (b) You must cause any modified files to carry prominent notices
 98 |           stating that You changed the files; and
 99 | 
100 |       (c) You must retain, in the Source form of any Derivative Works
101 |           that You distribute, all copyright, patent, trademark, and
102 |           attribution notices from the Source form of the Work,
103 |           excluding those notices that do not pertain to any part of
104 |           the Derivative Works; and
105 | 
106 |       (d) If the Work includes a "NOTICE" text file as part of its
107 |           distribution, then any Derivative Works that You distribute must
108 |           include a readable copy of the attribution notices contained
109 |           within such NOTICE file, excluding those notices that do not
110 |           pertain to any part of the Derivative Works, in at least one
111 |           of the following places: within a NOTICE text file distributed
112 |           as part of the Derivative Works; within the Source form or
113 |           documentation, if provided along with the Derivative Works; or,
114 |           within a display generated by the Derivative Works, if and
115 |           wherever such third-party notices normally appear. The contents
116 |           of the NOTICE file are for informational purposes only and
117 |           do not modify the License. You may add Your own attribution
118 |           notices within Derivative Works that You distribute, alongside
119 |           or as an addendum to the NOTICE text from the Work, provided
120 |           that such additional attribution notices cannot be construed
121 |           as modifying the License.
122 | 
123 |       You may add Your own copyright statement to Your modifications and
124 |       may provide additional or different license terms and conditions
125 |       for use, reproduction, or distribution of Your modifications, or
126 |       for any such Derivative Works as a whole, provided Your use,
127 |       reproduction, and distribution of the Work otherwise complies with
128 |       the conditions stated in this License.
129 | 
130 |    5. Submission of Contributions. Unless You explicitly state otherwise,
131 |       any Contribution intentionally submitted for inclusion in the Work
132 |       by You to the Licensor shall be under the terms and conditions of
133 |       this License, without any additional terms or conditions.
134 |       Notwithstanding the above, nothing herein shall supersede or modify
135 |       the terms of any separate license agreement you may have executed
136 |       with Licensor regarding such Contributions.
137 | 
138 |    6. Trademarks. This License does not grant permission to use the trade
139 |       names, trademarks, service marks, or product names of the Licensor,
140 |       except as required for reasonable and customary use in describing the
141 |       origin of the Work and reproducing the content of the NOTICE file.
142 | 
143 |    7. Disclaimer of Warranty. Unless required by applicable law or
144 |       agreed to in writing, Licensor provides the Work (and each
145 |       Contributor provides its Contributions) on an "AS IS" BASIS,
146 |       WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 |       implied, including, without limitation, any warranties or conditions
148 |       of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 |       PARTICULAR PURPOSE. You are solely responsible for determining the
150 |       appropriateness of using or redistributing the Work and assume any
151 |       risks associated with Your exercise of permissions under this License.
152 | 
153 |    8. Limitation of Liability. In no event and under no legal theory,
154 |       whether in tort (including negligence), contract, or otherwise,
155 |       unless required by applicable law (such as deliberate and grossly
156 |       negligent acts) or agreed to in writing, shall any Contributor be
157 |       liable to You for damages, including any direct, indirect, special,
158 |       incidental, or consequential damages of any character arising as a
159 |       result of this License or out of the use or inability to use the
160 |       Work (including but not limited to damages for loss of goodwill,
161 |       work stoppage, computer failure or malfunction, or any and all
162 |       other commercial damages or losses), even if such Contributor
163 |       has been advised of the possibility of such damages.
164 | 
165 |    9. Accepting Warranty or Additional Liability. While redistributing
166 |       the Work or Derivative Works thereof, You may choose to offer,
167 |       and charge a fee for, acceptance of support, warranty, indemnity,
168 |       or other liability obligations and/or rights consistent with this
169 |       License. However, in accepting such obligations, You may act only
170 |       on Your own behalf and on Your sole responsibility, not on behalf
171 |       of any other Contributor, and only if You agree to indemnify,
172 |       defend, and hold each Contributor harmless for any liability
173 |       incurred by, or claims asserted against, such Contributor by reason
174 |       of your accepting any such warranty or additional liability.
175 | 
176 |    END OF TERMS AND CONDITIONS
177 | 
178 |    APPENDIX: How to apply the Apache License to your work.
179 | 
180 |       To apply the Apache License to your work, attach the following
181 |       boilerplate notice, with the fields enclosed by brackets "[]"
182 |       replaced with your own identifying information. (Don't include
183 |       the brackets!)  The text should be enclosed in the appropriate
184 |       comment syntax for the file format. We also recommend that a
185 |       file or class name and description of purpose be included on the
186 |       same "printed page" as the copyright notice for easier
187 |       identification within third-party archives.
188 | 
189 |    Copyright [yyyy] [name of copyright owner]
190 | 
191 |    Licensed under the Apache License, Version 2.0 (the "License");
192 |    you may not use this file except in compliance with the License.
193 |    You may obtain a copy of the License at
194 | 
195 |        http://www.apache.org/licenses/LICENSE-2.0
196 | 
197 |    Unless required by applicable law or agreed to in writing, software
198 |    distributed under the License is distributed on an "AS IS" BASIS,
199 |    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 |    See the License for the specific language governing permissions and
201 |    limitations under the License.
202 | 


--------------------------------------------------------------------------------
/LICENSE-MIT:
--------------------------------------------------------------------------------
 1 | Permission is hereby granted, free of charge, to any
 2 | person obtaining a copy of this software and associated
 3 | documentation files (the "Software"), to deal in the
 4 | Software without restriction, including without
 5 | limitation the rights to use, copy, modify, merge,
 6 | publish, distribute, sublicense, and/or sell copies of
 7 | the Software, and to permit persons to whom the Software
 8 | is furnished to do so, subject to the following
 9 | conditions:
10 | 
11 | The above copyright notice and this permission notice
12 | shall be included in all copies or substantial portions
13 | of the Software.
14 | 
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
16 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
17 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
18 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
19 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
20 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
22 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
23 | DEALINGS IN THE SOFTWARE.
24 | 


--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
  1 | # Delayed Streams Modeling: Kyutai STT & TTS
  2 | 
  3 | This repo contains instructions and examples of how to run
  4 | [Kyutai Speech-To-Text](#kyutai-speech-to-text)
  5 | and [Kyutai Text-To-Speech](#kyutai-text-to-speech) models.
  6 | These models are powered by delayed streams modeling (DSM),
  7 | a flexible formulation for streaming, multimodal sequence-to-sequence learning.
  8 | See also [Unmute](https://github.com/kyutai-labs/unmute), an voice AI system built using Kyutai STT and Kyutai TTS.
  9 | 
 10 | But wait, what is "Delayed Streams Modeling"? It is a technique for solving many streaming X-to-Y tasks (with X, Y in `{speech, text}`)
 11 | that formalize the approach we had with Moshi and Hibiki. A pre-print paper is coming soon!
 12 | 
 13 | ## Kyutai Speech-To-Text
 14 | 
 15 | <a href="https://huggingface.co/collections/kyutai/speech-to-text-685403682cf8a23ab9466886" target="_blank" style="margin: 2px;">
 16 |     <img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-KyutaiSTT-blue" style="display: inline-block; vertical-align: middle;"/>
 17 | </a>
 18 | <a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/stt_pytorch.ipynb">
 19 |   <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
 20 | </a>
 21 | 
 22 | **More details can be found on the [project page](https://kyutai.org/next/stt).**
 23 | 
 24 | Kyutai STT models are optimized for real-time usage, can be batched for efficiency, and return word level timestamps.
 25 | We provide two models:
 26 | - `kyutai/stt-1b-en_fr`, an English and French model with ~1B parameters, a 0.5 second delay, and a [semantic VAD](https://kyutai.org/next/stt#semantic-vad).
 27 | - `kyutai/stt-2.6b-en`, an English-only model with ~2.6B parameters and a 2.5 second delay.
 28 | 
 29 | These speech-to-text models have several advantages:
 30 | - Streaming inference: the models can process audio in chunks, which allows
 31 |   for real-time transcription, and is great for interactive applications.
 32 | - Easy batching for maximum efficiency: a H100 can process 400 streams in
 33 |   real-time.
 34 | - They return word-level timestamps.
 35 | - The 1B model has a semantic Voice Activity Detection (VAD) component that
 36 |   can be used to detect when the user is speaking. This is especially useful
 37 |   for building voice agents.
 38 | 
 39 | ### Implementations overview
 40 | 
 41 | We provide different implementations of Kyutai STT for different use cases.
 42 | Here is how to choose which one to use:
 43 | 
 44 | - **PyTorch: for research and tinkering.**
 45 |   If you want to call the model from Python for research or experimentation, use our PyTorch implementation.
 46 | - **Rust: for production.**
 47 |   If you want to serve Kyutai STT in a production setting, use our Rust server.
 48 |   Our robust Rust server provides streaming access to the model over websockets.
 49 |   We use this server to run [Unmute](https://unmute.sh/); on a L40S GPU, we can serve 64 simultaneous connections at a real-time factor of 3x.
 50 | - **MLX: for on-device inference on iPhone and Mac.**
 51 |   MLX is Apple's ML framework that allows you to use hardware acceleration on Apple silicon.
 52 |   If you want to run the model on a Mac or an iPhone, choose the MLX implementation.
 53 | 
 54 | <details>
 55 | <summary>PyTorch implementation</summary>
 56 | <a href="https://huggingface.co/kyutai/stt-2.6b-en" target="_blank" style="margin: 2px;">
 57 |     <img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
 58 | </a>
 59 | <a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/stt_pytorch.ipynb">
 60 |   <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
 61 | </a>
 62 | 
 63 | For an example of how to use the model in a way where you can directly stream in PyTorch tensors,
 64 | [see our Colab notebook](https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/stt_pytorch.ipynb).
 65 | 
 66 | This requires the [moshi package](https://pypi.org/project/moshi/)
 67 | with version 0.2.6 or later, which can be installed via pip.
 68 | 
 69 | If you just want to run the model on a file, you can use `moshi.run_inference`.
 70 | 
 71 | ```bash
 72 | python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en audio/bria.mp3
 73 | ```
 74 | 
 75 | If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step
 76 | and just prefix the command above with `uvx --with moshi`.
 77 | 
 78 | Additionally, we provide two scripts that highlight different usage scenarios. The first script illustrates how to extract word-level timestamps from the model's outputs:
 79 | 
 80 | ```bash
 81 | uv run \
 82 |   scripts/stt_from_file_pytorch.py \
 83 |   --hf-repo kyutai/stt-2.6b-en \
 84 |   audio/bria.mp3
 85 | ```
 86 | 
 87 | The second script can be used to run a model on an existing Hugging Face dataset and calculate its performance metrics: 
 88 | ```bash
 89 | uv run scripts/evaluate_on_dataset.py  \
 90 |   --dataset meanwhile  \
 91 |   --hf-repo kyutai/stt-2.6b-en
 92 | ```
 93 | 
 94 | Another example shows how one can provide a text-, audio-, or text-audio prompt to our STT model:
 95 | ```bash
 96 | uv run scripts/stt_from_file_pytorch_with_prompt.py \
 97 |   --hf-repo kyutai/stt-2.6b-en \
 98 |   --file bria.mp3 \
 99 |   --prompt_file ./audio/loonah.mp3 \
100 |   --prompt_text "Loonah" \
101 |   --cut-prompt-transcript
102 | ```
103 | Produces the transcript of `bria.mp3` using the `Loonah` spelling for the name, instead of the `Luna` used without any prompt:
104 | ```
105 | In the heart of an ancient forest, where the trees whispered secrets of the past, there lived a peculiar rabbit named Loonah (...)
106 | ```
107 | 
108 | Apart from nudging the model for a specific spelling of a word, other potential use-cases include speaker adaptation and steering the model towards a specific formatting style or even a language.
109 | However, please bear in mind that is an experimental feature and its behavior is very sensitive to the prompt provided.
110 | </details>
111 | 
112 | <details>
113 | <summary>Rust server</summary>
114 | 
115 | <a href="https://huggingface.co/kyutai/stt-2.6b-en-candle" target="_blank" style="margin: 2px;">
116 |     <img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
117 | </a>
118 | 
119 | The Rust implementation provides a server that can process multiple streaming
120 | queries in parallel. Depending on the amount of memory on your GPU, you may
121 | have to adjust the batch size from the config file. For a L40S GPU, a batch size
122 | of 64 works well and requests can be processed at 3x real-time speed.
123 | 
124 | In order to run the server, install the [moshi-server
125 | crate](https://crates.io/crates/moshi-server) via the following command. The
126 | server code can be found in the
127 | [kyutai-labs/moshi](https://github.com/kyutai-labs/moshi/tree/main/rust/moshi-server)
128 | repository.
129 | ```bash
130 | cargo install --features cuda moshi-server
131 | ```
132 | 
133 | Then the server can be started via the following command using the config file
134 | from this repository.
135 | For `kyutai/stt-1b-en_fr`, use `configs/config-stt-en_fr.hf.toml`,
136 | and for `kyutai/stt-2.6b-en`, use `configs/config-stt-en-hf.toml`,
137 | 
138 | ```bash
139 | moshi-server worker --config configs/config-stt-en_fr-hf.toml
140 | ```
141 | 
142 | Once the server has started you can transcribe audio from your microphone with the following script.
143 | ```bash
144 | uv run scripts/stt_from_mic_rust_server.py
145 | ```
146 | 
147 | We also provide a script for transcribing from an audio file.
148 | ```bash
149 | uv run scripts/stt_from_file_rust_server.py audio/bria.mp3
150 | ```
151 | 
152 | The script limits the decoding speed to simulates real-time processing of the audio. 
153 | Faster processing can be triggered by setting 
154 | the real-time factor, e.g. `--rtf 1000` will process
155 | the data as fast as possible.
156 | </details>
157 | 
158 | <details>
159 | <summary>Rust standalone</summary>
160 | <a href="https://huggingface.co/kyutai/stt-2.6b-en-candle" target="_blank" style="margin: 2px;">
161 |     <img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
162 | </a>
163 | 
164 | A standalone Rust example script is provided in the `stt-rs` directory in this repo.
165 | This can be used as follows:
166 | ```bash
167 | cd stt-rs
168 | cargo run --features cuda -r -- ../audio/bria.mp3
169 | ```
170 | You can get the timestamps by adding the `--timestamps` flag, and see the output
171 | of the semantic VAD by adding the `--vad` flag.
172 | </details>
173 | 
174 | <details>
175 | <summary>MLX implementation</summary>
176 | <a href="https://huggingface.co/kyutai/stt-2.6b-en-mlx" target="_blank" style="margin: 2px;">
177 |     <img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue" style="display: inline-block; vertical-align: middle;"/>
178 | </a>
179 | 
180 | [MLX](https://ml-explore.github.io/mlx/build/html/index.html) is Apple's ML framework that allows you to use
181 | hardware acceleration on Apple silicon.
182 | 
183 | This requires the [moshi-mlx package](https://pypi.org/project/moshi-mlx/)
184 | with version 0.2.6 or later, which can be installed via pip.
185 | 
186 | If you just want to run the model on a file, you can use `moshi_mlx.run_inference`:
187 | 
188 | ```bash
189 | python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx audio/bria.mp3 --temp 0
190 | ```
191 | 
192 | If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step
193 | and just prefix the command above with `uvx --with moshi-mlx`.
194 | 
195 | If you want to transcribe audio from your microphone, use:
196 | 
197 | ```bash
198 | python scripts/stt_from_mic_mlx.py
199 | ```
200 | 
201 | The MLX models can also be used in swift using the [moshi-swift
202 | codebase](https://github.com/kyutai-labs/moshi-swift), the 1b model has been
203 | tested to work fine on an iPhone 16 Pro.
204 | </details>
205 | 
206 | ## Kyutai Text-to-Speech
207 | 
208 | <a href="https://huggingface.co/collections/kyutai/text-to-speech-6866192e7e004ed04fd39e29" target="_blank" style="margin: 2px;">
209 |     <img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-KyutaiTTS-blue" style="display: inline-block; vertical-align: middle;"/>
210 | </a>
211 | <a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/tts_pytorch.ipynb">
212 |   <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
213 | </a>
214 | 
215 | **More details can be found on the [project page](https://kyutai.org/next/tts).**
216 | 
217 | We provide different implementations of Kyutai TTS for different use cases. Here is how to choose which one to use:
218 | 
219 | - PyTorch: for research and tinkering. If you want to call the model from Python for research or experimentation, use our PyTorch implementation.
220 | - Rust: for production. If you want to serve Kyutai TTS in a production setting, use our Rust server. Our robust Rust server provides streaming access to the model over websockets. We use this server to run Unmute.
221 | - MLX: for on-device inference on iPhone and Mac. MLX is Apple's ML framework that allows you to use hardware acceleration on Apple silicon. If you want to run the model on a Mac or an iPhone, choose the MLX implementation.
222 | 
223 | <details>
224 | <summary>PyTorch implementation</summary>
225 | 
226 | <a target="_blank" href="https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/tts_pytorch.ipynb">
227 |   <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
228 | </a>
229 | 
230 | Check out our [Colab notebook](https://colab.research.google.com/github/kyutai-labs/delayed-streams-modeling/blob/main/tts_pytorch.ipynb) or use the script:
231 | 
232 | ```bash
233 | # From stdin, plays audio immediately
234 | echo "Hey, how are you?" | python scripts/tts_pytorch.py - -
235 | 
236 | # From text file to audio file
237 | python scripts/tts_pytorch.py text_to_say.txt audio_output.wav
238 | ```
239 | 
240 | This requires the [moshi package](https://pypi.org/project/moshi/), which can be installed via pip.
241 | If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step
242 | and just prefix the command above with `uvx --with moshi`.
243 | </details>
244 | 
245 | <details>
246 | <summary>Rust server</summary>
247 | 
248 | 
249 | The Rust implementation provides a server that can process multiple streaming
250 | queries in parallel.
251 | 
252 | Installing the Rust server is a bit tricky because it uses our Python implementation under the hood,
253 | which also requires installing the Python dependencies.
254 | Use the [start_tts.sh](https://github.com/kyutai-labs/unmute/blob/main/dockerless/start_tts.sh) script to properly install the Rust server.
255 | If you already installed the `moshi-server` crate before and it's not working, you might need to force a reinstall by running `cargo uninstall moshi-server` first.
256 | Feel free to open an issue if the installation is still broken.
257 | 
258 | Once installed, the server can be started via the following command using the config file
259 | from this repository.
260 | 
261 | ```bash
262 | moshi-server worker --config configs/config-tts.toml
263 | ```
264 | 
265 | Once the server has started you can connect to it using our script as follows:
266 | ```bash
267 | # From stdin, plays audio immediately
268 | echo "Hey, how are you?" | python scripts/tts_rust_server.py - -
269 | 
270 | # From text file to audio file
271 | python scripts/tts_rust_server.py text_to_say.txt audio_output.wav
272 | ```
273 | </details>
274 | 
275 | <details>
276 | <summary>MLX implementation</summary>
277 | 
278 | [MLX](https://ml-explore.github.io/mlx/build/html/index.html) is Apple's ML framework that allows you to use
279 | hardware acceleration on Apple silicon.
280 | 
281 | Use our example script to run Kyutai TTS on MLX.
282 | The script takes text from stdin or a file and can output to a file or stream the resulting audio.
283 | When streaming the output, if the model is not fast enough to keep with
284 | real-time, you can use the `--quantize 8` or `--quantize 4` flags to quantize
285 | the model resulting in faster inference.
286 | 
287 | ```bash
288 | # From stdin, plays audio immediately
289 | echo "Hey, how are you?" | python scripts/tts_mlx.py - - --quantize 8
290 | 
291 | # From text file to audio file
292 | python scripts/tts_mlx.py text_to_say.txt audio_output.wav
293 | ```
294 | 
295 | This requires the [moshi-mlx package](https://pypi.org/project/moshi-mlx/), which can be installed via pip.
296 | If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step
297 | and just prefix the command above with `uvx --with moshi-mlx`.
298 | </details>
299 | 
300 | ## License
301 | 
302 | The present code is provided under the MIT license for the Python parts, and Apache license for the Rust backend.
303 | The web client code is provided under the MIT license.
304 | Note that parts of this code is based on [AudioCraft](https://github.com/facebookresearch/audiocraft), released under
305 | the MIT license.
306 | 
307 | The weights for the speech-to-text models are released under the CC-BY 4.0 license.
308 | 
309 | ## Developing
310 | 
311 | Install the [pre-commit hooks](https://pre-commit.com/) by running:
312 | 
313 | ```bash
314 | pip install pre-commit
315 | pre-commit install
316 | ```
317 | 
318 | If you're using `uv`, you can replace the two commands with `uvx pre-commit install`.
319 | 


--------------------------------------------------------------------------------
/audio/bria.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/delayed-streams-modeling/baf0c75bba89608e921cb26e03c959981df2ad5f/audio/bria.mp3


--------------------------------------------------------------------------------
/audio/loona.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/delayed-streams-modeling/baf0c75bba89608e921cb26e03c959981df2ad5f/audio/loona.mp3


--------------------------------------------------------------------------------
/audio/sample_fr_hibiki_crepes.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/delayed-streams-modeling/baf0c75bba89608e921cb26e03c959981df2ad5f/audio/sample_fr_hibiki_crepes.mp3


--------------------------------------------------------------------------------
/configs/config-stt-en-hf.toml:
--------------------------------------------------------------------------------
 1 | static_dir = "./static/"
 2 | log_dir = "$HOME/tmp/tts-logs"
 3 | instance_name = "tts"
 4 | authorized_ids = ["public_token"]
 5 | 
 6 | [modules.asr]
 7 | path = "/api/asr-streaming"
 8 | type = "BatchedAsr"
 9 | lm_model_file = "hf://kyutai/stt-2.6b-en-candle/model.safetensors"
10 | text_tokenizer_file = "hf://kyutai/stt-2.6b-en-candle/tokenizer_en_audio_4000.model"
11 | audio_tokenizer_file = "hf://kyutai/stt-2.6b-en-candle/mimi-pytorch-e351c8d8@125.safetensors"
12 | asr_delay_in_tokens = 32
13 | batch_size = 16
14 | conditioning_learnt_padding = true
15 | temperature = 0
16 | 
17 | [modules.asr.model]
18 | audio_vocab_size = 2049
19 | text_in_vocab_size = 4001
20 | text_out_vocab_size = 4000
21 | audio_codebooks = 32
22 | 
23 | [modules.asr.model.transformer]
24 | d_model = 2048
25 | num_heads = 32
26 | num_layers = 48
27 | dim_feedforward = 8192
28 | causal = true
29 | norm_first = true
30 | bias_ff = false
31 | bias_attn = false
32 | context = 375
33 | max_period = 100000
34 | use_conv_block = false
35 | use_conv_bias = true
36 | gating = "silu"
37 | norm = "RmsNorm"
38 | positional_embedding = "Rope"
39 | conv_layout = false
40 | conv_kernel_size = 3
41 | kv_repeat = 1
42 | max_seq_len = 40960
43 | 


--------------------------------------------------------------------------------
/configs/config-stt-en_fr-hf.toml:
--------------------------------------------------------------------------------
 1 | static_dir = "./static/"
 2 | log_dir = "$HOME/tmp/tts-logs"
 3 | instance_name = "tts"
 4 | authorized_ids = ["public_token"]
 5 | 
 6 | [modules.asr]
 7 | path = "/api/asr-streaming"
 8 | type = "BatchedAsr"
 9 | lm_model_file = "hf://kyutai/stt-1b-en_fr-candle/model.safetensors"
10 | text_tokenizer_file = "hf://kyutai/stt-1b-en_fr-candle/tokenizer_en_fr_audio_8000.model"
11 | audio_tokenizer_file = "hf://kyutai/stt-1b-en_fr-candle/mimi-pytorch-e351c8d8@125.safetensors"
12 | asr_delay_in_tokens = 6
13 | batch_size = 64
14 | conditioning_learnt_padding = true
15 | temperature = 0.0
16 | 
17 | [modules.asr.model]
18 | audio_vocab_size = 2049
19 | text_in_vocab_size = 8001
20 | text_out_vocab_size = 8000
21 | audio_codebooks = 32
22 | 
23 | [modules.asr.model.transformer]
24 | d_model = 2048
25 | num_heads = 16
26 | num_layers = 16
27 | dim_feedforward = 8192
28 | causal = true
29 | norm_first = true
30 | bias_ff = false
31 | bias_attn = false
32 | context = 750
33 | max_period = 100000
34 | use_conv_block = false
35 | use_conv_bias = true
36 | gating = "silu"
37 | norm = "RmsNorm"
38 | positional_embedding = "Rope"
39 | conv_layout = false
40 | conv_kernel_size = 3
41 | kv_repeat = 1
42 | max_seq_len = 40960
43 | 
44 | [modules.asr.model.extra_heads]
45 | num_heads = 4
46 | dim = 6
47 | 


--------------------------------------------------------------------------------
/configs/config-tts.toml:
--------------------------------------------------------------------------------
 1 | static_dir = "./static/"
 2 | log_dir = "$HOME/tmp/tts-logs"
 3 | instance_name = "tts"
 4 | authorized_ids = ["public_token"]
 5 | 
 6 | [modules.tts_py]
 7 | type = "Py"
 8 | path = "/api/tts_streaming"
 9 | text_tokenizer_file = "hf://kyutai/tts-1.6b-en_fr/tokenizer_spm_8k_en_fr_audio.model"
10 | batch_size = 8  # Adjust to your GPU memory capacity
11 | text_bos_token = 1
12 | 
13 | [modules.tts_py.py]
14 | log_folder = "$HOME/tmp/moshi-server-logs"
15 | voice_folder = "hf-snapshot://kyutai/tts-voices/**/*.safetensors"
16 | default_voice = "unmute-prod-website/default_voice.wav"
17 | cfg_coef = 2.0
18 | cfg_is_no_text = true
19 | padding_between = 1
20 | n_q = 24
21 | 


--------------------------------------------------------------------------------
/scripts/stt_evaluate_on_dataset.py:
--------------------------------------------------------------------------------
  1 | # /// script
  2 | # requires-python = ">=3.12"
  3 | # dependencies = [
  4 | #     "datasets",
  5 | #     "jiwer==3.1.0",
  6 | #     "julius",
  7 | #     "librosa",
  8 | #     "moshi",
  9 | #     "openai-whisper",
 10 | #     "soundfile",
 11 | # ]
 12 | # ///
 13 | """
 14 | Example implementation of the streaming STT example. Here we group
 15 | test utterances in batches (pre- and post-padded with silence) and
 16 | and then feed these batches into the streaming STT model frame-by-frame.
 17 | """
 18 | 
 19 | # The outputs I get on my H100 using this code with the 2.6B model,
 20 | # bsz 32:
 21 | 
 22 | # LibriVox === cer: 4.09% wer: 7.33% corpus_wer: 6.78% RTF = 52.72
 23 | # Ami === cer: 15.99% wer: 18.78% corpus_wer: 12.20% RTF = 28.37
 24 | # LibriSpeech other === cer: 2.31% wer: 5.24% corpus_wer: 4.33% RTF = 44.76
 25 | # LibriSpeech clean === cer: 0.67% wer: 1.95% corpus_wer: 1.69% RTF = 68.19
 26 | # Tedlium (short) === cer: 2.15% wer: 3.65% corpus_wer: 3.33% RTF = 67.44
 27 | # spgispeech === cer: 0.99% wer: 2.00% corpus_wer: 2.03% RTF = 78.64
 28 | # gigaspeech === cer: 6.80% wer: 11.31% corpus_wer: 9.81% RTF = 64.04
 29 | # earnings22 (short) === cer: 12.63% wer: 15.70% corpus_wer: 11.02% RTF = 50.13
 30 | 
 31 | # Meanwhile === cer: 2.02% wer: 5.50% corpus_wer: 5.60% RTF = 69.19
 32 | # Tedlium (long) == cer: 1.53% wer: 2.56% corpus_wer: 2.97% RTF = 33.92
 33 | # Rev16 === cer: 6.57% wer: 10.08% corpus_wer: 11.43% RTF = 40.34
 34 | # Earnings21 === cer: 5.73% wer: 9.84% corpus_wer: 10.38% RTF = 73.15
 35 | 
 36 | import argparse
 37 | import dataclasses
 38 | import time
 39 | 
 40 | import jiwer
 41 | import julius
 42 | import moshi.models
 43 | import torch
 44 | import tqdm
 45 | from datasets import Dataset, load_dataset
 46 | from whisper.normalizers import EnglishTextNormalizer
 47 | 
 48 | _NORMALIZER = EnglishTextNormalizer()
 49 | 
 50 | 
 51 | def get_text(sample):
 52 |     possible_keys = [
 53 |         "text",
 54 |         "sentence",
 55 |         "normalized_text",
 56 |         "transcript",
 57 |         "transcription",
 58 |     ]
 59 |     for key in possible_keys:
 60 |         if key in sample:
 61 |             return sample[key]
 62 |     raise ValueError(
 63 |         f"Expected transcript column of either {possible_keys}."
 64 |         f"Got sample with keys: {', '.join(sample.keys())}. Ensure a text column name is present in the dataset."
 65 |     )
 66 | 
 67 | 
 68 | # The two functions below are adapted from https://github.com/huggingface/open_asr_leaderboard/blob/main/normalizer/data_utils.py
 69 | 
 70 | 
 71 | def normalize(batch):
 72 |     batch["original_text"] = get_text(batch)
 73 |     batch["norm_text"] = _NORMALIZER(batch["original_text"])
 74 |     return batch
 75 | 
 76 | 
 77 | def is_target_text_in_range(ref):
 78 |     if ref.strip() == "ignore time segment in scoring":
 79 |         return False
 80 |     else:
 81 |         return ref.strip() != ""
 82 | 
 83 | 
 84 | # End of the adapted part
 85 | 
 86 | 
 87 | class AsrMetrics:
 88 |     def __init__(self):
 89 |         self.cer_sum = 0.0
 90 |         self.wer_sum = 0.0
 91 |         self.errors_sum = 0.0
 92 |         self.total_words_sum = 0.0
 93 |         self.num_sequences = 0.0
 94 | 
 95 |     def update(self, hyp: str, ref: str) -> None:
 96 |         normalized_ref = _NORMALIZER(ref)
 97 |         normalized_hyp = _NORMALIZER(hyp)
 98 | 
 99 |         this_wer = jiwer.wer(normalized_ref, normalized_hyp)
100 |         this_cer = jiwer.cer(normalized_ref, normalized_hyp)
101 |         measures = jiwer.compute_measures(normalized_ref, normalized_hyp)
102 | 
103 |         self.wer_sum += this_wer
104 |         self.cer_sum += this_cer
105 |         self.errors_sum += (
106 |             measures["substitutions"] + measures["deletions"] + measures["insertions"]
107 |         )
108 |         self.total_words_sum += (
109 |             measures["substitutions"] + measures["deletions"] + measures["hits"]
110 |         )
111 |         self.num_sequences += 1
112 | 
113 |     def compute(self) -> dict:
114 |         assert self.num_sequences > 0, (
115 |             "Unable to compute with total number of comparisons <= 0"
116 |         )  # type: ignore
117 |         return {
118 |             "cer": (self.cer_sum / self.num_sequences),
119 |             "wer": (self.wer_sum / self.num_sequences),
120 |             "corpus_wer": (self.errors_sum / self.total_words_sum),
121 |         }
122 | 
123 |     def __str__(self) -> str:
124 |         result = self.compute()
125 |         return " ".join(f"{k}: {100 * v:.2f}%" for k, v in result.items())
126 | 
127 | 
128 | class Timer:
129 |     def __init__(self):
130 |         self.total = 0
131 |         self._start_time = None
132 | 
133 |     def __enter__(self):
134 |         self._start_time = time.perf_counter()
135 |         return self
136 | 
137 |     def __exit__(self, *_):
138 |         self.total += time.perf_counter() - self._start_time
139 |         self._start_time = None
140 | 
141 | 
142 | @dataclasses.dataclass
143 | class _DatasetInfo:
144 |     alias: str
145 | 
146 |     name: str
147 |     config: str
148 |     split: str = "test"
149 | 
150 | 
151 | _DATASETS = [
152 |     # Long-form datasets from distil-whisper
153 |     _DatasetInfo("rev16", "distil-whisper/rev16", "whisper_subset"),
154 |     _DatasetInfo("earnings21", "distil-whisper/earnings21", "full"),
155 |     _DatasetInfo("earnings22", "distil-whisper/earnings22", "full"),
156 |     _DatasetInfo("tedlium", "distil-whisper/tedlium-long-form", None),
157 |     _DatasetInfo("meanwhile", "distil-whisper/meanwhile", None),
158 |     # Short-form datasets from OpenASR leaderboard
159 |     _DatasetInfo("ami", "hf-audio/esb-datasets-test-only-sorted", "ami"),
160 |     _DatasetInfo(
161 |         "librispeech.clean",
162 |         "hf-audio/esb-datasets-test-only-sorted",
163 |         "librispeech",
164 |         split="test.clean",
165 |     ),
166 |     _DatasetInfo(
167 |         "librispeech.other",
168 |         "hf-audio/esb-datasets-test-only-sorted",
169 |         "librispeech",
170 |         split="test.other",
171 |     ),
172 |     _DatasetInfo("voxpopuli", "hf-audio/esb-datasets-test-only-sorted", "voxpopuli"),
173 |     _DatasetInfo("spgispeech", "hf-audio/esb-datasets-test-only-sorted", "spgispeech"),
174 |     _DatasetInfo("gigaspeech", "hf-audio/esb-datasets-test-only-sorted", "gigaspeech"),
175 |     _DatasetInfo("tedlium-short", "hf-audio/esb-datasets-test-only-sorted", "tedlium"),
176 |     _DatasetInfo(
177 |         "earnings22-short", "hf-audio/esb-datasets-test-only-sorted", "earnings22"
178 |     ),
179 | ]
180 | DATASET_MAP = {dataset.alias: dataset for dataset in _DATASETS}
181 | 
182 | 
183 | def get_dataset(args) -> Dataset:
184 |     if args.dataset not in DATASET_MAP:
185 |         raise RuntimeError(f"Unknown dataset: {args.dataset}")
186 | 
187 |     info = DATASET_MAP[args.dataset]
188 | 
189 |     dataset = load_dataset(
190 |         info.name,
191 |         info.config,
192 |         split=info.split,
193 |         cache_dir=args.hf_cache_dir,
194 |         streaming=False,
195 |         token=True,
196 |     )
197 |     dataset = dataset.map(normalize)
198 |     dataset = dataset.filter(is_target_text_in_range, input_columns=["norm_text"])
199 | 
200 |     return dataset
201 | 
202 | 
203 | @torch.no_grad
204 | def get_padded_batch(
205 |     audios: list[tuple[torch.Tensor, int]],
206 |     before_padding: float,
207 |     after_padding: float,
208 |     audio_encoder,
209 | ):
210 |     sample_rate = audio_encoder.sample_rate
211 | 
212 |     max_len = 0
213 |     batch = []
214 |     durations = []
215 |     for audio, sr in audios:
216 |         durations.append(audio.shape[-1] / sr)
217 |         audio = julius.resample_frac(audio, int(sr), int(sample_rate))
218 |         audio = torch.nn.functional.pad(
219 |             audio, (int(before_padding * sample_rate), int(after_padding * sample_rate))
220 |         )
221 |         max_len = max(max_len, audio.shape[-1])
222 |         batch.append(audio)
223 | 
224 |     target = max_len
225 |     if target % audio_encoder.frame_size != 0:
226 |         target = target + (
227 |             audio_encoder.frame_size - max_len % audio_encoder.frame_size
228 |         )
229 |     padded_batch = torch.stack(
230 |         [
231 |             torch.nn.functional.pad(audio, (0, target - audio.shape[-1]))
232 |             for audio in batch
233 |         ]
234 |     )
235 |     return padded_batch
236 | 
237 | 
238 | @torch.no_grad
239 | def streaming_transcribe(
240 |     padded_batch: torch.Tensor,
241 |     mimi,
242 |     lm_gen,
243 | ):
244 |     bsz = padded_batch.shape[0]
245 | 
246 |     text_tokens_acc = []
247 | 
248 |     with mimi.streaming(bsz), lm_gen.streaming(bsz):
249 |         for offset in range(0, padded_batch.shape[-1], mimi.frame_size):
250 |             audio_chunk = padded_batch[:, offset : offset + mimi.frame_size]
251 |             audio_chunk = audio_chunk[:, None, :]
252 | 
253 |             audio_tokens = mimi.encode(audio_chunk)
254 |             text_tokens = lm_gen.step(audio_tokens)
255 |             if text_tokens is not None:
256 |                 text_tokens_acc.append(text_tokens)
257 | 
258 |     return torch.concat(text_tokens_acc, axis=-1)
259 | 
260 | 
261 | def run_inference(
262 |     dataset,
263 |     mimi,
264 |     lm_gen,
265 |     tokenizer,
266 |     padding_token_id,
267 |     before_padding_sec,
268 |     after_padding_sec,
269 | ):
270 |     metrics = AsrMetrics()
271 |     audio_time = 0.0
272 |     inference_timer = Timer()
273 | 
274 |     for batch in tqdm.tqdm(dataset.iter(args.batch_size)):
275 |         audio_data = list(
276 |             zip(
277 |                 [torch.tensor(x["array"]).float() for x in batch["audio"]],
278 |                 [x["sampling_rate"] for x in batch["audio"]],
279 |             )
280 |         )
281 | 
282 |         audio_time += sum(audio.shape[-1] / sr for (audio, sr) in audio_data)
283 | 
284 |         gt_transcripts = batch["original_text"]
285 | 
286 |         padded_batch = get_padded_batch(
287 |             audio_data,
288 |             before_padding=before_padding_sec,
289 |             after_padding=after_padding_sec,
290 |             audio_encoder=mimi,
291 |         )
292 |         padded_batch = padded_batch.cuda()
293 | 
294 |         with inference_timer:
295 |             text_tokens = streaming_transcribe(
296 |                 padded_batch,
297 |                 mimi=mimi,
298 |                 lm_gen=lm_gen,
299 |             )
300 | 
301 |         for batch_index in range(text_tokens.shape[0]):
302 |             utterance_tokens = text_tokens[batch_index, ...]
303 |             utterance_tokens = utterance_tokens[utterance_tokens > padding_token_id]
304 |             text = tokenizer.decode(utterance_tokens.cpu().numpy().tolist())
305 |             metrics.update(hyp=text, ref=gt_transcripts[batch_index])
306 | 
307 |     return metrics, inference_timer.total, audio_time
308 | 
309 | 
310 | def main(args):
311 |     torch.set_float32_matmul_precision("high")
312 | 
313 |     info = moshi.models.loaders.CheckpointInfo.from_hf_repo(
314 |         args.hf_repo,
315 |         moshi_weights=args.moshi_weight,
316 |         mimi_weights=args.mimi_weight,
317 |         tokenizer=args.tokenizer,
318 |         config_path=args.config_path,
319 |     )
320 | 
321 |     mimi = info.get_mimi(device=args.device)
322 |     tokenizer = info.get_text_tokenizer()
323 |     lm = info.get_moshi(
324 |         device=args.device,
325 |         dtype=torch.bfloat16,
326 |     )
327 |     lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0)
328 |     dataset = get_dataset(args)
329 | 
330 |     padding_token_id = info.raw_config.get("text_padding_token_id", 3)
331 |     # Putting in some conservative defaults
332 |     audio_silence_prefix_seconds = info.stt_config.get(
333 |         "audio_silence_prefix_seconds", 1.0
334 |     )
335 |     audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0)
336 | 
337 |     wer_metric, inference_time, audio_time = run_inference(
338 |         dataset,
339 |         mimi,
340 |         lm_gen,
341 |         tokenizer,
342 |         padding_token_id,
343 |         audio_silence_prefix_seconds,
344 |         audio_delay_seconds + 0.5,
345 |     )
346 | 
347 |     print(wer_metric, f"RTF = {audio_time / inference_time:.2f}")
348 | 
349 | 
350 | if __name__ == "__main__":
351 |     parser = argparse.ArgumentParser(description="Example streaming STT inference.")
352 |     parser.add_argument(
353 |         "--dataset",
354 |         required=True,
355 |         choices=DATASET_MAP.keys(),
356 |         help="Dataset to run inference on.",
357 |     )
358 | 
359 |     parser.add_argument(
360 |         "--hf-repo", type=str, help="HF repo to load the STT model from."
361 |     )
362 |     parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.")
363 |     parser.add_argument(
364 |         "--moshi-weight", type=str, help="Path to a local checkpoint file."
365 |     )
366 |     parser.add_argument(
367 |         "--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi."
368 |     )
369 |     parser.add_argument(
370 |         "--config-path", type=str, help="Path to a local config file.", default=None
371 |     )
372 |     parser.add_argument(
373 |         "--batch-size",
374 |         type=int,
375 |         help="Batch size.",
376 |         default=32,
377 |     )
378 |     parser.add_argument(
379 |         "--device",
380 |         type=str,
381 |         default="cuda",
382 |         help="Device on which to run, defaults to 'cuda'.",
383 |     )
384 |     parser.add_argument("--hf-cache-dir", type=str, help="HuggingFace cache folder.")
385 |     args = parser.parse_args()
386 | 
387 |     main(args)
388 | 


--------------------------------------------------------------------------------
/scripts/stt_from_file_mlx.py:
--------------------------------------------------------------------------------
 1 | # /// script
 2 | # requires-python = ">=3.12"
 3 | # dependencies = [
 4 | #     "huggingface_hub",
 5 | #     "moshi_mlx==0.2.10",
 6 | #     "numpy",
 7 | #     "sentencepiece",
 8 | #     "sounddevice",
 9 | #     "sphn",
10 | # ]
11 | # ///
12 | 
13 | import argparse
14 | import json
15 | 
16 | import mlx.core as mx
17 | import mlx.nn as nn
18 | import sentencepiece
19 | import sphn
20 | from huggingface_hub import hf_hub_download
21 | from moshi_mlx import models, utils
22 | 
23 | if __name__ == "__main__":
24 |     parser = argparse.ArgumentParser()
25 |     parser.add_argument("in_file", help="The file to transcribe.")
26 |     parser.add_argument("--max-steps", default=4096)
27 |     parser.add_argument("--hf-repo", default="kyutai/stt-1b-en_fr-mlx")
28 |     parser.add_argument(
29 |         "--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
30 |     )
31 |     args = parser.parse_args()
32 | 
33 |     audio, _ = sphn.read(args.in_file, sample_rate=24000)
34 |     lm_config = hf_hub_download(args.hf_repo, "config.json")
35 |     with open(lm_config, "r") as fobj:
36 |         lm_config = json.load(fobj)
37 |     mimi_weights = hf_hub_download(args.hf_repo, lm_config["mimi_name"])
38 |     moshi_name = lm_config.get("moshi_name", "model.safetensors")
39 |     moshi_weights = hf_hub_download(args.hf_repo, moshi_name)
40 |     text_tokenizer = hf_hub_download(args.hf_repo, lm_config["tokenizer_name"])
41 | 
42 |     lm_config = models.LmConfig.from_config_dict(lm_config)
43 |     model = models.Lm(lm_config)
44 |     model.set_dtype(mx.bfloat16)
45 |     if moshi_weights.endswith(".q4.safetensors"):
46 |         nn.quantize(model, bits=4, group_size=32)
47 |     elif moshi_weights.endswith(".q8.safetensors"):
48 |         nn.quantize(model, bits=8, group_size=64)
49 | 
50 |     print(f"loading model weights from {moshi_weights}")
51 |     if args.hf_repo.endswith("-candle"):
52 |         model.load_pytorch_weights(moshi_weights, lm_config, strict=True)
53 |     else:
54 |         model.load_weights(moshi_weights, strict=True)
55 | 
56 |     print(f"loading the text tokenizer from {text_tokenizer}")
57 |     text_tokenizer = sentencepiece.SentencePieceProcessor(text_tokenizer)  # type: ignore
58 | 
59 |     print(f"loading the audio tokenizer {mimi_weights}")
60 |     audio_tokenizer = models.mimi.Mimi(models.mimi_202407(32))
61 |     audio_tokenizer.load_pytorch_weights(str(mimi_weights), strict=True)
62 |     print("warming up the model")
63 |     model.warmup()
64 |     gen = models.LmGen(
65 |         model=model,
66 |         max_steps=args.max_steps,
67 |         text_sampler=utils.Sampler(top_k=25, temp=0),
68 |         audio_sampler=utils.Sampler(top_k=250, temp=0.8),
69 |         check=False,
70 |     )
71 | 
72 |     print(f"starting inference {audio.shape}")
73 |     audio = mx.concat([mx.array(audio), mx.zeros((1, 48000))], axis=-1)
74 |     last_print_was_vad = False
75 |     for start_idx in range(0, audio.shape[-1] // 1920 * 1920, 1920):
76 |         block = audio[:, None, start_idx : start_idx + 1920]
77 |         other_audio_tokens = audio_tokenizer.encode_step(block).transpose(0, 2, 1)
78 |         if args.vad:
79 |             text_token, vad_heads = gen.step_with_extra_heads(other_audio_tokens[0])
80 |             if vad_heads:
81 |                 pr_vad = vad_heads[2][0, 0, 0].item()
82 |                 if pr_vad > 0.5 and not last_print_was_vad:
83 |                     print(" [end of turn detected]")
84 |                     last_print_was_vad = True
85 |         else:
86 |             text_token = gen.step(other_audio_tokens[0])
87 |         text_token = text_token[0].item()
88 |         audio_tokens = gen.last_audio_tokens()
89 |         _text = None
90 |         if text_token not in (0, 3):
91 |             _text = text_tokenizer.id_to_piece(text_token)  # type: ignore
92 |             _text = _text.replace("▁", " ")
93 |             print(_text, end="", flush=True)
94 |             last_print_was_vad = False
95 |     print()
96 | 


--------------------------------------------------------------------------------
/scripts/stt_from_file_pytorch.py:
--------------------------------------------------------------------------------
  1 | # /// script
  2 | # requires-python = ">=3.12"
  3 | # dependencies = [
  4 | #     "julius",
  5 | #     "librosa",
  6 | #     "soundfile",
  7 | #     "moshi==0.2.9",
  8 | # ]
  9 | # ///
 10 | 
 11 | """An example script that illustrates how one can get per-word timestamps from
 12 | Kyutai STT models.
 13 | """
 14 | 
 15 | import argparse
 16 | import dataclasses
 17 | import itertools
 18 | import math
 19 | 
 20 | import julius
 21 | import moshi.models
 22 | import sphn
 23 | import time
 24 | import torch
 25 | 
 26 | 
 27 | @dataclasses.dataclass
 28 | class TimestampedText:
 29 |     text: str
 30 |     timestamp: tuple[float, float]
 31 | 
 32 |     def __str__(self):
 33 |         return f"{self.text} ({self.timestamp[0]:.2f}:{self.timestamp[1]:.2f})"
 34 | 
 35 | 
 36 | def tokens_to_timestamped_text(
 37 |     text_tokens,
 38 |     tokenizer,
 39 |     frame_rate,
 40 |     end_of_padding_id,
 41 |     padding_token_id,
 42 |     offset_seconds,
 43 | ) -> list[TimestampedText]:
 44 |     text_tokens = text_tokens.cpu().view(-1)
 45 | 
 46 |     # Normally `end_of_padding` tokens indicate word boundaries.
 47 |     # Everything between them should be a single word;
 48 |     # the time offset of the those tokens correspond to word start and
 49 |     # end timestamps (minus silence prefix and audio delay).
 50 |     #
 51 |     # However, in rare cases some complexities could arise. Firstly,
 52 |     # for words that are said quickly but are represented with
 53 |     # multiple tokens, the boundary might be omitted. Secondly,
 54 |     # for the very last word the end boundary might not happen.
 55 |     # Below is a code snippet that handles those situations a bit
 56 |     # more carefully.
 57 | 
 58 |     sequence_timestamps = []
 59 | 
 60 |     def _tstmp(start_position, end_position):
 61 |         return (
 62 |             max(0, start_position / frame_rate - offset_seconds),
 63 |             max(0, end_position / frame_rate - offset_seconds),
 64 |         )
 65 | 
 66 |     def _decode(t):
 67 |         t = t[t > padding_token_id]
 68 |         return tokenizer.decode(t.numpy().tolist())
 69 | 
 70 |     def _decode_segment(start, end):
 71 |         nonlocal text_tokens
 72 |         nonlocal sequence_timestamps
 73 | 
 74 |         text = _decode(text_tokens[start:end])
 75 |         words_inside_segment = text.split()
 76 | 
 77 |         if len(words_inside_segment) == 0:
 78 |             return
 79 |         if len(words_inside_segment) == 1:
 80 |             # Single word within the boundaries, the general case
 81 |             sequence_timestamps.append(
 82 |                 TimestampedText(text=text, timestamp=_tstmp(start, end))
 83 |             )
 84 |         else:
 85 |             # We're in a rare situation where multiple words are so close they are not separated by `end_of_padding`.
 86 |             # We tokenize words one-by-one; each word is assigned with as many frames as much tokens it has.
 87 |             for adjacent_word in words_inside_segment[:-1]:
 88 |                 n_tokens = len(tokenizer.encode(adjacent_word))
 89 |                 sequence_timestamps.append(
 90 |                     TimestampedText(
 91 |                         text=adjacent_word, timestamp=_tstmp(start, start + n_tokens)
 92 |                     )
 93 |                 )
 94 |                 start += n_tokens
 95 | 
 96 |             # The last word takes everything until the boundary
 97 |             adjacent_word = words_inside_segment[-1]
 98 |             sequence_timestamps.append(
 99 |                 TimestampedText(text=adjacent_word, timestamp=_tstmp(start, end))
100 |             )
101 | 
102 |     (segment_boundaries,) = torch.where(text_tokens == end_of_padding_id)
103 | 
104 |     if not segment_boundaries.numel():
105 |         return []
106 | 
107 |     for i in range(len(segment_boundaries) - 1):
108 |         segment_start = int(segment_boundaries[i]) + 1
109 |         segment_end = int(segment_boundaries[i + 1])
110 | 
111 |         _decode_segment(segment_start, segment_end)
112 | 
113 |     last_segment_start = segment_boundaries[-1] + 1
114 | 
115 |     boundary_token = torch.tensor([tokenizer.eos_id()])
116 |     (end_of_last_segment,) = torch.where(
117 |         torch.isin(text_tokens[last_segment_start:], boundary_token)
118 |     )
119 | 
120 |     if not end_of_last_segment.numel():
121 |         # upper-bound either end of the audio or 1 second duration, whicher is smaller
122 |         last_segment_end = min(text_tokens.shape[-1], last_segment_start + frame_rate)
123 |     else:
124 |         last_segment_end = last_segment_start + end_of_last_segment[0]
125 |     _decode_segment(last_segment_start, last_segment_end)
126 | 
127 |     return sequence_timestamps
128 | 
129 | 
130 | def main(args):
131 |     info = moshi.models.loaders.CheckpointInfo.from_hf_repo(
132 |         args.hf_repo,
133 |         moshi_weights=args.moshi_weight,
134 |         mimi_weights=args.mimi_weight,
135 |         tokenizer=args.tokenizer,
136 |         config_path=args.config_path,
137 |     )
138 | 
139 |     mimi = info.get_mimi(device=args.device)
140 |     tokenizer = info.get_text_tokenizer()
141 |     lm = info.get_moshi(
142 |         device=args.device,
143 |         dtype=torch.bfloat16,
144 |     )
145 |     lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0)
146 | 
147 |     audio_silence_prefix_seconds = info.stt_config.get(
148 |         "audio_silence_prefix_seconds", 1.0
149 |     )
150 |     audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0)
151 |     padding_token_id = info.raw_config.get("text_padding_token_id", 3)
152 | 
153 |     audio, input_sample_rate = sphn.read(args.in_file)
154 |     audio = torch.from_numpy(audio).to(args.device)
155 |     audio = julius.resample_frac(audio, input_sample_rate, mimi.sample_rate)
156 |     if audio.shape[-1] % mimi.frame_size != 0:
157 |         to_pad = mimi.frame_size - audio.shape[-1] % mimi.frame_size
158 |         audio = torch.nn.functional.pad(audio, (0, to_pad))
159 | 
160 |     text_tokens_accum = []
161 | 
162 |     n_prefix_chunks = math.ceil(audio_silence_prefix_seconds * mimi.frame_rate)
163 |     n_suffix_chunks = math.ceil(audio_delay_seconds * mimi.frame_rate)
164 |     silence_chunk = torch.zeros(
165 |         (1, 1, mimi.frame_size), dtype=torch.float32, device=args.device
166 |     )
167 | 
168 |     chunks = itertools.chain(
169 |         itertools.repeat(silence_chunk, n_prefix_chunks),
170 |         torch.split(audio[:, None], mimi.frame_size, dim=-1),
171 |         itertools.repeat(silence_chunk, n_suffix_chunks),
172 |     )
173 | 
174 |     start_time = time.time()
175 |     nchunks = 0
176 |     last_print_was_vad = False
177 |     with mimi.streaming(1), lm_gen.streaming(1):
178 |         for audio_chunk in chunks:
179 |             nchunks += 1
180 |             audio_tokens = mimi.encode(audio_chunk)
181 |             if args.vad:
182 |                 text_tokens, vad_heads = lm_gen.step_with_extra_heads(audio_tokens)
183 |                 if vad_heads:
184 |                     pr_vad = vad_heads[2][0, 0, 0].cpu().item()
185 |                     if pr_vad > 0.5 and not last_print_was_vad:
186 |                         print(" [end of turn detected]")
187 |                         last_print_was_vad = True
188 |             else:
189 |                 text_tokens = lm_gen.step(audio_tokens)
190 |             text_token = text_tokens[0, 0, 0].cpu().item()
191 |             if text_token not in (0, 3):
192 |                 _text = tokenizer.id_to_piece(text_tokens[0, 0, 0].cpu().item())  # type: ignore
193 |                 _text = _text.replace("▁", " ")
194 |                 print(_text, end="", flush=True)
195 |                 last_print_was_vad = False
196 |             text_tokens_accum.append(text_tokens)
197 | 
198 |     utterance_tokens = torch.concat(text_tokens_accum, dim=-1)
199 |     dt = time.time() - start_time
200 |     print(
201 |         f"\nprocessed {nchunks} chunks in {dt:.2f} seconds, steps per second: {nchunks / dt:.2f}"
202 |     )
203 |     timed_text = tokens_to_timestamped_text(
204 |         utterance_tokens,
205 |         tokenizer,
206 |         mimi.frame_rate,
207 |         end_of_padding_id=0,
208 |         padding_token_id=padding_token_id,
209 |         offset_seconds=int(n_prefix_chunks / mimi.frame_rate) + audio_delay_seconds,
210 |     )
211 | 
212 |     decoded = " ".join([str(t) for t in timed_text])
213 |     print(decoded)
214 | 
215 | 
216 | if __name__ == "__main__":
217 |     parser = argparse.ArgumentParser(description="Example streaming STT w/ timestamps.")
218 |     parser.add_argument("in_file", help="The file to transcribe.")
219 | 
220 |     parser.add_argument(
221 |         "--hf-repo", type=str, help="HF repo to load the STT model from. "
222 |     )
223 |     parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.")
224 |     parser.add_argument(
225 |         "--moshi-weight", type=str, help="Path to a local checkpoint file."
226 |     )
227 |     parser.add_argument(
228 |         "--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi."
229 |     )
230 |     parser.add_argument(
231 |         "--config-path", type=str, help="Path to a local config file.", default=None
232 |     )
233 |     parser.add_argument(
234 |         "--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
235 |     )
236 |     parser.add_argument(
237 |         "--device",
238 |         type=str,
239 |         default="cuda",
240 |         help="Device on which to run, defaults to 'cuda'.",
241 |     )
242 |     args = parser.parse_args()
243 | 
244 |     main(args)
245 | 


--------------------------------------------------------------------------------
/scripts/stt_from_file_rust_server.py:
--------------------------------------------------------------------------------
  1 | # /// script
  2 | # requires-python = ">=3.12"
  3 | # dependencies = [
  4 | #     "msgpack",
  5 | #     "numpy",
  6 | #     "sphn",
  7 | #     "websockets",
  8 | # ]
  9 | # ///
 10 | import argparse
 11 | import asyncio
 12 | import time
 13 | 
 14 | import msgpack
 15 | import numpy as np
 16 | import sphn
 17 | import websockets
 18 | 
 19 | SAMPLE_RATE = 24000
 20 | FRAME_SIZE = 1920  # Send data in chunks
 21 | 
 22 | 
 23 | def load_and_process_audio(file_path):
 24 |     """Load an MP3 file, resample to 24kHz, convert to mono, and extract PCM float32 data."""
 25 |     pcm_data, _ = sphn.read(file_path, sample_rate=SAMPLE_RATE)
 26 |     return pcm_data[0]
 27 | 
 28 | 
 29 | async def receive_messages(websocket):
 30 |     transcript = []
 31 | 
 32 |     async for message in websocket:
 33 |         data = msgpack.unpackb(message, raw=False)
 34 |         if data["type"] == "Step":
 35 |             # This message contains the signal from the semantic VAD, and tells us how
 36 |             # much audio the server has already processed. We don't use either here.
 37 |             continue
 38 |         if data["type"] == "Word":
 39 |             print(data["text"], end=" ", flush=True)
 40 |             transcript.append(
 41 |                 {
 42 |                     "text": data["text"],
 43 |                     "timestamp": [data["start_time"], data["start_time"]],
 44 |                 }
 45 |             )
 46 |         if data["type"] == "EndWord":
 47 |             if len(transcript) > 0:
 48 |                 transcript[-1]["timestamp"][1] = data["stop_time"]
 49 |         if data["type"] == "Marker":
 50 |             # Received marker, stopping stream
 51 |             break
 52 | 
 53 |     return transcript
 54 | 
 55 | 
 56 | async def send_messages(websocket, rtf: float):
 57 |     audio_data = load_and_process_audio(args.in_file)
 58 | 
 59 |     async def send_audio(audio: np.ndarray):
 60 |         await websocket.send(
 61 |             msgpack.packb(
 62 |                 {"type": "Audio", "pcm": [float(x) for x in audio]},
 63 |                 use_single_float=True,
 64 |             )
 65 |         )
 66 | 
 67 |     # Start with a second of silence.
 68 |     # This is needed for the 2.6B model for technical reasons.
 69 |     await send_audio([0.0] * SAMPLE_RATE)
 70 | 
 71 |     start_time = time.time()
 72 |     for i in range(0, len(audio_data), FRAME_SIZE):
 73 |         await send_audio(audio_data[i : i + FRAME_SIZE])
 74 | 
 75 |         expected_send_time = start_time + (i + 1) / SAMPLE_RATE / rtf
 76 |         current_time = time.time()
 77 |         if current_time < expected_send_time:
 78 |             await asyncio.sleep(expected_send_time - current_time)
 79 |         else:
 80 |             await asyncio.sleep(0.001)
 81 | 
 82 |     for _ in range(5):
 83 |         await send_audio([0.0] * SAMPLE_RATE)
 84 | 
 85 |     # Send a marker to indicate the end of the stream.
 86 |     await websocket.send(
 87 |         msgpack.packb({"type": "Marker", "id": 0}, use_single_float=True)
 88 |     )
 89 | 
 90 |     # We'll get back the marker once the corresponding audio has been transcribed,
 91 |     # accounting for the delay of the model. That's why we need to send some silence
 92 |     # after the marker, because the model will not return the marker immediately.
 93 |     for _ in range(35):
 94 |         await send_audio([0.0] * SAMPLE_RATE)
 95 | 
 96 | 
 97 | async def stream_audio(url: str, api_key: str, rtf: float):
 98 |     """Stream audio data to a WebSocket server."""
 99 |     headers = {"kyutai-api-key": api_key}
100 | 
101 |     # Instead of using the header, you can authenticate by adding `?auth_id={api_key}` to the URL
102 |     async with websockets.connect(url, additional_headers=headers) as websocket:
103 |         send_task = asyncio.create_task(send_messages(websocket, rtf))
104 |         receive_task = asyncio.create_task(receive_messages(websocket))
105 |         _, transcript = await asyncio.gather(send_task, receive_task)
106 | 
107 |     return transcript
108 | 
109 | 
110 | if __name__ == "__main__":
111 |     parser = argparse.ArgumentParser()
112 |     parser.add_argument("in_file")
113 |     parser.add_argument(
114 |         "--url",
115 |         help="The url of the server to which to send the audio",
116 |         default="ws://127.0.0.1:8080",
117 |     )
118 |     parser.add_argument("--api-key", default="public_token")
119 |     parser.add_argument(
120 |         "--rtf",
121 |         type=float,
122 |         default=1.01,
123 |         help="The real-time factor of how fast to feed in the audio.",
124 |     )
125 |     args = parser.parse_args()
126 | 
127 |     url = f"{args.url}/api/asr-streaming"
128 |     transcript = asyncio.run(stream_audio(url, args.api_key, args.rtf))
129 | 
130 |     print()
131 |     print()
132 |     for word in transcript:
133 |         print(
134 |             f"{word['timestamp'][0]:7.2f} -{word['timestamp'][1]:7.2f}  {word['text']}"
135 |         )
136 | 


--------------------------------------------------------------------------------
/scripts/stt_from_file_with_prompt_pytorch.py:
--------------------------------------------------------------------------------
  1 | """An example script that illustrates how one can prompt Kyutai STT models."""
  2 | 
  3 | import argparse
  4 | import itertools
  5 | import math
  6 | from collections import deque
  7 | 
  8 | import julius
  9 | import moshi.models
 10 | import sphn
 11 | import torch
 12 | import tqdm
 13 | 
 14 | 
 15 | class PromptHook:
 16 |     def __init__(self, tokenizer, prefix, padding_tokens=(0, 3)):
 17 |         self.tokenizer = tokenizer
 18 |         self.prefix_enforce = deque(self.tokenizer.encode(prefix))
 19 |         self.padding_tokens = padding_tokens
 20 | 
 21 |     def on_token(self, token):
 22 |         if not self.prefix_enforce:
 23 |             return
 24 | 
 25 |         token = token.item()
 26 | 
 27 |         if token in self.padding_tokens:
 28 |             pass
 29 |         elif token == self.prefix_enforce[0]:
 30 |             self.prefix_enforce.popleft()
 31 |         else:
 32 |             assert False
 33 | 
 34 |     def on_logits(self, logits):
 35 |         if not self.prefix_enforce:
 36 |             return
 37 | 
 38 |         mask = torch.zeros_like(logits, dtype=torch.bool)
 39 |         for t in self.padding_tokens:
 40 |             mask[..., t] = True
 41 |         mask[..., self.prefix_enforce[0]] = True
 42 | 
 43 |         logits[:] = torch.where(mask, logits, float("-inf"))
 44 | 
 45 | 
 46 | def main(args):
 47 |     info = moshi.models.loaders.CheckpointInfo.from_hf_repo(
 48 |         args.hf_repo,
 49 |         moshi_weights=args.moshi_weight,
 50 |         mimi_weights=args.mimi_weight,
 51 |         tokenizer=args.tokenizer,
 52 |         config_path=args.config_path,
 53 |     )
 54 | 
 55 |     mimi = info.get_mimi(device=args.device)
 56 |     tokenizer = info.get_text_tokenizer()
 57 |     lm = info.get_moshi(
 58 |         device=args.device,
 59 |         dtype=torch.bfloat16,
 60 |     )
 61 | 
 62 |     if args.prompt_text:
 63 |         prompt_hook = PromptHook(tokenizer, args.prompt_text)
 64 |         lm_gen = moshi.models.LMGen(
 65 |             lm,
 66 |             temp=0,
 67 |             temp_text=0.0,
 68 |             on_text_hook=prompt_hook.on_token,
 69 |             on_text_logits_hook=prompt_hook.on_logits,
 70 |         )
 71 |     else:
 72 |         lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0)
 73 | 
 74 |     audio_silence_prefix_seconds = info.stt_config.get(
 75 |         "audio_silence_prefix_seconds", 1.0
 76 |     )
 77 |     audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0)
 78 |     padding_token_id = info.raw_config.get("text_padding_token_id", 3)
 79 | 
 80 |     def _load_and_process(path):
 81 |         audio, input_sample_rate = sphn.read(path)
 82 |         audio = torch.from_numpy(audio).to(args.device).mean(axis=0, keepdim=True)
 83 |         audio = julius.resample_frac(audio, input_sample_rate, mimi.sample_rate)
 84 |         if audio.shape[-1] % mimi.frame_size != 0:
 85 |             to_pad = mimi.frame_size - audio.shape[-1] % mimi.frame_size
 86 |             audio = torch.nn.functional.pad(audio, (0, to_pad))
 87 |         return audio
 88 | 
 89 |     n_prefix_chunks = math.ceil(audio_silence_prefix_seconds * mimi.frame_rate)
 90 |     n_suffix_chunks = math.ceil(audio_delay_seconds * mimi.frame_rate)
 91 |     silence_chunk = torch.zeros(
 92 |         (1, 1, mimi.frame_size), dtype=torch.float32, device=args.device
 93 |     )
 94 | 
 95 |     audio = _load_and_process(args.file)
 96 |     if args.prompt_file:
 97 |         audio_prompt = _load_and_process(args.prompt_file)
 98 |     else:
 99 |         audio_prompt = None
100 | 
101 |     chain = [itertools.repeat(silence_chunk, n_prefix_chunks)]
102 | 
103 |     if audio_prompt is not None:
104 |         chain.append(torch.split(audio_prompt[:, None, :], mimi.frame_size, dim=-1))
105 |         # adding a bit (0.8s) of silence to separate prompt and the actual audio
106 |         chain.append(itertools.repeat(silence_chunk, 10))
107 | 
108 |     chain += [
109 |         torch.split(audio[:, None, :], mimi.frame_size, dim=-1),
110 |         itertools.repeat(silence_chunk, n_suffix_chunks),
111 |     ]
112 | 
113 |     chunks = itertools.chain(*chain)
114 | 
115 |     text_tokens_accum = []
116 |     with mimi.streaming(1), lm_gen.streaming(1):
117 |         for audio_chunk in tqdm.tqdm(chunks):
118 |             audio_tokens = mimi.encode(audio_chunk)
119 |             text_tokens = lm_gen.step(audio_tokens)
120 |             if text_tokens is not None:
121 |                 text_tokens_accum.append(text_tokens)
122 | 
123 |     utterance_tokens = torch.concat(text_tokens_accum, dim=-1)
124 |     text_tokens = utterance_tokens.cpu().view(-1)
125 | 
126 |     # if we have an audio prompt and we don't want to have it in the transcript,
127 |     # we should cut the corresponding number of frames from the output tokens.
128 |     # However, there is also some amount of padding that happens before it
129 |     # due to silence_prefix and audio_delay. Normally it is ignored in detokenization,
130 |     # but now we should account for it to find the position of the prompt transcript.
131 |     if args.cut_prompt_transcript and audio_prompt is not None:
132 |         prompt_frames = audio_prompt.shape[1] // mimi.frame_size
133 |         no_prompt_offset_seconds = audio_delay_seconds + audio_silence_prefix_seconds
134 |         no_prompt_offset = int(no_prompt_offset_seconds * mimi.frame_rate)
135 |         text_tokens = text_tokens[prompt_frames + no_prompt_offset :]
136 | 
137 |     text = tokenizer.decode(
138 |         text_tokens[text_tokens > padding_token_id].numpy().tolist()
139 |     )
140 | 
141 |     print(text)
142 | 
143 | 
144 | if __name__ == "__main__":
145 |     parser = argparse.ArgumentParser(description="Example streaming STT w/ a prompt.")
146 |     parser.add_argument(
147 |         "--file",
148 |         required=True,
149 |         help="File to transcribe.",
150 |     )
151 |     parser.add_argument(
152 |         "--prompt_file",
153 |         required=False,
154 |         help="Audio of the prompt.",
155 |     )
156 |     parser.add_argument(
157 |         "--prompt_text",
158 |         required=False,
159 |         help="Text of the prompt.",
160 |     )
161 |     parser.add_argument(
162 |         "--cut-prompt-transcript",
163 |         action="store_true",
164 |         help="Cut the prompt from the output transcript",
165 |     )
166 |     parser.add_argument(
167 |         "--hf-repo", type=str, help="HF repo to load the STT model from. "
168 |     )
169 |     parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.")
170 |     parser.add_argument(
171 |         "--moshi-weight", type=str, help="Path to a local checkpoint file."
172 |     )
173 |     parser.add_argument(
174 |         "--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi."
175 |     )
176 |     parser.add_argument(
177 |         "--config-path", type=str, help="Path to a local config file.", default=None
178 |     )
179 |     parser.add_argument(
180 |         "--device",
181 |         type=str,
182 |         default="cuda",
183 |         help="Device on which to run, defaults to 'cuda'.",
184 |     )
185 |     args = parser.parse_args()
186 | 
187 |     main(args)
188 | 


--------------------------------------------------------------------------------
/scripts/stt_from_mic_mlx.py:
--------------------------------------------------------------------------------
  1 | # /// script
  2 | # requires-python = ">=3.12"
  3 | # dependencies = [
  4 | #     "huggingface_hub",
  5 | #     "moshi_mlx==0.2.10",
  6 | #     "numpy",
  7 | #     "rustymimi",
  8 | #     "sentencepiece",
  9 | #     "sounddevice",
 10 | # ]
 11 | # ///
 12 | 
 13 | import argparse
 14 | import json
 15 | import queue
 16 | 
 17 | import mlx.core as mx
 18 | import mlx.nn as nn
 19 | import rustymimi
 20 | import sentencepiece
 21 | import sounddevice as sd
 22 | from huggingface_hub import hf_hub_download
 23 | from moshi_mlx import models, utils
 24 | 
 25 | if __name__ == "__main__":
 26 |     parser = argparse.ArgumentParser()
 27 |     parser.add_argument("--max-steps", default=4096)
 28 |     parser.add_argument("--hf-repo", default="kyutai/stt-1b-en_fr-mlx")
 29 |     parser.add_argument(
 30 |         "--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
 31 |     )
 32 |     args = parser.parse_args()
 33 | 
 34 |     lm_config = hf_hub_download(args.hf_repo, "config.json")
 35 |     with open(lm_config, "r") as fobj:
 36 |         lm_config = json.load(fobj)
 37 |     mimi_weights = hf_hub_download(args.hf_repo, lm_config["mimi_name"])
 38 |     moshi_name = lm_config.get("moshi_name", "model.safetensors")
 39 |     moshi_weights = hf_hub_download(args.hf_repo, moshi_name)
 40 |     tokenizer = hf_hub_download(args.hf_repo, lm_config["tokenizer_name"])
 41 | 
 42 |     lm_config = models.LmConfig.from_config_dict(lm_config)
 43 |     model = models.Lm(lm_config)
 44 |     model.set_dtype(mx.bfloat16)
 45 |     if moshi_weights.endswith(".q4.safetensors"):
 46 |         nn.quantize(model, bits=4, group_size=32)
 47 |     elif moshi_weights.endswith(".q8.safetensors"):
 48 |         nn.quantize(model, bits=8, group_size=64)
 49 | 
 50 |     print(f"loading model weights from {moshi_weights}")
 51 |     if args.hf_repo.endswith("-candle"):
 52 |         model.load_pytorch_weights(moshi_weights, lm_config, strict=True)
 53 |     else:
 54 |         model.load_weights(moshi_weights, strict=True)
 55 | 
 56 |     print(f"loading the text tokenizer from {tokenizer}")
 57 |     text_tokenizer = sentencepiece.SentencePieceProcessor(tokenizer)  # type: ignore
 58 | 
 59 |     print(f"loading the audio tokenizer {mimi_weights}")
 60 |     generated_codebooks = lm_config.generated_codebooks
 61 |     other_codebooks = lm_config.other_codebooks
 62 |     mimi_codebooks = max(generated_codebooks, other_codebooks)
 63 |     audio_tokenizer = rustymimi.Tokenizer(mimi_weights, num_codebooks=mimi_codebooks)  # type: ignore
 64 |     print("warming up the model")
 65 |     model.warmup()
 66 |     gen = models.LmGen(
 67 |         model=model,
 68 |         max_steps=args.max_steps,
 69 |         text_sampler=utils.Sampler(top_k=25, temp=0),
 70 |         audio_sampler=utils.Sampler(top_k=250, temp=0.8),
 71 |         check=False,
 72 |     )
 73 | 
 74 |     block_queue = queue.Queue()
 75 | 
 76 |     def audio_callback(indata, _frames, _time, _status):
 77 |         block_queue.put(indata.copy())
 78 | 
 79 |     print("recording audio from microphone, speak to get your words transcribed")
 80 |     last_print_was_vad = False
 81 |     with sd.InputStream(
 82 |         channels=1,
 83 |         dtype="float32",
 84 |         samplerate=24000,
 85 |         blocksize=1920,
 86 |         callback=audio_callback,
 87 |     ):
 88 |         while True:
 89 |             block = block_queue.get()
 90 |             block = block[None, :, 0]
 91 |             other_audio_tokens = audio_tokenizer.encode_step(block[None, 0:1])
 92 |             other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[
 93 |                 :, :, :other_codebooks
 94 |             ]
 95 |             if args.vad:
 96 |                 text_token, vad_heads = gen.step_with_extra_heads(other_audio_tokens[0])
 97 |                 if vad_heads:
 98 |                     pr_vad = vad_heads[2][0, 0, 0].item()
 99 |                     if pr_vad > 0.5 and not last_print_was_vad:
100 |                         print(" [end of turn detected]")
101 |                         last_print_was_vad = True
102 |             else:
103 |                 text_token = gen.step(other_audio_tokens[0])
104 |             text_token = text_token[0].item()
105 |             audio_tokens = gen.last_audio_tokens()
106 |             _text = None
107 |             if text_token not in (0, 3):
108 |                 _text = text_tokenizer.id_to_piece(text_token)  # type: ignore
109 |                 _text = _text.replace("▁", " ")
110 |                 print(_text, end="", flush=True)
111 |                 last_print_was_vad = False
112 | 


--------------------------------------------------------------------------------
/scripts/stt_from_mic_rust_server.py:
--------------------------------------------------------------------------------
  1 | # /// script
  2 | # requires-python = ">=3.12"
  3 | # dependencies = [
  4 | #     "msgpack",
  5 | #     "numpy",
  6 | #     "sounddevice",
  7 | #     "websockets",
  8 | # ]
  9 | # ///
 10 | import argparse
 11 | import asyncio
 12 | import signal
 13 | 
 14 | import msgpack
 15 | import numpy as np
 16 | import sounddevice as sd
 17 | import websockets
 18 | 
 19 | SAMPLE_RATE = 24000
 20 | 
 21 | # The VAD has several prediction heads, each of which tries to determine whether there
 22 | # has been a pause of a given length. The lengths are 0.5, 1.0, 2.0, and 3.0 seconds.
 23 | # Lower indices predict pauses more aggressively. In Unmute, we use 2.0 seconds = index 2.
 24 | PAUSE_PREDICTION_HEAD_INDEX = 2
 25 | 
 26 | 
 27 | async def receive_messages(websocket, show_vad: bool = False):
 28 |     """Receive and process messages from the WebSocket server."""
 29 |     try:
 30 |         speech_started = False
 31 |         async for message in websocket:
 32 |             data = msgpack.unpackb(message, raw=False)
 33 | 
 34 |             # The Step message only gets sent if the model has semantic VAD available
 35 |             if data["type"] == "Step" and show_vad:
 36 |                 pause_prediction = data["prs"][PAUSE_PREDICTION_HEAD_INDEX]
 37 |                 if pause_prediction > 0.5 and speech_started:
 38 |                     print("| ", end="", flush=True)
 39 |                     speech_started = False
 40 | 
 41 |             elif data["type"] == "Word":
 42 |                 print(data["text"], end=" ", flush=True)
 43 |                 speech_started = True
 44 |     except websockets.ConnectionClosed:
 45 |         print("Connection closed while receiving messages.")
 46 | 
 47 | 
 48 | async def send_messages(websocket, audio_queue):
 49 |     """Send audio data from microphone to WebSocket server."""
 50 |     try:
 51 |         # Start by draining the queue to avoid lags
 52 |         while not audio_queue.empty():
 53 |             await audio_queue.get()
 54 | 
 55 |         print("Starting the transcription")
 56 | 
 57 |         while True:
 58 |             audio_data = await audio_queue.get()
 59 |             chunk = {"type": "Audio", "pcm": [float(x) for x in audio_data]}
 60 |             msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
 61 |             await websocket.send(msg)
 62 | 
 63 |     except websockets.ConnectionClosed:
 64 |         print("Connection closed while sending messages.")
 65 | 
 66 | 
 67 | async def stream_audio(url: str, api_key: str, show_vad: bool):
 68 |     """Stream audio data to a WebSocket server."""
 69 |     print("Starting microphone recording...")
 70 |     print("Press Ctrl+C to stop recording")
 71 |     audio_queue = asyncio.Queue()
 72 | 
 73 |     loop = asyncio.get_event_loop()
 74 | 
 75 |     def audio_callback(indata, frames, time, status):
 76 |         loop.call_soon_threadsafe(
 77 |             audio_queue.put_nowait, indata[:, 0].astype(np.float32).copy()
 78 |         )
 79 | 
 80 |     # Start audio stream
 81 |     with sd.InputStream(
 82 |         samplerate=SAMPLE_RATE,
 83 |         channels=1,
 84 |         dtype="float32",
 85 |         callback=audio_callback,
 86 |         blocksize=1920,  # 80ms blocks
 87 |     ):
 88 |         headers = {"kyutai-api-key": api_key}
 89 |         # Instead of using the header, you can authenticate by adding `?auth_id={api_key}` to the URL
 90 |         async with websockets.connect(url, additional_headers=headers) as websocket:
 91 |             send_task = asyncio.create_task(send_messages(websocket, audio_queue))
 92 |             receive_task = asyncio.create_task(
 93 |                 receive_messages(websocket, show_vad=show_vad)
 94 |             )
 95 |             await asyncio.gather(send_task, receive_task)
 96 | 
 97 | 
 98 | if __name__ == "__main__":
 99 |     parser = argparse.ArgumentParser(description="Real-time microphone transcription")
100 |     parser.add_argument(
101 |         "--url",
102 |         help="The URL of the server to which to send the audio",
103 |         default="ws://127.0.0.1:8080",
104 |     )
105 |     parser.add_argument("--api-key", default="public_token")
106 |     parser.add_argument(
107 |         "--list-devices", action="store_true", help="List available audio devices"
108 |     )
109 |     parser.add_argument(
110 |         "--device", type=int, help="Input device ID (use --list-devices to see options)"
111 |     )
112 |     parser.add_argument(
113 |         "--show-vad",
114 |         action="store_true",
115 |         help="Visualize the predictions of the semantic voice activity detector with a '|' symbol",
116 |     )
117 | 
118 |     args = parser.parse_args()
119 | 
120 |     def handle_sigint(signum, frame):
121 |         print("Interrupted by user")  # Don't complain about KeyboardInterrupt
122 |         exit(0)
123 | 
124 |     signal.signal(signal.SIGINT, handle_sigint)
125 | 
126 |     if args.list_devices:
127 |         print("Available audio devices:")
128 |         print(sd.query_devices())
129 |         exit(0)
130 | 
131 |     if args.device is not None:
132 |         sd.default.device[0] = args.device  # Set input device
133 | 
134 |     url = f"{args.url}/api/asr-streaming"
135 |     asyncio.run(stream_audio(url, args.api_key, args.show_vad))
136 | 


--------------------------------------------------------------------------------
/scripts/tts_mlx.py:
--------------------------------------------------------------------------------
  1 | # /// script
  2 | # requires-python = ">=3.12"
  3 | # dependencies = [
  4 | #     "huggingface_hub",
  5 | #     "moshi_mlx==0.2.9",
  6 | #     "numpy",
  7 | #     "sounddevice",
  8 | # ]
  9 | # ///
 10 | 
 11 | import argparse
 12 | import json
 13 | import queue
 14 | import sys
 15 | import time
 16 | 
 17 | import mlx.core as mx
 18 | import mlx.nn as nn
 19 | import numpy as np
 20 | import sentencepiece
 21 | import sounddevice as sd
 22 | import sphn
 23 | from moshi_mlx import models
 24 | from moshi_mlx.client_utils import make_log
 25 | from moshi_mlx.models.tts import (
 26 |     DEFAULT_DSM_TTS_REPO,
 27 |     DEFAULT_DSM_TTS_VOICE_REPO,
 28 |     TTSModel,
 29 | )
 30 | from moshi_mlx.utils.loaders import hf_get
 31 | 
 32 | 
 33 | def log(level: str, msg: str):
 34 |     print(make_log(level, msg))
 35 | 
 36 | 
 37 | def main():
 38 |     parser = argparse.ArgumentParser(
 39 |         description="Run Kyutai TTS using the PyTorch implementation"
 40 |     )
 41 |     parser.add_argument("inp", type=str, help="Input file, use - for stdin")
 42 |     parser.add_argument(
 43 |         "out", type=str, help="Output file to generate, use - for playing the audio"
 44 |     )
 45 |     parser.add_argument(
 46 |         "--hf-repo",
 47 |         type=str,
 48 |         default=DEFAULT_DSM_TTS_REPO,
 49 |         help="HF repo in which to look for the pretrained models.",
 50 |     )
 51 |     parser.add_argument(
 52 |         "--voice-repo",
 53 |         default=DEFAULT_DSM_TTS_VOICE_REPO,
 54 |         help="HF repo in which to look for pre-computed voice embeddings.",
 55 |     )
 56 |     parser.add_argument(
 57 |         "--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav"
 58 |     )
 59 |     parser.add_argument(
 60 |         "--quantize",
 61 |         type=int,
 62 |         help="The quantization to be applied, e.g. 8 for 8 bits.",
 63 |     )
 64 |     args = parser.parse_args()
 65 | 
 66 |     mx.random.seed(299792458)
 67 | 
 68 |     log("info", "retrieving checkpoints")
 69 | 
 70 |     raw_config = hf_get("config.json", args.hf_repo)
 71 |     with open(hf_get(raw_config), "r") as fobj:
 72 |         raw_config = json.load(fobj)
 73 | 
 74 |     mimi_weights = hf_get(raw_config["mimi_name"], args.hf_repo)
 75 |     moshi_name = raw_config.get("moshi_name", "model.safetensors")
 76 |     moshi_weights = hf_get(moshi_name, args.hf_repo)
 77 |     tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo)
 78 |     lm_config = models.LmConfig.from_config_dict(raw_config)
 79 |     model = models.Lm(lm_config)
 80 |     model.set_dtype(mx.bfloat16)
 81 | 
 82 |     log("info", f"loading model weights from {moshi_weights}")
 83 |     model.load_pytorch_weights(str(moshi_weights), lm_config, strict=True)
 84 | 
 85 |     if args.quantize is not None:
 86 |         log("info", f"quantizing model to {args.quantize} bits")
 87 |         nn.quantize(model.depformer, bits=args.quantize)
 88 |         for layer in model.transformer.layers:
 89 |             nn.quantize(layer.self_attn, bits=args.quantize)
 90 |             nn.quantize(layer.gating, bits=args.quantize)
 91 | 
 92 |     log("info", f"loading the text tokenizer from {tokenizer}")
 93 |     text_tokenizer = sentencepiece.SentencePieceProcessor(str(tokenizer))  # type: ignore
 94 | 
 95 |     log("info", f"loading the audio tokenizer {mimi_weights}")
 96 |     generated_codebooks = lm_config.generated_codebooks
 97 |     audio_tokenizer = models.mimi.Mimi(models.mimi_202407(generated_codebooks))
 98 |     audio_tokenizer.load_pytorch_weights(str(mimi_weights), strict=True)
 99 | 
100 |     cfg_coef_conditioning = None
101 |     tts_model = TTSModel(
102 |         model,
103 |         audio_tokenizer,
104 |         text_tokenizer,
105 |         voice_repo=args.voice_repo,
106 |         temp=0.6,
107 |         cfg_coef=1,
108 |         max_padding=8,
109 |         initial_padding=2,
110 |         final_padding=2,
111 |         padding_bonus=0,
112 |         raw_config=raw_config,
113 |     )
114 |     if tts_model.valid_cfg_conditionings:
115 |         # Model was trained with CFG distillation.
116 |         cfg_coef_conditioning = tts_model.cfg_coef
117 |         tts_model.cfg_coef = 1.0
118 |         cfg_is_no_text = False
119 |         cfg_is_no_prefix = False
120 |     else:
121 |         cfg_is_no_text = True
122 |         cfg_is_no_prefix = True
123 |     mimi = tts_model.mimi
124 | 
125 |     log("info", f"reading input from {args.inp}")
126 |     if args.inp == "-":
127 |         if sys.stdin.isatty():  # Interactive
128 |             print("Enter text to synthesize (Ctrl+D to end input):")
129 |         text_to_tts = sys.stdin.read().strip()
130 |     else:
131 |         with open(args.inp, "r") as fobj:
132 |             text_to_tts = fobj.read().strip()
133 | 
134 |     all_entries = [tts_model.prepare_script([text_to_tts])]
135 |     if tts_model.multi_speaker:
136 |         voices = [tts_model.get_voice_path(args.voice)]
137 |     else:
138 |         voices = []
139 |     all_attributes = [
140 |         tts_model.make_condition_attributes(voices, cfg_coef_conditioning)
141 |     ]
142 | 
143 |     wav_frames = queue.Queue()
144 | 
145 |     def _on_frame(frame):
146 |         if (frame == -1).any():
147 |             return
148 |         _pcm = tts_model.mimi.decode_step(frame[:, :, None])
149 |         _pcm = np.array(mx.clip(_pcm[0, 0], -1, 1))
150 |         wav_frames.put_nowait(_pcm)
151 | 
152 |     def run():
153 |         log("info", "starting the inference loop")
154 |         begin = time.time()
155 |         result = tts_model.generate(
156 |             all_entries,
157 |             all_attributes,
158 |             cfg_is_no_prefix=cfg_is_no_prefix,
159 |             cfg_is_no_text=cfg_is_no_text,
160 |             on_frame=_on_frame,
161 |         )
162 |         frames = mx.concat(result.frames, axis=-1)
163 |         total_duration = frames.shape[0] * frames.shape[-1] / mimi.frame_rate
164 |         time_taken = time.time() - begin
165 |         total_speed = total_duration / time_taken
166 |         log("info", f"[LM] took {time_taken:.2f}s, total speed {total_speed:.2f}x")
167 |         return result
168 | 
169 |     if args.out == "-":
170 | 
171 |         def audio_callback(outdata, _a, _b, _c):
172 |             try:
173 |                 pcm_data = wav_frames.get(block=False)
174 |                 outdata[:, 0] = pcm_data
175 |             except queue.Empty:
176 |                 outdata[:] = 0
177 | 
178 |         with sd.OutputStream(
179 |             samplerate=mimi.sample_rate,
180 |             blocksize=1920,
181 |             channels=1,
182 |             callback=audio_callback,
183 |         ):
184 |             run()
185 |             time.sleep(3)
186 |             while True:
187 |                 if wav_frames.qsize() == 0:
188 |                     break
189 |                 time.sleep(1)
190 |     else:
191 |         run()
192 |         frames = []
193 |         while True:
194 |             try:
195 |                 frames.append(wav_frames.get_nowait())
196 |             except queue.Empty:
197 |                 break
198 |         wav = np.concat(frames, -1)
199 |         sphn.write_wav(args.out, wav, mimi.sample_rate)
200 | 
201 | 
202 | if __name__ == "__main__":
203 |     main()
204 | 


--------------------------------------------------------------------------------
/scripts/tts_pytorch.py:
--------------------------------------------------------------------------------
  1 | # /// script
  2 | # requires-python = ">=3.12"
  3 | # dependencies = [
  4 | #     "moshi==0.2.8",
  5 | #     "torch",
  6 | #     "sphn",
  7 | #     "sounddevice",
  8 | # ]
  9 | # ///
 10 | import argparse
 11 | import sys
 12 | 
 13 | import numpy as np
 14 | import queue
 15 | import sphn
 16 | import time
 17 | import torch
 18 | from moshi.models.loaders import CheckpointInfo
 19 | from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
 20 | 
 21 | 
 22 | def main():
 23 |     parser = argparse.ArgumentParser(
 24 |         description="Run Kyutai TTS using the PyTorch implementation"
 25 |     )
 26 |     parser.add_argument("inp", type=str, help="Input file, use - for stdin.")
 27 |     parser.add_argument(
 28 |         "out", type=str, help="Output file to generate, use - for playing the audio"
 29 |     )
 30 |     parser.add_argument(
 31 |         "--hf-repo",
 32 |         type=str,
 33 |         default=DEFAULT_DSM_TTS_REPO,
 34 |         help="HF repo in which to look for the pretrained models.",
 35 |     )
 36 |     parser.add_argument(
 37 |         "--voice-repo",
 38 |         default=DEFAULT_DSM_TTS_VOICE_REPO,
 39 |         help="HF repo in which to look for pre-computed voice embeddings.",
 40 |     )
 41 |     parser.add_argument(
 42 |         "--voice",
 43 |         default="expresso/ex03-ex01_happy_001_channel1_334s.wav",
 44 |         help="The voice to use, relative to the voice repo root. "
 45 |         f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
 46 |     )
 47 |     parser.add_argument(
 48 |         "--device",
 49 |         type=str,
 50 |         default="cuda",
 51 |         help="Device on which to run, defaults to 'cuda'.",
 52 |     )
 53 |     args = parser.parse_args()
 54 | 
 55 |     print("Loading model...")
 56 |     checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo)
 57 |     tts_model = TTSModel.from_checkpoint_info(
 58 |         checkpoint_info, n_q=32, temp=0.6, device=args.device
 59 |     )
 60 | 
 61 |     if args.inp == "-":
 62 |         if sys.stdin.isatty():  # Interactive
 63 |             print("Enter text to synthesize (Ctrl+D to end input):")
 64 |         text = sys.stdin.read().strip()
 65 |     else:
 66 |         with open(args.inp, "r") as fobj:
 67 |             text = fobj.read().strip()
 68 | 
 69 |     # If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...]
 70 |     entries = tts_model.prepare_script([text], padding_between=1)
 71 |     voice_path = tts_model.get_voice_path(args.voice)
 72 |     # CFG coef goes here because the model was trained with CFG distillation,
 73 |     # so it's not _actually_ doing CFG at inference time.
 74 |     # Also, if you are generating a dialog, you should have two voices in the list.
 75 |     condition_attributes = tts_model.make_condition_attributes(
 76 |         [voice_path], cfg_coef=2.0
 77 |     )
 78 | 
 79 |     if args.out == "-":
 80 |         # Stream the audio to the speakers using sounddevice.
 81 |         import sounddevice as sd
 82 | 
 83 |         pcms = queue.Queue()
 84 | 
 85 |         def _on_frame(frame):
 86 |             if (frame != -1).all():
 87 |                 pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
 88 |                 pcms.put_nowait(np.clip(pcm[0, 0], -1, 1))
 89 | 
 90 |         def audio_callback(outdata, _a, _b, _c):
 91 |             try:
 92 |                 pcm_data = pcms.get(block=False)
 93 |                 outdata[:, 0] = pcm_data
 94 |             except queue.Empty:
 95 |                 outdata[:] = 0
 96 | 
 97 |         with sd.OutputStream(
 98 |             samplerate=tts_model.mimi.sample_rate,
 99 |             blocksize=1920,
100 |             channels=1,
101 |             callback=audio_callback,
102 |         ):
103 |             with tts_model.mimi.streaming(1):
104 |                 tts_model.generate(
105 |                     [entries], [condition_attributes], on_frame=_on_frame
106 |                 )
107 |             time.sleep(3)
108 |             while True:
109 |                 if pcms.qsize() == 0:
110 |                     break
111 |                 time.sleep(1)
112 |     else:
113 |         result = tts_model.generate([entries], [condition_attributes])
114 |         with tts_model.mimi.streaming(1), torch.no_grad():
115 |             pcms = []
116 |             for frame in result.frames[tts_model.delay_steps :]:
117 |                 pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
118 |                 pcms.append(np.clip(pcm[0, 0], -1, 1))
119 |             pcm = np.concatenate(pcms, axis=-1)
120 |         sphn.write_wav(args.out, pcm, tts_model.mimi.sample_rate)
121 | 
122 | 
123 | if __name__ == "__main__":
124 |     main()
125 | 


--------------------------------------------------------------------------------
/scripts/tts_rust_server.py:
--------------------------------------------------------------------------------
  1 | # /// script
  2 | # requires-python = ">=3.12"
  3 | # dependencies = [
  4 | #     "msgpack",
  5 | #     "numpy",
  6 | #     "sphn",
  7 | #     "websockets",
  8 | #     "sounddevice",
  9 | #     "tqdm",
 10 | # ]
 11 | # ///
 12 | import argparse
 13 | import asyncio
 14 | import sys
 15 | from urllib.parse import urlencode
 16 | 
 17 | import msgpack
 18 | import numpy as np
 19 | import sounddevice as sd
 20 | import sphn
 21 | import tqdm
 22 | import websockets
 23 | 
 24 | SAMPLE_RATE = 24000
 25 | 
 26 | TTS_TEXT = "Hello, this is a test of the moshi text to speech system, this should result in some nicely sounding generated voice."
 27 | DEFAULT_DSM_TTS_VOICE_REPO = "kyutai/tts-voices"
 28 | AUTH_TOKEN = "public_token"
 29 | 
 30 | 
 31 | async def receive_messages(websocket: websockets.ClientConnection, output_queue):
 32 |     with tqdm.tqdm(desc="Receiving audio", unit=" seconds generated") as pbar:
 33 |         accumulated_samples = 0
 34 |         last_seconds = 0
 35 | 
 36 |         async for message_bytes in websocket:
 37 |             msg = msgpack.unpackb(message_bytes)
 38 | 
 39 |             if msg["type"] == "Audio":
 40 |                 pcm = np.array(msg["pcm"]).astype(np.float32)
 41 |                 await output_queue.put(pcm)
 42 | 
 43 |                 accumulated_samples += len(msg["pcm"])
 44 |                 current_seconds = accumulated_samples // SAMPLE_RATE
 45 |                 if current_seconds > last_seconds:
 46 |                     pbar.update(current_seconds - last_seconds)
 47 |                     last_seconds = current_seconds
 48 | 
 49 |     print("End of audio.")
 50 |     await output_queue.put(None)  # Signal end of audio
 51 | 
 52 | 
 53 | async def output_audio(out: str, output_queue: asyncio.Queue[np.ndarray | None]):
 54 |     if out == "-":
 55 |         should_exit = False
 56 | 
 57 |         def audio_callback(outdata, _a, _b, _c):
 58 |             nonlocal should_exit
 59 | 
 60 |             try:
 61 |                 pcm_data = output_queue.get_nowait()
 62 |                 if pcm_data is not None:
 63 |                     outdata[:, 0] = pcm_data
 64 |                 else:
 65 |                     should_exit = True
 66 |                     outdata[:] = 0
 67 |             except asyncio.QueueEmpty:
 68 |                 outdata[:] = 0
 69 | 
 70 |         with sd.OutputStream(
 71 |             samplerate=SAMPLE_RATE,
 72 |             blocksize=1920,
 73 |             channels=1,
 74 |             callback=audio_callback,
 75 |         ):
 76 |             while True:
 77 |                 if should_exit:
 78 |                     break
 79 |                 await asyncio.sleep(1)
 80 |     else:
 81 |         frames = []
 82 |         while True:
 83 |             item = await output_queue.get()
 84 |             if item is None:
 85 |                 break
 86 |             frames.append(item)
 87 | 
 88 |         sphn.write_wav(out, np.concat(frames, -1), SAMPLE_RATE)
 89 |         print(f"Saved audio to {out}")
 90 | 
 91 | 
 92 | async def websocket_client():
 93 |     parser = argparse.ArgumentParser(description="Use the TTS streaming API")
 94 |     parser.add_argument("inp", type=str, help="Input file, use - for stdin.")
 95 |     parser.add_argument(
 96 |         "out", type=str, help="Output file to generate, use - for playing the audio"
 97 |     )
 98 |     parser.add_argument(
 99 |         "--voice",
100 |         default="expresso/ex03-ex01_happy_001_channel1_334s.wav",
101 |         help="The voice to use, relative to the voice repo root. "
102 |         f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
103 |     )
104 |     parser.add_argument(
105 |         "--url",
106 |         help="The URL of the server to which to send the audio",
107 |         default="ws://127.0.0.1:8080",
108 |     )
109 |     parser.add_argument("--api-key", default="public_token")
110 |     args = parser.parse_args()
111 | 
112 |     params = {"voice": args.voice, "format": "PcmMessagePack"}
113 |     uri = f"{args.url}/api/tts_streaming?{urlencode(params)}"
114 |     print(uri)
115 | 
116 |     # TODO: stream the text instead of sending it all at once
117 |     if args.inp == "-":
118 |         if sys.stdin.isatty():  # Interactive
119 |             print("Enter text to synthesize (Ctrl+D to end input):")
120 |         text_to_tts = sys.stdin.read().strip()
121 |     else:
122 |         with open(args.inp, "r") as fobj:
123 |             text_to_tts = fobj.read().strip()
124 | 
125 |     headers = {"kyutai-api-key": args.api_key}
126 | 
127 |     async with websockets.connect(uri, additional_headers=headers) as websocket:
128 |         await websocket.send(msgpack.packb({"type": "Text", "text": text_to_tts}))
129 |         await websocket.send(msgpack.packb({"type": "Eos"}))
130 | 
131 |         output_queue = asyncio.Queue()
132 |         receive_task = asyncio.create_task(receive_messages(websocket, output_queue))
133 |         output_audio_task = asyncio.create_task(output_audio(args.out, output_queue))
134 |         await asyncio.gather(receive_task, output_audio_task)
135 | 
136 | 
137 | if __name__ == "__main__":
138 |     asyncio.run(websocket_client())
139 | 


--------------------------------------------------------------------------------
/stt-rs/Cargo.toml:
--------------------------------------------------------------------------------
 1 | [package]
 2 | name = "kyutai-stt-rs"
 3 | version = "0.1.0"
 4 | edition = "2024"
 5 | 
 6 | [dependencies]
 7 | anyhow = "1.0"
 8 | candle = { version = "0.9.1",  package = "candle-core" }
 9 | candle-nn = "0.9.1"
10 | clap = { version = "4.4.12", features = ["derive"] }
11 | hf-hub = "0.4.3"
12 | kaudio = "0.2.1"
13 | moshi = "0.6.1"
14 | sentencepiece = "0.11.3"
15 | serde = { version = "1.0.210", features = ["derive"] }
16 | serde_json = "1.0.115"
17 | 
18 | [features]
19 | default = []
20 | cuda = ["candle/cuda", "candle-nn/cuda"]
21 | cudnn = ["candle/cudnn", "candle-nn/cudnn"]
22 | metal = ["candle/metal", "candle-nn/metal"]
23 | 
24 | [profile.release]
25 | debug = true
26 | 
27 | [profile.release-no-debug]
28 | inherits = "release"
29 | debug = false
30 | 
31 | 


--------------------------------------------------------------------------------
/stt-rs/src/main.rs:
--------------------------------------------------------------------------------
  1 | // Copyright (c) Kyutai, all rights reserved.
  2 | // This source code is licensed under the license found in the
  3 | // LICENSE file in the root directory of this source tree.
  4 | 
  5 | use anyhow::Result;
  6 | use candle::{Device, Tensor};
  7 | use clap::Parser;
  8 | 
  9 | #[derive(Debug, Parser)]
 10 | struct Args {
 11 |     /// The audio input file, in wav/mp3/ogg/... format.
 12 |     in_file: String,
 13 | 
 14 |     /// The repo where to get the model from.
 15 |     #[arg(long, default_value = "kyutai/stt-1b-en_fr-candle")]
 16 |     hf_repo: String,
 17 | 
 18 |     /// Run the model on cpu.
 19 |     #[arg(long)]
 20 |     cpu: bool,
 21 | 
 22 |     /// Display word level timestamps.
 23 |     #[arg(long)]
 24 |     timestamps: bool,
 25 | 
 26 |     /// Display the level of voice activity detection (VAD).
 27 |     #[arg(long)]
 28 |     vad: bool,
 29 | }
 30 | 
 31 | fn device(cpu: bool) -> Result<Device> {
 32 |     if cpu {
 33 |         Ok(Device::Cpu)
 34 |     } else if candle::utils::cuda_is_available() {
 35 |         Ok(Device::new_cuda(0)?)
 36 |     } else if candle::utils::metal_is_available() {
 37 |         Ok(Device::new_metal(0)?)
 38 |     } else {
 39 |         Ok(Device::Cpu)
 40 |     }
 41 | }
 42 | 
 43 | #[derive(Debug, serde::Deserialize)]
 44 | struct SttConfig {
 45 |     audio_silence_prefix_seconds: f64,
 46 |     audio_delay_seconds: f64,
 47 | }
 48 | 
 49 | #[derive(Debug, serde::Deserialize)]
 50 | struct Config {
 51 |     mimi_name: String,
 52 |     tokenizer_name: String,
 53 |     card: usize,
 54 |     text_card: usize,
 55 |     dim: usize,
 56 |     n_q: usize,
 57 |     context: usize,
 58 |     max_period: f64,
 59 |     num_heads: usize,
 60 |     num_layers: usize,
 61 |     causal: bool,
 62 |     stt_config: SttConfig,
 63 | }
 64 | 
 65 | impl Config {
 66 |     fn model_config(&self, vad: bool) -> moshi::lm::Config {
 67 |         let lm_cfg = moshi::transformer::Config {
 68 |             d_model: self.dim,
 69 |             num_heads: self.num_heads,
 70 |             num_layers: self.num_layers,
 71 |             dim_feedforward: self.dim * 4,
 72 |             causal: self.causal,
 73 |             norm_first: true,
 74 |             bias_ff: false,
 75 |             bias_attn: false,
 76 |             layer_scale: None,
 77 |             context: self.context,
 78 |             max_period: self.max_period as usize,
 79 |             use_conv_block: false,
 80 |             use_conv_bias: true,
 81 |             cross_attention: None,
 82 |             gating: Some(candle_nn::Activation::Silu),
 83 |             norm: moshi::NormType::RmsNorm,
 84 |             positional_embedding: moshi::transformer::PositionalEmbedding::Rope,
 85 |             conv_layout: false,
 86 |             conv_kernel_size: 3,
 87 |             kv_repeat: 1,
 88 |             max_seq_len: 4096 * 4,
 89 |             shared_cross_attn: false,
 90 |         };
 91 |         let extra_heads = if vad {
 92 |             Some(moshi::lm::ExtraHeadsConfig {
 93 |                 num_heads: 4,
 94 |                 dim: 6,
 95 |             })
 96 |         } else {
 97 |             None
 98 |         };
 99 |         moshi::lm::Config {
100 |             transformer: lm_cfg,
101 |             depformer: None,
102 |             audio_vocab_size: self.card + 1,
103 |             text_in_vocab_size: self.text_card + 1,
104 |             text_out_vocab_size: self.text_card,
105 |             audio_codebooks: self.n_q,
106 |             conditioners: Default::default(),
107 |             extra_heads,
108 |         }
109 |     }
110 | }
111 | 
112 | struct Model {
113 |     state: moshi::asr::State,
114 |     text_tokenizer: sentencepiece::SentencePieceProcessor,
115 |     timestamps: bool,
116 |     vad: bool,
117 |     config: Config,
118 |     dev: Device,
119 | }
120 | 
121 | impl Model {
122 |     fn load_from_hf(args: &Args, dev: &Device) -> Result<Self> {
123 |         let dtype = dev.bf16_default_to_f32();
124 | 
125 |         // Retrieve the model files from the Hugging Face Hub
126 |         let api = hf_hub::api::sync::Api::new()?;
127 |         let repo = api.model(args.hf_repo.to_string());
128 |         let config_file = repo.get("config.json")?;
129 |         let config: Config = serde_json::from_str(&std::fs::read_to_string(&config_file)?)?;
130 |         let tokenizer_file = repo.get(&config.tokenizer_name)?;
131 |         let model_file = repo.get("model.safetensors")?;
132 |         let mimi_file = repo.get(&config.mimi_name)?;
133 | 
134 |         let text_tokenizer = sentencepiece::SentencePieceProcessor::open(&tokenizer_file)?;
135 |         let vb_lm =
136 |             unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], dtype, dev)? };
137 |         let audio_tokenizer = moshi::mimi::load(mimi_file.to_str().unwrap(), Some(32), dev)?;
138 |         let lm = moshi::lm::LmModel::new(
139 |             &config.model_config(args.vad),
140 |             moshi::nn::MaybeQuantizedVarBuilder::Real(vb_lm),
141 |         )?;
142 |         let asr_delay_in_tokens = (config.stt_config.audio_delay_seconds * 12.5) as usize;
143 |         let state = moshi::asr::State::new(1, asr_delay_in_tokens, 0., audio_tokenizer, lm)?;
144 |         Ok(Model {
145 |             state,
146 |             config,
147 |             text_tokenizer,
148 |             timestamps: args.timestamps,
149 |             vad: args.vad,
150 |             dev: dev.clone(),
151 |         })
152 |     }
153 | 
154 |     fn run(&mut self, mut pcm: Vec<f32>) -> Result<()> {
155 |         use std::io::Write;
156 | 
157 |         // Add the silence prefix to the audio.
158 |         if self.config.stt_config.audio_silence_prefix_seconds > 0.0 {
159 |             let silence_len =
160 |                 (self.config.stt_config.audio_silence_prefix_seconds * 24000.0) as usize;
161 |             pcm.splice(0..0, vec![0.0; silence_len]);
162 |         }
163 |         // Add some silence at the end to ensure all the audio is processed.
164 |         let suffix = (self.config.stt_config.audio_delay_seconds * 24000.0) as usize;
165 |         pcm.resize(pcm.len() + suffix + 24000, 0.0);
166 | 
167 |         let mut last_word = None;
168 |         let mut printed_eot = false;
169 |         for pcm in pcm.chunks(1920) {
170 |             let pcm = Tensor::new(pcm, &self.dev)?.reshape((1, 1, ()))?;
171 |             let asr_msgs = self.state.step_pcm(pcm, None, &().into(), |_, _, _| ())?;
172 |             for asr_msg in asr_msgs.iter() {
173 |                 match asr_msg {
174 |                     moshi::asr::AsrMsg::Step { prs, .. } => {
175 |                         // prs is the probability of having no voice activity for different time
176 |                         // horizons.
177 |                         // In kyutai/stt-1b-en_fr-candle, these horizons are 0.5s, 1s, 2s, and 3s.
178 |                         if self.vad && prs[2][0] > 0.5 && !printed_eot {
179 |                             printed_eot = true;
180 |                             if !self.timestamps {
181 |                                 print!(" <endofturn pr={}>", prs[2][0]);
182 |                             } else {
183 |                                 println!("<endofturn pr={}>", prs[2][0]);
184 |                             }
185 |                         }
186 |                     }
187 |                     moshi::asr::AsrMsg::EndWord { stop_time, .. } => {
188 |                         printed_eot = false;
189 |                         if self.timestamps {
190 |                             if let Some((word, start_time)) = last_word.take() {
191 |                                 println!("[{start_time:5.2}-{stop_time:5.2}] {word}");
192 |                             }
193 |                         }
194 |                     }
195 |                     moshi::asr::AsrMsg::Word {
196 |                         tokens, start_time, ..
197 |                     } => {
198 |                         printed_eot = false;
199 |                         let word = self
200 |                             .text_tokenizer
201 |                             .decode_piece_ids(tokens)
202 |                             .unwrap_or_else(|_| String::new());
203 |                         if !self.timestamps {
204 |                             print!(" {word}");
205 |                             std::io::stdout().flush()?
206 |                         } else {
207 |                             if let Some((word, prev_start_time)) = last_word.take() {
208 |                                 println!("[{prev_start_time:5.2}-{start_time:5.2}] {word}");
209 |                             }
210 |                             last_word = Some((word, *start_time));
211 |                         }
212 |                     }
213 |                 }
214 |             }
215 |         }
216 |         if let Some((word, start_time)) = last_word.take() {
217 |             println!("[{start_time:5.2}-     ] {word}");
218 |         }
219 |         println!();
220 |         Ok(())
221 |     }
222 | }
223 | 
224 | fn main() -> Result<()> {
225 |     let args = Args::parse();
226 |     let device = device(args.cpu)?;
227 |     println!("Using device: {:?}", device);
228 | 
229 |     println!("Loading audio file from: {}", args.in_file);
230 |     let (pcm, sample_rate) = kaudio::pcm_decode(&args.in_file)?;
231 |     let pcm = if sample_rate != 24_000 {
232 |         kaudio::resample(&pcm, sample_rate as usize, 24_000)?
233 |     } else {
234 |         pcm
235 |     };
236 |     println!("Loading model from repository: {}", args.hf_repo);
237 |     let mut model = Model::load_from_hf(&args, &device)?;
238 |     println!("Running inference");
239 |     model.run(pcm)?;
240 |     Ok(())
241 | }
242 | 


--------------------------------------------------------------------------------
/stt_pytorch.ipynb:
--------------------------------------------------------------------------------
  1 | {
  2 |  "cells": [
  3 |   {
  4 |    "cell_type": "code",
  5 |    "execution_count": null,
  6 |    "metadata": {
  7 |     "colab": {
  8 |      "base_uri": "https://localhost:8080/"
  9 |     },
 10 |     "id": "gJEMjPgeI-rw",
 11 |     "outputId": "7491c067-b1be-4505-b3f5-19ba4c00a593"
 12 |    },
 13 |    "outputs": [],
 14 |    "source": [
 15 |     "!pip install moshi"
 16 |    ]
 17 |   },
 18 |   {
 19 |    "cell_type": "code",
 20 |    "execution_count": null,
 21 |    "metadata": {
 22 |     "colab": {
 23 |      "base_uri": "https://localhost:8080/"
 24 |     },
 25 |     "id": "CA4K5iDFJcqJ",
 26 |     "outputId": "b609843a-a193-4729-b099-5f8780532333"
 27 |    },
 28 |    "outputs": [],
 29 |    "source": [
 30 |     "!wget https://github.com/kyutai-labs/moshi/raw/refs/heads/main/data/sample_fr_hibiki_crepes.mp3"
 31 |    ]
 32 |   },
 33 |   {
 34 |    "cell_type": "code",
 35 |    "execution_count": null,
 36 |    "metadata": {
 37 |     "id": "VA3Haix3IZ8Q"
 38 |    },
 39 |    "outputs": [],
 40 |    "source": [
 41 |     "from dataclasses import dataclass\n",
 42 |     "import time\n",
 43 |     "import sentencepiece\n",
 44 |     "import sphn\n",
 45 |     "import textwrap\n",
 46 |     "import torch\n",
 47 |     "\n",
 48 |     "from moshi.models import loaders, MimiModel, LMModel, LMGen"
 49 |    ]
 50 |   },
 51 |   {
 52 |    "cell_type": "code",
 53 |    "execution_count": null,
 54 |    "metadata": {
 55 |     "id": "9AK5zBMTI9bw"
 56 |    },
 57 |    "outputs": [],
 58 |    "source": [
 59 |     "@dataclass\n",
 60 |     "class InferenceState:\n",
 61 |     "    mimi: MimiModel\n",
 62 |     "    text_tokenizer: sentencepiece.SentencePieceProcessor\n",
 63 |     "    lm_gen: LMGen\n",
 64 |     "\n",
 65 |     "    def __init__(\n",
 66 |     "        self,\n",
 67 |     "        mimi: MimiModel,\n",
 68 |     "        text_tokenizer: sentencepiece.SentencePieceProcessor,\n",
 69 |     "        lm: LMModel,\n",
 70 |     "        batch_size: int,\n",
 71 |     "        device: str | torch.device,\n",
 72 |     "    ):\n",
 73 |     "        self.mimi = mimi\n",
 74 |     "        self.text_tokenizer = text_tokenizer\n",
 75 |     "        self.lm_gen = LMGen(lm, temp=0, temp_text=0, use_sampling=False)\n",
 76 |     "        self.device = device\n",
 77 |     "        self.frame_size = int(self.mimi.sample_rate / self.mimi.frame_rate)\n",
 78 |     "        self.batch_size = batch_size\n",
 79 |     "        self.mimi.streaming_forever(batch_size)\n",
 80 |     "        self.lm_gen.streaming_forever(batch_size)\n",
 81 |     "\n",
 82 |     "    def run(self, in_pcms: torch.Tensor):\n",
 83 |     "        device = self.lm_gen.lm_model.device\n",
 84 |     "        ntokens = 0\n",
 85 |     "        first_frame = True\n",
 86 |     "        chunks = [\n",
 87 |     "            c\n",
 88 |     "            for c in in_pcms.split(self.frame_size, dim=2)\n",
 89 |     "            if c.shape[-1] == self.frame_size\n",
 90 |     "        ]\n",
 91 |     "        start_time = time.time()\n",
 92 |     "        all_text = []\n",
 93 |     "        for chunk in chunks:\n",
 94 |     "            codes = self.mimi.encode(chunk)\n",
 95 |     "            if first_frame:\n",
 96 |     "                # Ensure that the first slice of codes is properly seen by the transformer\n",
 97 |     "                # as otherwise the first slice is replaced by the initial tokens.\n",
 98 |     "                tokens = self.lm_gen.step(codes)\n",
 99 |     "                first_frame = False\n",
100 |     "            tokens = self.lm_gen.step(codes)\n",
101 |     "            if tokens is None:\n",
102 |     "                continue\n",
103 |     "            assert tokens.shape[1] == 1\n",
104 |     "            one_text = tokens[0, 0].cpu()\n",
105 |     "            if one_text.item() not in [0, 3]:\n",
106 |     "                text = self.text_tokenizer.id_to_piece(one_text.item())\n",
107 |     "                text = text.replace(\"▁\", \" \")\n",
108 |     "                all_text.append(text)\n",
109 |     "            ntokens += 1\n",
110 |     "        dt = time.time() - start_time\n",
111 |     "        print(\n",
112 |     "            f\"processed {ntokens} steps in {dt:.0f}s, {1000 * dt / ntokens:.2f}ms/step\"\n",
113 |     "        )\n",
114 |     "        return \"\".join(all_text)"
115 |    ]
116 |   },
117 |   {
118 |    "cell_type": "code",
119 |    "execution_count": null,
120 |    "metadata": {
121 |     "colab": {
122 |      "base_uri": "https://localhost:8080/",
123 |      "height": 353,
124 |      "referenced_widgets": [
125 |       "0a5f6f887e2b4cd1990a0e9ec0153ed9",
126 |       "f7893826fcba4bdc87539589d669249b",
127 |       "8805afb12c484781be85082ff02dad13",
128 |       "97679c0d9ab44bed9a3456f2fcb541fd",
129 |       "d73c0321bed54a52b5e1da0a7788e32a",
130 |       "d67be13a920d4fc89e5570b5b29fc1d2",
131 |       "6b377c2d7bf945fb89e46c39d246a332",
132 |       "b82ff365c78e41ad8094b46daf79449d",
133 |       "477aa7fa82dc42d5bce6f1743c45d626",
134 |       "cbd288510c474430beb66f346f382c45",
135 |       "aafc347cdf28428ea6a7abe5b46b726f",
136 |       "fca09acd5d0d45468c8b04bfb2de7646",
137 |       "79e35214b51b4a9e9b3f7144b0b34f7b",
138 |       "89e9a37f69904bd48b954d627bff6687",
139 |       "57028789c78248a7b0ad4f031c9545c9",
140 |       "1150fcb427994c2984d4d0f4e4745fe5",
141 |       "e24b1fc52f294f849019c9b3befb613f",
142 |       "8724878682cf4c3ca992667c45009398",
143 |       "36a22c977d5242008871310133b7d2af",
144 |       "5b3683cad5cb4877b43fadd003edf97f",
145 |       "703f98272e4d469d8f27f5a465715dd8",
146 |       "9dbe02ef5fac41cfaee3d02946e65c88",
147 |       "37faa87ad03a4271992c21ce6a629e18",
148 |       "570c547e48cd421b814b2c5e028e4c0b",
149 |       "b173768580fc4c0a8e3abf272e4c363a",
150 |       "e57d1620f0a9427b85d8b4885ef4e8e3",
151 |       "5dd4474df70743498b616608182714dd",
152 |       "cc907676a65f4ad1bf68a77b4a00e89b",
153 |       "a34abc3b118e4305951a466919c28ff6",
154 |       "a77ccfcdb90146c7a63b4b2d232bc494",
155 |       "f7313e6e3a27475993cab3961d6ae363",
156 |       "39b47fad9c554839868fe9e4bbf7def2",
157 |       "14e9511ea0bd44c49f0cf3abf1a6d40e",
158 |       "a4ea8e0c4cac4d5e88b7e3f527e4fe90",
159 |       "571afc0f4b2840c9830d6b5a307ed1f9",
160 |       "6ec593cab5b64f0ea638bb175b9daa5c",
161 |       "77a52aed00ae408bb24524880e19ec8a",
162 |       "0b2de4b29b4b44fe9d96361a40c793d0",
163 |       "3c5b5fb1a5ac468a89c1058bd90cfb58",
164 |       "e53e0a2a240e43cfa562c89b3d703dea",
165 |       "35966343cf9249ef8bc028a0d5c5f97d",
166 |       "e36a37e0d41c47ccb8bc6d56c19fb17c",
167 |       "279ccf7de43847a1a6579c9182a46cc8",
168 |       "41b5d6ab0b7d43c790a55f125c0e7494"
169 |      ]
170 |     },
171 |     "id": "UsQJdAgkLp9n",
172 |     "outputId": "9b7131c3-69c5-4323-8312-2ce7621d8869"
173 |    },
174 |    "outputs": [],
175 |    "source": [
176 |     "device = \"cuda\"\n",
177 |     "# Use the en+fr low latency model, an alternative is kyutai/stt-2.6b-en\n",
178 |     "checkpoint_info = loaders.CheckpointInfo.from_hf_repo(\"kyutai/stt-1b-en_fr\")\n",
179 |     "mimi = checkpoint_info.get_mimi(device=device)\n",
180 |     "text_tokenizer = checkpoint_info.get_text_tokenizer()\n",
181 |     "lm = checkpoint_info.get_moshi(device=device)\n",
182 |     "in_pcms, _ = sphn.read(\"sample_fr_hibiki_crepes.mp3\", sample_rate=mimi.sample_rate)\n",
183 |     "in_pcms = torch.from_numpy(in_pcms).to(device=device)\n",
184 |     "\n",
185 |     "stt_config = checkpoint_info.stt_config\n",
186 |     "pad_left = int(stt_config.get(\"audio_silence_prefix_seconds\", 0.0) * 24000)\n",
187 |     "pad_right = int((stt_config.get(\"audio_delay_seconds\", 0.0) + 1.0) * 24000)\n",
188 |     "in_pcms = torch.nn.functional.pad(in_pcms, (pad_left, pad_right), mode=\"constant\")\n",
189 |     "in_pcms = in_pcms[None, 0:1].expand(1, -1, -1)\n",
190 |     "\n",
191 |     "state = InferenceState(mimi, text_tokenizer, lm, batch_size=1, device=device)\n",
192 |     "text = state.run(in_pcms)\n",
193 |     "print(textwrap.fill(text, width=100))"
194 |    ]
195 |   },
196 |   {
197 |    "cell_type": "code",
198 |    "execution_count": null,
199 |    "metadata": {
200 |     "colab": {
201 |      "base_uri": "https://localhost:8080/",
202 |      "height": 75
203 |     },
204 |     "id": "CIAXs9oaPrtj",
205 |     "outputId": "94cc208c-2454-4dd4-a64e-d79025144af5"
206 |    },
207 |    "outputs": [],
208 |    "source": [
209 |     "from IPython.display import Audio\n",
210 |     "\n",
211 |     "Audio(\"sample_fr_hibiki_crepes.mp3\")"
212 |    ]
213 |   },
214 |   {
215 |    "cell_type": "code",
216 |    "execution_count": null,
217 |    "metadata": {
218 |     "id": "qkUZ6CBKOdTa"
219 |    },
220 |    "outputs": [],
221 |    "source": []
222 |   }
223 |  ],
224 |  "metadata": {
225 |   "accelerator": "GPU",
226 |   "colab": {
227 |    "gpuType": "L4",
228 |    "provenance": []
229 |   },
230 |   "kernelspec": {
231 |    "display_name": "Python 3 (ipykernel)",
232 |    "language": "python",
233 |    "name": "python3"
234 |   }
235 |  },
236 |  "nbformat": 4,
237 |  "nbformat_minor": 0
238 | }
239 | 


--------------------------------------------------------------------------------
/tts_pytorch.ipynb:
--------------------------------------------------------------------------------
  1 | {
  2 |  "cells": [
  3 |   {
  4 |    "cell_type": "code",
  5 |    "execution_count": null,
  6 |    "id": "0",
  7 |    "metadata": {},
  8 |    "outputs": [],
  9 |    "source": [
 10 |     "# Fast install, might break in the future.\n",
 11 |     "!pip install 'sphn<0.2'\n",
 12 |     "!pip install --no-deps \"moshi==0.2.8\"\n",
 13 |     "# Slow install (will download torch and cuda), but future proof.\n",
 14 |     "# !pip install \"moshi==0.2.8\""
 15 |    ]
 16 |   },
 17 |   {
 18 |    "cell_type": "code",
 19 |    "execution_count": null,
 20 |    "id": "1",
 21 |    "metadata": {},
 22 |    "outputs": [],
 23 |    "source": [
 24 |     "import argparse\n",
 25 |     "import sys\n",
 26 |     "\n",
 27 |     "import numpy as np\n",
 28 |     "import torch\n",
 29 |     "from moshi.models.loaders import CheckpointInfo\n",
 30 |     "from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel\n",
 31 |     "\n",
 32 |     "from IPython.display import display, Audio"
 33 |    ]
 34 |   },
 35 |   {
 36 |    "cell_type": "code",
 37 |    "execution_count": null,
 38 |    "id": "2",
 39 |    "metadata": {},
 40 |    "outputs": [],
 41 |    "source": [
 42 |     "# Configuration\n",
 43 |     "text = \"Hey there! How are you? I had the craziest day today.\"\n",
 44 |     "voice = \"expresso/ex03-ex01_happy_001_channel1_334s.wav\"\n",
 45 |     "print(f\"See https://huggingface.co/{DEFAULT_DSM_TTS_VOICE_REPO} for available voices.\")"
 46 |    ]
 47 |   },
 48 |   {
 49 |    "cell_type": "code",
 50 |    "execution_count": null,
 51 |    "id": "3",
 52 |    "metadata": {},
 53 |    "outputs": [],
 54 |    "source": [
 55 |     "# Set everything up\n",
 56 |     "checkpoint_info = CheckpointInfo.from_hf_repo(DEFAULT_DSM_TTS_REPO)\n",
 57 |     "tts_model = TTSModel.from_checkpoint_info(\n",
 58 |     "    checkpoint_info, n_q=32, temp=0.6, device=torch.device(\"cuda\")\n",
 59 |     ")\n",
 60 |     "\n",
 61 |     "# If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...]\n",
 62 |     "entries = tts_model.prepare_script([text], padding_between=1)\n",
 63 |     "voice_path = tts_model.get_voice_path(voice)\n",
 64 |     "# CFG coef goes here because the model was trained with CFG distillation,\n",
 65 |     "# so it's not _actually_ doing CFG at inference time.\n",
 66 |     "# Also, if you are generating a dialog, you should have two voices in the list.\n",
 67 |     "condition_attributes = tts_model.make_condition_attributes(\n",
 68 |     "    [voice_path], cfg_coef=2.0\n",
 69 |     ")"
 70 |    ]
 71 |   },
 72 |   {
 73 |    "cell_type": "code",
 74 |    "execution_count": null,
 75 |    "id": "4",
 76 |    "metadata": {},
 77 |    "outputs": [],
 78 |    "source": [
 79 |     "print(\"Generating audio...\")\n",
 80 |     "\n",
 81 |     "pcms = []\n",
 82 |     "def _on_frame(frame):\n",
 83 |     "    print(\"Step\", len(pcms), end=\"\\r\")\n",
 84 |     "    if (frame != -1).all():\n",
 85 |     "        pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()\n",
 86 |     "        pcms.append(np.clip(pcm[0, 0], -1, 1))\n",
 87 |     "\n",
 88 |     "# You could also generate multiple audios at once by extending the following lists.\n",
 89 |     "all_entries = [entries]\n",
 90 |     "all_condition_attributes = [condition_attributes]\n",
 91 |     "with tts_model.mimi.streaming(len(all_entries)):\n",
 92 |     "    result = tts_model.generate(all_entries, all_condition_attributes, on_frame=_on_frame)\n",
 93 |     "\n",
 94 |     "print(\"Done generating.\")\n",
 95 |     "audio = np.concatenate(pcms, axis=-1)"
 96 |    ]
 97 |   },
 98 |   {
 99 |    "cell_type": "code",
100 |    "execution_count": null,
101 |    "id": "5",
102 |    "metadata": {},
103 |    "outputs": [],
104 |    "source": [
105 |     "display(\n",
106 |     "    Audio(audio, rate=tts_model.mimi.sample_rate, autoplay=True)\n",
107 |     ")"
108 |    ]
109 |   },
110 |   {
111 |    "cell_type": "code",
112 |    "execution_count": null,
113 |    "id": "6",
114 |    "metadata": {},
115 |    "outputs": [],
116 |    "source": []
117 |   }
118 |  ],
119 |  "metadata": {
120 |   "kernelspec": {
121 |    "display_name": "Python 3 (ipykernel)",
122 |    "language": "python",
123 |    "name": "python3"
124 |   },
125 |   "language_info": {
126 |    "codemirror_mode": {
127 |     "name": "ipython",
128 |     "version": 3
129 |    },
130 |    "file_extension": ".py",
131 |    "mimetype": "text/x-python",
132 |    "name": "python",
133 |    "nbconvert_exporter": "python",
134 |    "pygments_lexer": "ipython3",
135 |    "version": "3.13.2"
136 |   }
137 |  },
138 |  "nbformat": 4,
139 |  "nbformat_minor": 5
140 | }
141 | 


--------------------------------------------------------------------------------