├── .gitignore ├── .python-version ├── LICENSE ├── README.md ├── config ├── emilia │ ├── baseline.json │ └── resume_from_checkpoint.json ├── kokoro_v0.19 │ ├── libritts_bytes.json │ ├── libritts_kokoro.json │ ├── pg_kokoro.json │ ├── pg_kokoro_wte.json │ └── pg_kokoro_wte_nsys.json ├── kokoro_v1 │ ├── baseline.json │ ├── deduplicate_code_0.json │ └── scaleup.json ├── librispeech │ ├── librispeech.json │ └── librispeech_cooldown.json └── ljspeech │ └── ljspeech.json ├── data_pipeline ├── README.md ├── __init__.py ├── encode_libritts.py ├── notebooks │ ├── create_bytelevel_init.ipynb │ ├── create_smoltts_init.ipynb │ ├── decode_audio.ipynb │ ├── encode_ljspeech.ipynb │ ├── test_emilia.ipynb │ ├── test_emilia.py │ ├── test_tokenization.ipynb │ ├── tokenize_libritts.ipynb │ ├── tokenize_libritts_byte.ipynb │ ├── tokenize_libritts_byte_kokoro.ipynb │ └── upload_libritts.ipynb ├── preview │ ├── app.py │ └── decoder_config.json ├── scripts │ ├── audio_tokenizer_configs │ │ ├── emilia_v1.json │ │ ├── project_gutenberg_v1.json │ │ ├── project_gutenberg_v2.1.json │ │ └── project_gutenberg_v2.json │ ├── bpe_tokenizer_configs │ │ └── kokoro_bytelevel.json │ ├── chatml_tokenize_dataset.py │ └── create_bytelevel_init.py └── utils │ ├── __init__.py │ ├── codec.py │ └── prompt.py ├── docs ├── dualar.png └── examples │ ├── bella.wav │ ├── emma.wav │ ├── fable.wav │ ├── fenrir.wav │ ├── heart.wav │ ├── isabella.wav │ ├── liam.wav │ ├── michael.wav │ ├── nova.wav │ ├── sarah.wav │ └── sky.wav ├── mlx_inference ├── README.md ├── default_settings.json ├── example.ipynb ├── pyproject.toml ├── src │ └── smoltts_mlx │ │ ├── __init__.py │ │ ├── codec │ │ ├── __init__.py │ │ ├── conv.py │ │ ├── mimi.py │ │ ├── rvq.py │ │ ├── seanet.py │ │ └── transformer.py │ │ ├── io │ │ ├── __init__.py │ │ └── wav.py │ │ ├── lm │ │ ├── __init__.py │ │ ├── cache.py │ │ ├── config.py │ │ ├── generate.py │ │ ├── rq_transformer.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── constraints.py │ │ │ ├── prompt.py │ │ │ └── samplers.py │ │ ├── scripts │ │ ├── __init__.py │ │ └── server.py │ │ └── server │ │ ├── __init__.py │ │ ├── routes │ │ ├── __init__.py │ │ ├── elevenlabs.py │ │ └── openai.py │ │ ├── settings.py │ │ └── tts_core.py ├── static │ └── index.html └── tests │ ├── compare_npy.py │ ├── sky.wav │ ├── test_decoder.py │ ├── test_encoder.py │ └── test_generate.py ├── modeling ├── __init__.py ├── model │ ├── __init__.py │ └── rq_transformer.py └── utils │ └── __init__.py ├── pyproject.toml ├── sample_model_sizes ├── smoltts_byte_150m.json └── smoltts_byte_70m.json ├── train ├── README.md ├── __init__.py ├── config.py ├── convert_safetensors.py ├── data.py ├── main.py ├── optim.py ├── pyproject.toml ├── state.py └── trainer.py └── uv.lock /.gitignore: -------------------------------------------------------------------------------- 1 | .venv 2 | .env 3 | __pycache__ 4 | *.egg-info 5 | wandb 6 | *.wav 7 | *.npy 8 | !checkpoints/smoltts_init/config.json 9 | 10 | # Artifacts 11 | *.safetensors 12 | checkpoints 13 | datasets 14 | inits 15 | mlx_trace* 16 | pytorch_trace.json 17 | !docs/examples/*.wav 18 | !mlx_inference/tests/sky.wav 19 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.9 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SmolTTS: a text-to-speech laboratory 2 | 3 | This repo is a personal laboratory for training autoregressive text-audio models. 4 | 5 | Assume everything will change; quality right now is pretty mid. Will get better. 6 | 7 | ## Using pretrained models 8 | 9 | ### smoltts_v0 10 | 11 | A distillation of [Kokoro TTS](https://huggingface.co/hexgrad/Kokoro-82M) to the RQ Transformer architecture. Released at 70M and 150M scale. 12 | 13 | For MLX inference on Apple Silicon, you'll need a working Python installation. See the `mlx_inference` folder for setup docs! 14 | 15 | ```bash 16 | # tl;dr 17 | uvx --from smoltts_mlx smoltts-server 18 | ``` 19 | 20 | Candle.rs docs coming soon. 21 | 22 | ## Using datasets 23 | 24 | As of Feb 2025, this project currently uses the [Mimi](https://huggingface.co/kyutai/mimi) pretrained codec by Kyutai, due to its low framerate (12.5Hz), high compression ratio, and streaming support. 25 | 26 | ### Synthetic data 27 | 28 | [projectgutenberg-kokoro_v1-mimi](jkeisling/projectgutenberg-kokoro_v1-mimi): 29 | 30 | - ~5500 hours of synthetic audio generated with [Kokoro v1](https://huggingface.co/hexgrad/Kokoro-82M) for US and UK English. 31 | - 3 million utterances of sentences from Project Gutenberg, mostly 3-15s. 3.29GB compressed with Mimi. 32 | - 11 speakers. 33 | 34 | ### Mimi re-encodings of standard datasets 35 | 36 | For convenience, we serialize popular open TTS benchmark datasets in Mimi, to directly have training targets and compress the filesize by ~500x: 37 | 38 | - [LibriTTS-R](https://huggingface.co/datasets/jkeisling/libritts-r-mimi) encoded with [Mimi](https://huggingface.co/kyutai/mimi) codec. ~460 hours of data. 39 | 40 | ## Pretraining a model 41 | 42 | ### Workspace setup 43 | 44 | Unfortunately, HuggingFace Datasets using audio columns require librosa, which has a hard Python 3.9 dependency for inexplicable reasons. 45 | If you are not creating a new dataset using raw audio instead of Mimi codes, please feel free to ignore this. 46 | 47 | Please use [uv](https://docs.astral.sh/uv/). 48 | 49 | ```bash 50 | # If you are not making new audio datasets, feel free to use a sane Python version instead 51 | uv sync 52 | uv pip install -e . 53 | ``` 54 | 55 | Create a `.env` file and add: 56 | 57 | ```bash 58 | HUGGINGFACE_TOKEN=sk-placeholder 59 | ``` 60 | 61 | For the dataset and init, see `data_pipeline/README.md`. 62 | 63 | ### RQ Transformer 64 | 65 | This architecture is most popularly used as the neural codec seq2seq backbone for: 66 | 67 | - [Fish Speech TTS](https://github.com/fishaudio/fish-speech) (in their [paper](https://arxiv.org/html/2411.01156v2#S3) as "DualAR" or dual-autoregressive) 68 | - Kyutai's [Moshi](https://github.com/kyutai-labs/moshi) model early in pretraining before adaptation to duplex audio. 69 | 70 | Models trained here will be compatible with my DualAR [fish-speech.rs](https://github.com/EndlessReform/fish-speech.rs/blob/main/README.md) inference engine. 71 | -------------------------------------------------------------------------------- /config/emilia/baseline.json: -------------------------------------------------------------------------------- 1 | { 2 | "project_name": "smoltts_emilia", 3 | "checkpoint_path": "../checkpoints", 4 | "init_folder": "../inits/smoltts_byte_kokoro_layer", 5 | "dataset_path": "../datasets/byte-tokenized-emilia-v1", 6 | "model_path": "pretrained_model", 7 | "batch_size": 16, 8 | "accumulate_steps": 8, 9 | "max_epochs": 1, 10 | "num_workers": 4, 11 | "gradient_clip": 1.0, 12 | "learning_rate": 7e-4, 13 | "lr_start": 1.5e-3, 14 | "lr_warmup_steps": 5000, 15 | "weight_decay": 0.01, 16 | "betas": [ 17 | 0.90, 18 | 0.95 19 | ], 20 | "eps": 1e-5, 21 | "val_every_n_steps": 2000, 22 | "save_every_n_steps": 10000, 23 | "max_sequence_length": 768, 24 | "use_bf16": true, 25 | "use_wandb": true, 26 | "use_pretrained": false 27 | } -------------------------------------------------------------------------------- /config/emilia/resume_from_checkpoint.json: -------------------------------------------------------------------------------- 1 | { 2 | "project_name": "smoltts_emilia", 3 | "checkpoint_path": "../checkpoints", 4 | "init_folder": "../inits/smoltts_byte_kokoro_layer", 5 | "dataset_path": "../datasets/byte-tokenized-emilia-v1", 6 | "model_path": "pretrained_model", 7 | "batch_size": 16, 8 | "accumulate_steps": 8, 9 | "max_epochs": 1, 10 | "num_workers": 4, 11 | "gradient_clip": 1.0, 12 | "learning_rate": 7e-4, 13 | "lr_start": 7e-4, 14 | "lr_warmup_steps": 8, 15 | "weight_decay": 0.01, 16 | "betas": [ 17 | 0.90, 18 | 0.95 19 | ], 20 | "eps": 1e-5, 21 | "val_every_n_steps": 2000, 22 | "save_every_n_steps": 10000, 23 | "max_sequence_length": 768, 24 | "use_bf16": true, 25 | "use_wandb": true, 26 | "use_pretrained": true 27 | } -------------------------------------------------------------------------------- /config/kokoro_v0.19/libritts_bytes.json: -------------------------------------------------------------------------------- 1 | { 2 | "project_name": "smoltts_160m_librispeech", 3 | "checkpoint_path": "../checkpoints", 4 | "init_folder": "../inits/smoltts_byte", 5 | "dataset_path": "../datasets/tokenized_libritts_bytes", 6 | "model_path": "pretrained_model", 7 | "batch_size": 20, 8 | "max_epochs": 6, 9 | "num_workers": 4, 10 | "gradient_clip": 1.0, 11 | "learning_rate": 5e-4, 12 | "lr_start": 1e-3, 13 | "lr_warmup_steps": 30000, 14 | "weight_decay": 0.01, 15 | "betas": [ 16 | 0.90, 17 | 0.95 18 | ], 19 | "eps": 1e-5, 20 | "val_every_n_steps": 250, 21 | "save_every_n_steps": 5000, 22 | "max_sequence_length": 768, 23 | "use_bf16": true, 24 | "use_wandb": true, 25 | "use_pretrained": false 26 | } -------------------------------------------------------------------------------- /config/kokoro_v0.19/libritts_kokoro.json: -------------------------------------------------------------------------------- 1 | { 2 | "project_name": "smoltts_160m_librispeech", 3 | "checkpoint_path": "../checkpoints", 4 | "init_folder": "../inits/smoltts_byte_kokoro", 5 | "dataset_path": "../datasets/tokenized_libritts_bytes_kokoro_speaker", 6 | "model_path": "pretrained_model", 7 | "batch_size": 22, 8 | "max_epochs": 8, 9 | "num_workers": 4, 10 | "gradient_clip": 1.0, 11 | "learning_rate": 5e-4, 12 | "lr_start": 1e-3, 13 | "lr_warmup_steps": 30000, 14 | "weight_decay": 0.01, 15 | "betas": [ 16 | 0.90, 17 | 0.95 18 | ], 19 | "eps": 1e-5, 20 | "val_every_n_steps": 250, 21 | "save_every_n_steps": 5000, 22 | "max_sequence_length": 768, 23 | "use_bf16": true, 24 | "use_wandb": true, 25 | "use_pretrained": false 26 | } -------------------------------------------------------------------------------- /config/kokoro_v0.19/pg_kokoro.json: -------------------------------------------------------------------------------- 1 | { 2 | "project_name": "smoltts_kokoro", 3 | "checkpoint_path": "../checkpoints", 4 | "init_folder": "../inits/smoltts_byte_kokoro", 5 | "dataset_path": "../datasets/tokenized_project_gutenberg_bytes_kokoro_all", 6 | "model_path": "pretrained_model", 7 | "batch_size": 22, 8 | "max_epochs": 3, 9 | "num_workers": 4, 10 | "gradient_clip": 1.0, 11 | "learning_rate": 3e-4, 12 | "lr_start": 6e-4, 13 | "lr_warmup_steps": 30000, 14 | "weight_decay": 0.01, 15 | "betas": [ 16 | 0.90, 17 | 0.95 18 | ], 19 | "eps": 1e-5, 20 | "val_every_n_steps": 500, 21 | "save_every_n_steps": 5000, 22 | "max_sequence_length": 768, 23 | "use_bf16": true, 24 | "use_wandb": false, 25 | "use_pretrained": false 26 | } -------------------------------------------------------------------------------- /config/kokoro_v0.19/pg_kokoro_wte.json: -------------------------------------------------------------------------------- 1 | { 2 | "project_name": "smoltts_kokoro", 3 | "checkpoint_path": "../checkpoints", 4 | "init_folder": "../inits/smoltts_byte_wte", 5 | "dataset_path": "../datasets/tokenized_project_gutenberg_bytes_kokoro_tau", 6 | "model_path": "pretrained_model", 7 | "batch_size": 48, 8 | "max_epochs": 3, 9 | "num_workers": 4, 10 | "gradient_clip": 1.0, 11 | "learning_rate": 3e-4, 12 | "lr_start": 6e-4, 13 | "lr_warmup_steps": 30000, 14 | "weight_decay": 0.01, 15 | "betas": [ 16 | 0.90, 17 | 0.95 18 | ], 19 | "eps": 1e-5, 20 | "val_every_n_steps": 500, 21 | "save_every_n_steps": 5000, 22 | "max_sequence_length": 768, 23 | "use_bf16": true, 24 | "use_wandb": true, 25 | "use_pretrained": false 26 | } -------------------------------------------------------------------------------- /config/kokoro_v0.19/pg_kokoro_wte_nsys.json: -------------------------------------------------------------------------------- 1 | { 2 | "project_name": "smoltts_kokoro", 3 | "checkpoint_path": "../checkpoints", 4 | "init_folder": "../inits/smoltts_byte_wte", 5 | "dataset_path": "../datasets/tokenized_project_gutenberg_bytes_kokoro_all", 6 | "model_path": "pretrained_model", 7 | "batch_size": 24, 8 | "max_epochs": 1, 9 | "num_workers": 4, 10 | "gradient_clip": 1.0, 11 | "learning_rate": 3e-4, 12 | "lr_start": 6e-4, 13 | "lr_warmup_steps": 30000, 14 | "weight_decay": 0.01, 15 | "betas": [ 16 | 0.90, 17 | 0.95 18 | ], 19 | "eps": 1e-5, 20 | "val_every_n_steps": 500, 21 | "save_every_n_steps": 5000, 22 | "max_sequence_length": 768, 23 | "use_bf16": true, 24 | "use_wandb": false, 25 | "use_pretrained": false 26 | } -------------------------------------------------------------------------------- /config/kokoro_v1/baseline.json: -------------------------------------------------------------------------------- 1 | { 2 | "project_name": "smoltts_kokoro", 3 | "checkpoint_path": "../checkpoints", 4 | "init_folder": "../inits/smoltts_byte_kokoro", 5 | "dataset_path": "../datasets/byte-tokenized-pg-kokoro_v1", 6 | "model_path": "pretrained_model", 7 | "batch_size": 22, 8 | "max_epochs": 3, 9 | "num_workers": 4, 10 | "gradient_clip": 1.0, 11 | "learning_rate": 5e-4, 12 | "lr_start": 1e-3, 13 | "lr_warmup_steps": 70000, 14 | "weight_decay": 0.01, 15 | "betas": [ 16 | 0.90, 17 | 0.95 18 | ], 19 | "eps": 1e-5, 20 | "val_every_n_steps": 1000, 21 | "save_every_n_steps": 5000, 22 | "max_sequence_length": 768, 23 | "use_bf16": true, 24 | "use_wandb": false, 25 | "use_pretrained": false 26 | } -------------------------------------------------------------------------------- /config/kokoro_v1/deduplicate_code_0.json: -------------------------------------------------------------------------------- 1 | { 2 | "project_name": "smoltts_kokoro", 3 | "checkpoint_path": "../checkpoints", 4 | "init_folder": "../inits/smoltts_byte_kokoro_layer", 5 | "dataset_path": "../datasets/byte-tokenized-pg-kokoro_v1.2", 6 | "model_path": "pretrained_model", 7 | "batch_size": 22, 8 | "max_epochs": 3, 9 | "num_workers": 4, 10 | "gradient_clip": 1.0, 11 | "learning_rate": 7e-4, 12 | "lr_start": 1.5e-3, 13 | "lr_warmup_steps": 5000, 14 | "weight_decay": 0.01, 15 | "betas": [ 16 | 0.90, 17 | 0.95 18 | ], 19 | "eps": 1e-5, 20 | "val_every_n_steps": 1000, 21 | "save_every_n_steps": 5000, 22 | "max_sequence_length": 768, 23 | "use_bf16": true, 24 | "use_wandb": true, 25 | "use_pretrained": false 26 | } -------------------------------------------------------------------------------- /config/kokoro_v1/scaleup.json: -------------------------------------------------------------------------------- 1 | { 2 | "project_name": "smoltts_kokoro", 3 | "checkpoint_path": "../checkpoints", 4 | "init_folder": "../inits/smoltts_byte_kokoro_150m", 5 | "dataset_path": "../datasets/byte-tokenized-pg-kokoro_v1", 6 | "model_path": "pretrained_model", 7 | "batch_size": 22, 8 | "max_epochs": 2, 9 | "num_workers": 4, 10 | "gradient_clip": 1.0, 11 | "learning_rate": 3e-4, 12 | "lr_start": 1e-3, 13 | "lr_warmup_steps": 70000, 14 | "weight_decay": 0.01, 15 | "betas": [ 16 | 0.90, 17 | 0.95 18 | ], 19 | "eps": 1e-5, 20 | "val_every_n_steps": 1000, 21 | "save_every_n_steps": 5000, 22 | "max_sequence_length": 768, 23 | "use_bf16": true, 24 | "use_wandb": true, 25 | "use_pretrained": false 26 | } -------------------------------------------------------------------------------- /config/librispeech/librispeech.json: -------------------------------------------------------------------------------- 1 | { 2 | "project_name": "smoltts_160m_librispeech", 3 | "checkpoint_path": "../checkpoints", 4 | "init_folder": "../inits/smoltts_init", 5 | "dataset_path": "../datasets/tokenized_libritts", 6 | "model_path": "pretrained_model", 7 | "batch_size": 20, 8 | "max_epochs": 6, 9 | "num_workers": 4, 10 | "gradient_clip": 1.0, 11 | "learning_rate": 5e-4, 12 | "lr_start": 1e-3, 13 | "lr_warmup_steps": 30000, 14 | "weight_decay": 0.01, 15 | "betas": [ 16 | 0.90, 17 | 0.95 18 | ], 19 | "eps": 1e-5, 20 | "val_every_n_steps": 250, 21 | "save_every_n_steps": 5000, 22 | "max_sequence_length": 768, 23 | "use_bf16": true, 24 | "use_wandb": true, 25 | "use_pretrained": false 26 | } -------------------------------------------------------------------------------- /config/librispeech/librispeech_cooldown.json: -------------------------------------------------------------------------------- 1 | { 2 | "project_name": "smoltts_160m_librispeech", 3 | "checkpoint_path": "../checkpoints", 4 | "model_path": "pretrained_model", 5 | "batch_size": 2, 6 | "max_epochs": 10, 7 | "num_workers": 4, 8 | "gradient_clip": 1.0, 9 | "learning_rate": 5e-6, 10 | "lr_start": 3e-4, 11 | "lr_warmup_steps": 9000, 12 | "weight_decay": 0.01, 13 | "betas": [ 14 | 0.90, 15 | 0.95 16 | ], 17 | "eps": 1e-5, 18 | "val_every_n_steps": 250, 19 | "save_every_n_steps": 2000, 20 | "max_sequence_length": 768, 21 | "use_bf16": true, 22 | "use_wandb": true, 23 | "use_pretrained": false 24 | } -------------------------------------------------------------------------------- /config/ljspeech/ljspeech.json: -------------------------------------------------------------------------------- 1 | { 2 | "project_name": "ljspeech_train", 3 | "checkpoint_path": "../checkpoints", 4 | "model_path": "pretrained_model", 5 | "batch_size": 32, 6 | "max_epochs": 2, 7 | "num_workers": 4, 8 | "gradient_clip": 1.0, 9 | "learning_rate": 5e-4, 10 | "lr_start": 1e-5, 11 | "lr_warmup_steps": 200, 12 | "weight_decay": 0.01, 13 | "betas": [ 14 | 0.9, 15 | 0.95 16 | ], 17 | "eps": 1e-5, 18 | "val_every_n_steps": 100, 19 | "save_every_n_steps": 2000, 20 | "max_sequence_length": 512, 21 | "use_bf16": true, 22 | "use_wandb": true, 23 | "use_pretrained": true 24 | } -------------------------------------------------------------------------------- /data_pipeline/README.md: -------------------------------------------------------------------------------- 1 | # Dataset utils 2 | 3 | > [!INFO] 4 | > Yes, I know, this isn't Meta AI-tier data engineering. Sue me. 5 | 6 | Here's what you need to bootstrap the training setup: 7 | 8 | ## Model config 9 | 10 | ### Byte-level tokenizer 11 | 12 | If you're training a _byte-level_ model from scratch, in the CLI, first navigate to the project root. 13 | Then create an `inits` folder for the model config. 14 | The name of your init can be arbitrary; we'll just assume you're using `smoltts_byte_kokoro`: 15 | 16 | ```bash 17 | # Name can be arbitrary 18 | mkdir -p inits/smoltts_byte_kokoro 19 | ``` 20 | 21 | Then let's create a byte-level HuggingFace tokenizer: 22 | 23 | ```bash 24 | # From project root 25 | uv run data_pipeline/scripts/create_bytelevel_init.py --out-dir inits/smoltts_byte_kokoro 26 | ``` 27 | 28 | Finally, copy your model config: 29 | 30 | ```bash 31 | # Substitute model config as desired 32 | cp sample_model_sizes/smoltts_byte_60m_wte.json ./inits/smoltts_byte_kokoro/config.json 33 | ``` 34 | 35 | ### BPE 36 | 37 | If you're using a real BPE tokenizer, run the `create_smoltts_init` notebook: Convert your LM base and tokenizer to DualAR format. 38 | 39 | ## Audio dataset creation 40 | 41 | ### LibriTTS-R dataset 42 | 43 | - `encode_libritts.py` to encode, then 44 | - `upload_libritts.ipynb` to upload if you feel this to be necessary. 45 | 46 | Thanks to HF Datasets streaming, does not require _persisting_ all ~100-200GB of audio to your hard drive, just _downloading_ it. **Skip this step** if using [pre-encoded LibriTTS-R](https://huggingface.co/datasets/jkeisling/libritts-r-mimi). 47 | 48 | ### Synthetic data with Kokoro 49 | 50 | See a different repo 51 | 52 | ### i love...EMILIA 53 | 54 | ## Tokenizing your data 55 | 56 | Your dataset must contain the following columns: 57 | 58 | - `text_normalized`: String 59 | - `codes`: Tensor of Mimi codes 60 | - `speaker_id`: string, name of speaker (e.g. "alloy") 61 | 62 | Create your config (ideally version-controlled in `data_pipeline/scripts/audio_tokenizer_configs`). Then, for example: 63 | 64 | ```bash 65 | uv run data_pipeline/scripts/chatml_tokenize_dataset.py \ 66 | -c ./data_pipeline/scripts/audio_tokenizer_configs/project_gutenberg_v2.1.json \ 67 | -o ./datasets/byte-tokenized-pg-kokoro_v1 68 | ``` 69 | 70 | For legacy BPE tokenization, use`tokenize_libritts.ipynb`: Tokenize, ChatML format, and pack by speaker index 71 | -------------------------------------------------------------------------------- /data_pipeline/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndlessReform/smoltts/410d6524966e2d433ce565e4f14f1fac2f983ff9/data_pipeline/__init__.py -------------------------------------------------------------------------------- /data_pipeline/encode_libritts.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, Audio 2 | import datasets 3 | from transformers import MimiModel 4 | from torch.nn.utils.rnn import pad_sequence 5 | import torch 6 | import math 7 | import argparse 8 | from tqdm import tqdm 9 | from pathlib import Path 10 | 11 | 12 | def get_target_length(arr: torch.Tensor) -> int: 13 | return math.ceil(arr.size(-1) / (SAMPLING_RATE / 12.5)) 14 | 15 | 16 | def batch_wav_encoder(batch_dict): 17 | batch = batch_dict["audio"] 18 | target_lengths = [get_target_length(sample["array"]) for sample in batch] 19 | 20 | # Single batch processing 21 | padded_batch = pad_sequence( 22 | [sample["array"] for sample in batch], batch_first=True 23 | ).unsqueeze( 24 | 1 25 | ) # (batch, 1, time) 26 | 27 | padding_mask = (padded_batch != 0).float() 28 | 29 | with torch.no_grad(): 30 | enc_out_cuda = model.encode( 31 | padded_batch.to("cuda"), padding_mask=padding_mask.to("cuda") 32 | ) 33 | enc_out = enc_out_cuda.audio_codes[:, 0:8, :].clone().cpu() 34 | del enc_out_cuda 35 | torch.cuda.empty_cache() 36 | 37 | # Process outputs 38 | chunked = torch.unbind(enc_out, dim=0) 39 | outputs = [t[:, :l] for t, l in zip(chunked, target_lengths)] 40 | 41 | return {"codes": outputs} 42 | 43 | 44 | def main(): 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument( 47 | "--output", 48 | type=str, 49 | default="encoded_libritts", 50 | help="Directory where each split folder will be saved", 51 | ) 52 | args = parser.parse_args() 53 | 54 | global SAMPLING_RATE, model 55 | SAMPLING_RATE = 24_000 56 | 57 | # Load model 58 | model = MimiModel.from_pretrained("kyutai/mimi").to("cuda") 59 | 60 | # The splits we want to process 61 | all_splits = [ 62 | "dev.clean", 63 | "test.clean", 64 | "train.clean.100", 65 | "train.clean.360", 66 | ] 67 | 68 | # Load the source dataset with streaming 69 | dataset_dict = load_dataset("mythicinfinity/libritts_r", "clean", streaming=True) 70 | 71 | output_dir = Path(args.output) 72 | 73 | for split in all_splits: 74 | print(f"\nProcessing {split}...") 75 | 76 | # We'll stream the dataset, with the audio column cast to the correct sample rate 77 | streamed_split = dataset_dict[split] 78 | streamed_split = streamed_split.cast_column( 79 | "audio", Audio(sampling_rate=SAMPLING_RATE) 80 | ) 81 | streamed_split = streamed_split.with_format("torch") 82 | 83 | # We'll accumulate encoded rows in memory 84 | # (If you truly can't fit them all, you'd reintroduce sharding.) 85 | encoded_rows = [] 86 | print("Encoding in batches...") 87 | 88 | for batch_out in tqdm( 89 | streamed_split.map( 90 | batch_wav_encoder, 91 | batched=True, 92 | batch_size=24, 93 | remove_columns=["audio"], 94 | ), 95 | desc=f"Encoding {split}", 96 | ): 97 | encoded_rows.append(batch_out) 98 | 99 | new_data = datasets.Dataset.from_list(encoded_rows) 100 | 101 | # Save to disk in a subfolder named after the split 102 | split_folder = output_dir / split 103 | print(f"Saving {split} to {split_folder}...") 104 | split_folder.mkdir(parents=True, exist_ok=True) 105 | new_data.save_to_disk(str(split_folder)) 106 | 107 | print(f"Finished {split}") 108 | 109 | print("\nAll splits processed. Done!") 110 | 111 | 112 | if __name__ == "__main__": 113 | main() 114 | -------------------------------------------------------------------------------- /data_pipeline/notebooks/create_bytelevel_init.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Byte level ablation" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "First, let's make a spurious \"BPE\" tokenizer without any actual byte pairs. \n", 15 | "\n", 16 | "This assumes you copied and pasted a folder to `../checkpoints/smoltts_byte` with the `config.json` of a normal model in it. If you didn't, do that now, by running the regular \"create init\" notebook." 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "from tokenizers import Tokenizer, models, normalizers, decoders, pre_tokenizers\n", 26 | "from tokenizers.trainers import BpeTrainer" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "# Initialize tokenizer\n", 36 | "tokenizer = Tokenizer(models.BPE())\n", 37 | "\n", 38 | "# Configure trainer\n", 39 | "trainer = BpeTrainer(vocab_size=256, special_tokens=[])\n", 40 | "\n", 41 | "# Generate actual bytes for training\n", 42 | "byte_data = [bytes([i]) for i in range(256)] # Create actual bytes\n", 43 | "# Convert to strings that preserve the byte values\n", 44 | "byte_strings = [b.decode('latin-1') for b in byte_data] \n", 45 | "\n", 46 | "# Train the tokenizer\n", 47 | "tokenizer.train_from_iterator(byte_strings, trainer=trainer)\n", 48 | "# tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)\n", 49 | "tokenizer.pre_tokenizer = None\n", 50 | "tokenizer.normalizer = None\n", 51 | "tokenizer.decoder = decoders.ByteLevel()\n", 52 | "\n", 53 | "# Check the result\n", 54 | "print(tokenizer.get_vocab()) # Should show all 256 bytes + special tokens" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "Let's test it works quickly as a round-trip:" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "evil_string = \"心\".encode(\"utf-8\").decode(\"latin-1\")\n", 71 | "print(f\"Evil string: {evil_string}\")\n", 72 | "enc = tokenizer.encode(evil_string)\n", 73 | "print(enc.ids)\n", 74 | "decoded_bytes = bytes(enc.ids).decode('utf-8')\n", 75 | "decoded_bytes" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": {}, 81 | "source": [ 82 | "## Special tokens" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "CODEBOOK_SIZE=2048\n", 92 | "semantic_tokens = [f\"<|semantic:{i}|>\" for i in range(CODEBOOK_SIZE)]\n", 93 | "control_tokens = [\n", 94 | " \"system\", \n", 95 | " \"user\", \n", 96 | " \"assistant\",\n", 97 | " \"<|british|>\",\n", 98 | " \"<|american|>\",\n", 99 | " \"<|male|>\",\n", 100 | " \"<|female|>\",\n", 101 | " \"<|unknown|>\",\n", 102 | " \"<|endoftext|>\", \n", 103 | " \"<|voice|>\", \n", 104 | " \"<|semantic|>\",\n", 105 | " \"<|pad|>\",\n", 106 | " \"<|epad|>\",\n", 107 | " \"<|im_start|>\", \n", 108 | " \"<|im_end|>\", \n", 109 | "]\n", 110 | "# Reserve individual speaker IDs as control tokens\n", 111 | "unused_tokens = [f\"<|speaker:{i}|>\" for i in range(64 - len(control_tokens))]\n", 112 | "charset = [*control_tokens, *unused_tokens, *semantic_tokens]\n", 113 | "print(len(charset))\n", 114 | "charset[:67]\n" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "tokenizer.add_special_tokens(charset)\n", 124 | "tokenizer.pad_token = \"<|pad|>\"\n", 125 | "tokenizer.eos_token = \"<|endoftext|>\"\n", 126 | "tokenizer.bos_token = \"<|im_start|>\"\n", 127 | "tokenizer.unk_token = \"<|unknown|>\"\n", 128 | "tokenizer.chat_template = \"{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}\"" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "from transformers import PreTrainedTokenizerFast\n", 138 | "\n", 139 | "# Create the fast tokenizer with all settings in one shot\n", 140 | "final_tokenizer = PreTrainedTokenizerFast(\n", 141 | " tokenizer_object=tokenizer, # your existing byte-level tokenizer\n", 142 | " bos_token=\"<|im_start|>\",\n", 143 | " eos_token=\"<|endoftext|>\",\n", 144 | " unk_token=\"<|unknown|>\",\n", 145 | " pad_token=\"<|pad|>\",\n", 146 | " chat_template=\"\"\"{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}\"\"\"\n", 147 | ")\n", 148 | "\n", 149 | "# Save it\n", 150 | "final_tokenizer.save_pretrained(\"../checkpoints/smoltts_byte_kokoro\")" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "Let's give this a final test before we dump compute into it:" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [ 166 | "# Test encoding of ASCII + special tokens + semantic tokens\n", 167 | "test_prompt = \"<|im_start|>system\\n<|american|><|female|><|speaker:4|><|im_end|>\\n<|im_start|>user\\nHello!<|im_end|>\\n<|semantic:42|>\"\n", 168 | "\n", 169 | "# Encode and look at IDs\n", 170 | "ids = tokenizer.encode(test_prompt.encode(\"utf-8\").decode(\"latin-1\"))\n", 171 | "print(\"Token IDs:\", ids)\n", 172 | "\n", 173 | "# Test decoding individual tokens\n", 174 | "print(\"\\nDecoding each token:\")\n", 175 | "for id in ids.ids:\n", 176 | " if id <= 255:\n", 177 | " print(f\"Byte {id}: {repr(tokenizer.decode([id]))}\")\n", 178 | " else:\n", 179 | " print(f\"Special {id}: {repr(tokenizer.id_to_token(id))}\")\n", 180 | "\n", 181 | "# Verify our semantic token ID maps correctly\n", 182 | "semantic_42 = tokenizer.encode(\"<|semantic:42|>\")\n", 183 | "print(\"\\nSemantic token 42:\", semantic_42.ids)\n", 184 | "print(\"Decodes back to:\", repr(tokenizer.decode(semantic_42.ids)))" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "metadata": {}, 190 | "source": [ 191 | "Let's save back the vocab size:" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "import json\n", 201 | "\n", 202 | "# Load config\n", 203 | "with open('../checkpoints/smoltts_byte/config.json', 'r') as f:\n", 204 | " config = json.load(f)\n", 205 | "\n", 206 | "# Get vocab size from tokenizer \n", 207 | "vocab_size = 256 + len(charset) # Base bytes + special tokens\n", 208 | "config['vocab_size'] = vocab_size\n", 209 | "\n", 210 | "# Save updated config\n", 211 | "with open('../checkpoints/smoltts_byte_kokoro/config.json', 'w') as f:\n", 212 | " json.dump(config, f, indent=4)\n", 213 | "\n", 214 | "print(f\"Updated vocab_size to {vocab_size}\")" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "# Debug space encoding\n", 224 | "print(\"Raw space char code:\", ord(\" \")) # Should be 32\n", 225 | "print(\"Space as bytes:\", \" \".encode('utf-8')) # Should be b' '\n", 226 | "print(\"Space as latin1:\", \" \".encode('latin1')) # Should be b' '\n", 227 | "\n", 228 | "# Test different space characters\n", 229 | "print(\"\\nTokenizer tests:\")\n", 230 | "print('ASCII space (32):', tokenizer.encode(\" \").ids)\n", 231 | "print('NBSP (160):', tokenizer.encode(\"\\u00A0\"))\n", 232 | "print('Raw byte 32:', tokenizer.encode(bytes([32]).decode('latin1')))\n", 233 | "\n", 234 | "# Look at normalizer config\n", 235 | "print(\"\\nTokenizer config:\")\n", 236 | "\n", 237 | "# Try encoding a string with spaces\n", 238 | "print(\"\\nString with spaces:\")\n", 239 | "test = \"a b c\"\n", 240 | "print(\"String:\", repr(test))\n", 241 | "print(\"Encoded:\", tokenizer.encode(test).ids)" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "# Compare AutoTokenizer vs PreTrainedTokenizerFast\n", 251 | "from transformers import AutoTokenizer, PreTrainedTokenizerFast\n", 252 | "\n", 253 | "auto = AutoTokenizer.from_pretrained(\"../checkpoints/smoltts_byte\")\n", 254 | "fast = PreTrainedTokenizerFast.from_pretrained(\"../checkpoints/smoltts_byte\")\n", 255 | "\n", 256 | "test = \"a, b\"\n", 257 | "print(\"AutoTokenizer config:\")\n", 258 | "print(\"Type:\", type(auto))\n", 259 | "print(\"Normalizer:\", auto.backend_tokenizer.normalizer)\n", 260 | "print(\"Pre-tokenizer:\", auto.backend_tokenizer.pre_tokenizer)\n", 261 | "print(\"Post-processor:\", auto.backend_tokenizer.post_processor)\n", 262 | "\n", 263 | "print(\"\\nPreTrainedTokenizerFast config:\")\n", 264 | "print(\"Type:\", type(fast))\n", 265 | "print(\"Normalizer:\", fast.backend_tokenizer.normalizer)\n", 266 | "print(\"Pre-tokenizer:\", fast.backend_tokenizer.pre_tokenizer)\n", 267 | "print(\"Post-processor:\", fast.backend_tokenizer.post_processor)\n", 268 | "\n", 269 | "print(\"\\nEncoding tests:\")\n", 270 | "print(\"Auto:\", auto.encode(test))\n", 271 | "print(\"Fast:\", fast.encode(test))\n", 272 | "\n", 273 | "# Check what tokenizer_config.json looks like\n", 274 | "import json\n", 275 | "with open(\"../checkpoints/smoltts_byte/tokenizer_config.json\") as f:\n", 276 | " config = json.load(f)\n", 277 | "print(\"\\nTokenizer config file:\")\n", 278 | "print(json.dumps(config, indent=2))" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": null, 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [ 287 | "tokens = [265, 256, 10, 83, 112, 101, 97, 107, 32, 111, 117, 116, 32, 116, 104, 101, 32, 112, 114, 111, 118, 105, 100, 101, 100, 32, 116, 101, 120, 116, 266, 265, 257, 10, 116, 101, 115, 116, 266, 265, 258, 10]\n", 288 | "tokenizer.decode(tokens)" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": null, 294 | "metadata": {}, 295 | "outputs": [], 296 | "source": [ 297 | "tokenizer.decode([265])" 298 | ] 299 | } 300 | ], 301 | "metadata": { 302 | "kernelspec": { 303 | "display_name": ".venv", 304 | "language": "python", 305 | "name": "python3" 306 | }, 307 | "language_info": { 308 | "codemirror_mode": { 309 | "name": "ipython", 310 | "version": 3 311 | }, 312 | "file_extension": ".py", 313 | "mimetype": "text/x-python", 314 | "name": "python", 315 | "nbconvert_exporter": "python", 316 | "pygments_lexer": "ipython3", 317 | "version": "3.9.20" 318 | } 319 | }, 320 | "nbformat": 4, 321 | "nbformat_minor": 2 322 | } 323 | -------------------------------------------------------------------------------- /data_pipeline/notebooks/create_smoltts_init.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Create LM initialization for DualAR transformer\n", 8 | "\n", 9 | "As of 2024-12-30 we're using Huggingface [SmolLM2-135M](https://huggingface.co/HuggingFaceTB/SmolLM2-135M-Instruct) for pretrained LM initialization. However, it needs some minor formatting changes to work with the Fish Speech / fish_speech.rs format." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import os\n", 19 | "from transformers import AutoModelForCausalLM, AutoTokenizer\n", 20 | "\n", 21 | "MODEL = \"HuggingFaceTB/SmolLM2-135M-Instruct\"\n", 22 | "checkpoint_dir = \"../checkpoints\"\n", 23 | "os.makedirs(checkpoint_dir, exist_ok=True)\n", 24 | "checkpoint_pretrained_dir = f\"../checkpoints/{MODEL.split('/')[-1]}\"\n", 25 | "os.makedirs(checkpoint_pretrained_dir, exist_ok=True)\n", 26 | "\n", 27 | "# Step (b): Download the HuggingFace model and save to ../checkpoints\n", 28 | "model_name = \"HuggingFaceTB/SmolLM2-135M-Instruct\"\n", 29 | "\n", 30 | "print(\"Downloading model...\")\n", 31 | "# Load the model and tokenizer\n", 32 | "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", 33 | "model = AutoModelForCausalLM.from_pretrained(model_name)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "print(f\"Saving model to {checkpoint_dir}...\")\n", 43 | "model.save_pretrained(checkpoint_pretrained_dir)\n", 44 | "tokenizer.save_pretrained(checkpoint_pretrained_dir)\n", 45 | "\n", 46 | "print(\"Model downloaded and saved successfully!\")" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "from safetensors.torch import load_file\n", 56 | "\n", 57 | "tensors = load_file(\"../checkpoints/SmolLM2-135M-Instruct/model.safetensors\")\n", 58 | "list(tensors.keys())" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "Unfortunately the [Fish Speech](https://github.com/fishaudio/fish-speech) DualAR backbone has different weight keys despite being vanilla Llama 3 architecture, so we have to rename them:" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "renamed_tensors = {\n", 75 | " key.replace('model.embed_tokens', 'model.embeddings')\n", 76 | " .replace('self_attn', 'attention')\n", 77 | " .replace('post_attention_layernorm', 'attention_norm')\n", 78 | " .replace('input_layernorm', 'ffn_norm')\n", 79 | " .replace('mlp', 'feed_forward')\n", 80 | " .replace('k_proj', 'wk')\n", 81 | " .replace('q_proj', 'wq')\n", 82 | " .replace('v_proj', 'wv')\n", 83 | " .replace('o_proj', 'wo')\n", 84 | " .replace('gate_proj', 'w1')\n", 85 | " .replace('down_proj', 'w2')\n", 86 | " .replace('up_proj', 'w3')\n", 87 | " .split('model.')[1]: tensor \n", 88 | " for key, tensor in tensors.items()\n", 89 | "}\n", 90 | "list(renamed_tensors.keys())" 91 | ] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "metadata": {}, 96 | "source": [ 97 | "Following existing literature, we initialize the semantic codebook embedding embeddings from the mean of the existing token embeddings, to lower the initial loss from random init. Empirically this lowers base loss from 140 to 25 at the beginning of training, which though still far above `ln(52000)=10` for the base is good enough." 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "import torch\n", 107 | "\n", 108 | "new_tokens = renamed_tensors['embeddings.weight'].mean(dim=0, keepdim=True).repeat(2048, 1)\n", 109 | "# nn.Embedding(2048, 576)\n", 110 | "extended_embeddings = torch.cat([\n", 111 | " renamed_tensors['embeddings.weight'],\n", 112 | " new_tokens\n", 113 | "], dim=0)\n", 114 | "\n", 115 | "renamed_tensors['embeddings.weight'] = extended_embeddings\n", 116 | "renamed_tensors['embeddings.weight'].shape" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "import torch\n", 126 | "import shutil\n", 127 | "from pathlib import Path\n", 128 | "\n", 129 | "source_dir = Path(checkpoint_pretrained_dir)\n", 130 | "dest_dir = Path(\"../checkpoints/smoltts_init\")\n", 131 | "\n", 132 | "os.makedirs(dest_dir, exist_ok=True)\n", 133 | "torch.save(renamed_tensors, dest_dir / \"model.pth\")\n", 134 | "\n", 135 | "\n", 136 | "# Ensure the destination directory exists\n", 137 | "dest_dir.mkdir(parents=True, exist_ok=True)\n", 138 | "\n", 139 | "# Copy all .json and .txt files\n", 140 | "for extension in (\"*.json\", \"*.txt\"):\n", 141 | " for file in source_dir.glob(extension):\n", 142 | " shutil.copy(file, dest_dir)" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": {}, 148 | "source": [ 149 | "Fish Speech uses a different config format than HF Transformers, so I'm going to define it by fiat here." 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "import json\n", 159 | "\n", 160 | "with open(dest_dir / \"config.json\") as f:\n", 161 | " hf_config = json.load(f)\n", 162 | "\n", 163 | "# Mimi codebook dimension\n", 164 | "CODEBOOK_SIZE = 2048\n", 165 | "\n", 166 | "config = {\n", 167 | " \"attention_qkv_bias\": False,\n", 168 | " \"codebook_size\": CODEBOOK_SIZE,\n", 169 | " \"dim\": hf_config[\"hidden_size\"],\n", 170 | " \"dropout\": 0.1,\n", 171 | " \"fast_attention_qkv_bias\": False,\n", 172 | " # TODO: Following Fish Speech, keeping fast layer dimensions the same for now. May revisit this later\n", 173 | " \"fast_dim\": hf_config[\"hidden_size\"],\n", 174 | " \"fast_head_dim\": hf_config[\"head_dim\"],\n", 175 | " \"fast_intermediate_size\": hf_config[\"intermediate_size\"],\n", 176 | " \"fast_n_head\": hf_config[\"num_attention_heads\"],\n", 177 | " \"fast_n_local_heads\": hf_config[\"num_key_value_heads\"],\n", 178 | " \"head_dim\": hf_config[\"head_dim\"],\n", 179 | " \"initializer_range\": hf_config[\"initializer_range\"],\n", 180 | " \"intermediate_size\": hf_config[\"intermediate_size\"],\n", 181 | " \"is_reward_model\": False,\n", 182 | " \"max_seq_len\": hf_config[\"max_position_embeddings\"],\n", 183 | " \"model_type\": \"dual_ar\",\n", 184 | " # TODO: Following Fish Speech for now\n", 185 | " \"n_fast_layer\": 4,\n", 186 | " \"n_head\": hf_config[\"num_attention_heads\"],\n", 187 | " \"n_local_heads\": hf_config[\"num_key_value_heads\"],\n", 188 | " \"norm_eps\": hf_config[\"rms_norm_eps\"],\n", 189 | " # Mimi\n", 190 | " \"num_codebooks\": 8,\n", 191 | " \"rope_base\": hf_config[\"rope_theta\"],\n", 192 | " \"scale_codebook_embeddings\": False,\n", 193 | " \"share_codebook_embeddings\": True,\n", 194 | " \"tie_word_embeddings\": hf_config[\"tie_word_embeddings\"],\n", 195 | " \"use_gradient_checkpointing\": True,\n", 196 | " # TODO: handle control tokens\n", 197 | " \"vocab_size\": hf_config[\"vocab_size\"] + CODEBOOK_SIZE\n", 198 | "}\n", 199 | "\n", 200 | "config" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "output_path = dest_dir / \"config.json\"\n", 210 | "with output_path.open('w') as f:\n", 211 | " json.dump(config, f, indent=2)" 212 | ] 213 | }, 214 | { 215 | "cell_type": "markdown", 216 | "metadata": {}, 217 | "source": [ 218 | "Our model now must:\n", 219 | "- Randomly initialize the fast transformer\n", 220 | "- Merge attention qkv into a single tensor (to save on kernel launch overhead and improve hardware utilization) \n", 221 | "\n", 222 | "The DualARTransformer modeling code will do this, but we need to load the model once.\n", 223 | "\n", 224 | "TODO: find more principled initialization strategies!" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": null, 230 | "metadata": {}, 231 | "outputs": [], 232 | "source": [ 233 | "from dual_ar.model.dual_ar import DualARTransformer\n", 234 | "\n", 235 | "model = DualARTransformer.from_pretrained(\n", 236 | " path=\"../checkpoints/smoltts_init\",\n", 237 | " load_weights=True\n", 238 | ")" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "state_dict = model.state_dict()\n", 248 | "torch.save(state_dict, dest_dir / \"model.pth\")" 249 | ] 250 | }, 251 | { 252 | "cell_type": "markdown", 253 | "metadata": {}, 254 | "source": [ 255 | "We're now done with modeling code. Now we need to extend the tokenizer to handle semantic tokens.\n", 256 | "\n", 257 | "TODO: Add control / modality tokens, PAD / EPAD and do ablations!" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "def make_tokenizer():\n", 267 | " tokenizer = AutoTokenizer.from_pretrained(MODEL, use_system_prompt=False)\n", 268 | " semantic_tokens = [f\"<|semantic:{i}|>\" for i in range(0, CODEBOOK_SIZE)]\n", 269 | " additional_special_tokens = [*semantic_tokens]\n", 270 | " tokenizer.add_special_tokens({\n", 271 | " \"additional_special_tokens\": additional_special_tokens\n", 272 | " })\n", 273 | " # Remove inane overly clever chat template\n", 274 | " if MODEL == \"HuggingFaceTB/SmolLM2-135M-Instruct\":\n", 275 | " tokenizer.chat_template = \"{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}\"\n", 276 | " \n", 277 | " tokenizer.save_pretrained(dest_dir)\n", 278 | "\n", 279 | "make_tokenizer()" 280 | ] 281 | }, 282 | { 283 | "cell_type": "markdown", 284 | "metadata": {}, 285 | "source": [ 286 | "All done!" 287 | ] 288 | }, 289 | { 290 | "cell_type": "markdown", 291 | "metadata": {}, 292 | "source": [ 293 | "## Optional: test model works" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": null, 299 | "metadata": {}, 300 | "outputs": [], 301 | "source": [ 302 | "import torch\n", 303 | "\n", 304 | "device = \"cuda\"\n", 305 | "model = model.to(device)\n", 306 | "model = model.to(torch.bfloat16)" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": null, 312 | "metadata": {}, 313 | "outputs": [], 314 | "source": [ 315 | "import os\n", 316 | "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"\n", 317 | "model = torch.compile(model)" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": null, 323 | "metadata": {}, 324 | "outputs": [], 325 | "source": [ 326 | "tensor = torch.zeros(1, 9, 1, dtype=torch.int32).to(\"cuda\")\n", 327 | "with torch.no_grad():\n", 328 | " out = model.forward(tensor, None)\n", 329 | " print(out.token_logits)" 330 | ] 331 | } 332 | ], 333 | "metadata": { 334 | "kernelspec": { 335 | "display_name": ".venv", 336 | "language": "python", 337 | "name": "python3" 338 | }, 339 | "language_info": { 340 | "codemirror_mode": { 341 | "name": "ipython", 342 | "version": 3 343 | }, 344 | "file_extension": ".py", 345 | "mimetype": "text/x-python", 346 | "name": "python", 347 | "nbconvert_exporter": "python", 348 | "pygments_lexer": "ipython3", 349 | "version": "3.9.6" 350 | } 351 | }, 352 | "nbformat": 4, 353 | "nbformat_minor": 2 354 | } 355 | -------------------------------------------------------------------------------- /data_pipeline/notebooks/decode_audio.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from transformers import MimiModel, AutoFeatureExtractor\n", 10 | "\n", 11 | "device = \"cpu\"\n", 12 | "feature_extractor = AutoFeatureExtractor.from_pretrained(\"kyutai/mimi\")\n", 13 | "model = MimiModel.from_pretrained(\"kyutai/mimi\")\n", 14 | "model = model.to(device)" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import torchaudio\n", 24 | "import torchaudio.transforms as T\n", 25 | "\n", 26 | "def load_and_process_wav(file_path):\n", 27 | " \"\"\"\n", 28 | " Load a WAV file, convert it to mono, resample it to 24kHz, and return as a tensor.\n", 29 | "\n", 30 | " Parameters:\n", 31 | " file_path (str): Path to the WAV file.\n", 32 | "\n", 33 | " Returns:\n", 34 | " torch.Tensor: Processed audio tensor.\n", 35 | " \"\"\"\n", 36 | " # Load the audio file\n", 37 | " waveform, sample_rate = torchaudio.load(file_path)\n", 38 | "\n", 39 | " # Convert to mono if not already\n", 40 | " if waveform.size(0) > 1:\n", 41 | " waveform = torch.mean(waveform, dim=0, keepdim=True)\n", 42 | "\n", 43 | " # Resample to 24kHz if needed\n", 44 | " target_sample_rate = 24000\n", 45 | " if sample_rate != target_sample_rate:\n", 46 | " resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)\n", 47 | " waveform = resampler(waveform)\n", 48 | "\n", 49 | " return waveform" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "import os\n", 59 | "\n", 60 | "def run_llama_generate(\n", 61 | " text=\"Can you generate five simple sentences for my child to practice speaking\",\n", 62 | " temp=0.1,\n", 63 | " checkpoint_path=\"../dual-ar/checkpoints/smoltts_scratch/\",\n", 64 | " working_dir=\"../../fish-speech.rs\" # Replace with your desired working directory\n", 65 | "):\n", 66 | " # Store current working directory\n", 67 | " original_dir = os.getcwd()\n", 68 | " \n", 69 | " try:\n", 70 | " # Change to desired working directory\n", 71 | " os.chdir(working_dir)\n", 72 | " \n", 73 | " # Construct the command\n", 74 | " cmd = f'cargo run --release --features cuda --bin llama_generate -- '\\\n", 75 | " f'--text \"{text}\" '\\\n", 76 | " f'--checkpoint {checkpoint_path} '\\\n", 77 | " f'--temp {temp}'\n", 78 | " \n", 79 | " # Execute command\n", 80 | " return os.system(cmd)\n", 81 | " \n", 82 | " finally:\n", 83 | " # Always return to original directory\n", 84 | " os.chdir(original_dir)\n", 85 | "\n", 86 | "# Example usage:\n", 87 | "# run_llama_generate(\n", 88 | "# text=\"Write a short story about a cat\",\n", 89 | "# temp=0.2,\n", 90 | "# working_dir=\"/path/to/your/project\"\n", 91 | "# )" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "import numpy as np\n", 101 | "import torch\n", 102 | "from IPython.display import Audio, display\n", 103 | "\n", 104 | "# run_llama_generate(\n", 105 | "# text=\"Here's how Bob talks, here's what language is, now speak like Bob saying this new thing\",\n", 106 | "# temp=0.05\n", 107 | "# )\n", 108 | "# Load and process the data\n", 109 | "test_arr = np.load(\"../../out.npy\")\n", 110 | "test_input = torch.from_numpy(test_arr[:,:200]).to(device).to(torch.long)\n", 111 | "print(test_input.shape)\n", 112 | "\n", 113 | "# Generate audio\n", 114 | "out_pcm = model.decode(test_input)\n", 115 | "\n", 116 | "# Convert to CPU and get numpy array for playback\n", 117 | "audio_data = out_pcm.audio_values[0].detach().to(\"cpu\").numpy()\n", 118 | "\n", 119 | "# Create and display audio widget\n", 120 | "# Note: sample_rate=24000 matches your original save command\n", 121 | "display(Audio(audio_data, rate=24000, autoplay=False))" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "test_input[0, 0, :]" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "import numpy as np\n", 140 | "\n", 141 | "pcm = load_and_process_wav(\"../../fish-speech.rs/voices/nova.wav\")\n", 142 | "codes = model.encode(pcm.to(\"cuda\").unsqueeze(0))\n", 143 | "np.save(\"nova.npy\", codes[\"audio_codes\"].squeeze(0)[:8, :].cpu().numpy())" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "codes[\"audio_codes\"].squeeze(0)[:8,:].shape" 153 | ] 154 | } 155 | ], 156 | "metadata": { 157 | "kernelspec": { 158 | "display_name": ".venv", 159 | "language": "python", 160 | "name": "python3" 161 | }, 162 | "language_info": { 163 | "codemirror_mode": { 164 | "name": "ipython", 165 | "version": 3 166 | }, 167 | "file_extension": ".py", 168 | "mimetype": "text/x-python", 169 | "name": "python", 170 | "nbconvert_exporter": "python", 171 | "pygments_lexer": "ipython3", 172 | "version": "3.9.6" 173 | } 174 | }, 175 | "nbformat": 4, 176 | "nbformat_minor": 4 177 | } 178 | -------------------------------------------------------------------------------- /data_pipeline/notebooks/encode_ljspeech.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from datasets import load_dataset, Audio\n", 10 | "\n", 11 | "SAMPLING_RATE=24_000\n", 12 | "# Load the LJ Speech dataset\n", 13 | "dataset = load_dataset(\"MikhailT/lj-speech\")\n", 14 | "dataset = dataset.cast_column(\"audio\", Audio(sampling_rate=SAMPLING_RATE))\n", 15 | "dataset = dataset.with_format(\"torch\")\n", 16 | "len(dataset[\"full\"])" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "from data_pipeline.utils.codec import MimiCodec\n", 26 | "\n", 27 | "mimi_model = MimiCodec()" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "first_items = dataset[\"full\"][0:16]\n", 37 | "audios = [row[\"array\"] for row in first_items[\"audio\"]]\n", 38 | "wavs = mimi_model.encode_batch(audios)" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "[l.shape for l in wavs]" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "from IPython import display\n", 57 | "\n", 58 | "out_pcm = mimi_model.decode(wavs[15])\n", 59 | "\n", 60 | "# Convert to CPU and get numpy array for playback\n", 61 | "audio_data = out_pcm.numpy()\n", 62 | "\n", 63 | "# Create and display audio widget\n", 64 | "# Note: sample_rate=24000 matches your original save command\n", 65 | "display.display(display.Audio(audio_data, rate=24000, autoplay=False))" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "def batch_wav_encoder(batch):\n", 75 | " audios = [audio[\"array\"] for audio in batch[\"audio\"]]\n", 76 | " return {\n", 77 | " \"codes\": mimi_model.encode_batch(audios)\n", 78 | " }\n", 79 | "\n", 80 | "dataset = dataset.map(batch_wav_encoder, batched=True, batch_size=24, remove_columns=[\"audio\"])" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "dataset.column_names" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "dataset = dataset.rename_column(original_column_name=\"normalized_text\", new_column_name=\"text_normalized\")\n", 99 | "dataset.save_to_disk(\"../../datasets/encoded_ljspeech\")" 100 | ] 101 | } 102 | ], 103 | "metadata": { 104 | "kernelspec": { 105 | "display_name": ".venv", 106 | "language": "python", 107 | "name": "python3" 108 | }, 109 | "language_info": { 110 | "codemirror_mode": { 111 | "name": "ipython", 112 | "version": 3 113 | }, 114 | "file_extension": ".py", 115 | "mimetype": "text/x-python", 116 | "name": "python", 117 | "nbconvert_exporter": "python", 118 | "pygments_lexer": "ipython3", 119 | "version": "3.9.21" 120 | } 121 | }, 122 | "nbformat": 4, 123 | "nbformat_minor": 4 124 | } 125 | -------------------------------------------------------------------------------- /data_pipeline/notebooks/test_emilia.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datasets import load_dataset, Dataset 3 | from data_pipeline.utils.codec import MimiCodec 4 | from dotenv import load_dotenv 5 | import os 6 | import shutil 7 | import torch 8 | from itertools import islice 9 | from torchaudio.transforms import Resample 10 | 11 | DATASET_SAMPLING_RATE = 24_000 12 | CHUNK_SIZE = 5 # Tweak based on RAM 13 | 14 | downsample_16k = Resample(orig_freq=DATASET_SAMPLING_RATE) 15 | 16 | 17 | def chunked(iterable, size): 18 | """Yield successive chunks from iterable of given size.""" 19 | it = iter(iterable) 20 | while chunk := list(islice(it, size)): # Python 3.8+ (walrus op) 21 | yield chunk 22 | 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser(description="Shard processing script") 26 | parser.add_argument( 27 | "--num-shards", 28 | type=int, 29 | default=100, 30 | help="Total number of shards (default: 100)", 31 | ) 32 | parser.add_argument( 33 | "--skip-shards", 34 | type=int, 35 | default=0, 36 | help="Number of shards to skip (default: 0)", 37 | ) 38 | return parser.parse_args() 39 | 40 | 41 | def main(): 42 | load_dotenv() 43 | args = parse_args() 44 | 45 | codec = MimiCodec() 46 | dataset_dir_base = os.path.expanduser("~/local_datasets/emilia_chunks") 47 | 48 | os.makedirs(dataset_dir_base, exist_ok=True) 49 | NUM_SHARDS = args.num_shards 50 | # NUM_SHARDS = 1 51 | SKIP_SHARDS = args.skip_shards 52 | 53 | for idx, chunk in enumerate( 54 | chunked(range(SKIP_SHARDS, NUM_SHARDS + SKIP_SHARDS), CHUNK_SIZE) 55 | ): 56 | print( 57 | f"\n🟢 Processing chunk {idx + 1}/{(NUM_SHARDS // CHUNK_SIZE) + 1}: {chunk}" 58 | ) 59 | 60 | paths = [f"Emilia/EN/EN-B00{i:04d}.tar" for i in chunk] 61 | 62 | print(f"📥 Downloading {len(paths)} files...") 63 | dataset = load_dataset( 64 | "amphion/Emilia-Dataset", 65 | data_files=paths, 66 | split="train", 67 | token=os.getenv("HUGGINGFACE_TOKEN"), 68 | ) 69 | dataset = dataset.with_format("pt") 70 | # dataset = dataset.take(500) 71 | 72 | def encode_batch(batch): 73 | audio = [a["array"] for a in batch["mp3"]] 74 | encoded = codec.encode_batch(audio) 75 | return {"codes": encoded} 76 | 77 | # Process & Save 78 | dataset = dataset.map( 79 | encode_batch, batched=True, batch_size=24, remove_columns=["mp3"] 80 | ) 81 | save_path = os.path.join( 82 | dataset_dir_base, f"shard_{chunk[0]:04d}_{chunk[-1]:04d}" 83 | ) 84 | dataset.save_to_disk(save_path) 85 | print(f"💾 Saved chunk {idx + 1} to {save_path}") 86 | 87 | # Nuke cache 88 | dataset_dir = os.path.expanduser( 89 | "~/.cache/huggingface/datasets/amphion___emilia-dataset" 90 | ) 91 | shutil.rmtree(dataset_dir, ignore_errors=True) 92 | hub_dir = os.path.expanduser( 93 | "~/.cache/huggingface/hub/datasets--amphion--Emilia-Dataset" 94 | ) 95 | shutil.rmtree(hub_dir, ignore_errors=True) 96 | print(f"🔥 Cleared cache for chunk {idx + 1}") 97 | 98 | print("✅ ALL CHUNKS PROCESSED.") 99 | 100 | 101 | if __name__ == "__main__": 102 | main() 103 | -------------------------------------------------------------------------------- /data_pipeline/notebooks/test_tokenization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from datasets import load_from_disk\n", 10 | "from transformers import AutoTokenizer\n", 11 | "\n", 12 | "ds = load_from_disk(\"../../datasets/tokenized_project_gutenberg_bytes_kokoro\")" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "tokenizer = AutoTokenizer.from_pretrained(\"../../inits/smoltts_byte_kokoro\")\n", 22 | "ds[\"train\"][40]" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "tokenizer.decode(ds[\"train\"][40][\"tokens\"][0,:])" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "ds[\"train\"][48][\"tokens\"].shape" 41 | ] 42 | } 43 | ], 44 | "metadata": { 45 | "kernelspec": { 46 | "display_name": ".venv", 47 | "language": "python", 48 | "name": "python3" 49 | }, 50 | "language_info": { 51 | "codemirror_mode": { 52 | "name": "ipython", 53 | "version": 3 54 | }, 55 | "file_extension": ".py", 56 | "mimetype": "text/x-python", 57 | "name": "python", 58 | "nbconvert_exporter": "python", 59 | "pygments_lexer": "ipython3", 60 | "version": "3.9.6" 61 | } 62 | }, 63 | "nbformat": 4, 64 | "nbformat_minor": 2 65 | } 66 | -------------------------------------------------------------------------------- /data_pipeline/notebooks/upload_libritts.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from datasets import load_from_disk, concatenate_datasets\n", 10 | "from transformers import MimiModel, AutoFeatureExtractor\n", 11 | "\n", 12 | "feature_extractor = AutoFeatureExtractor.from_pretrained(\"kyutai/mimi\")\n", 13 | "model = MimiModel.from_pretrained(\"kyutai/mimi\")\n", 14 | "model = model.to(\"cuda\")" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "NOTE: dataset creation script is elsewhere.\n", 22 | "\n", 23 | "If you are Jacob Keisling, please create a file `.env` in this repo and put your HuggingFace token under the `HUGGINGFACE_TOKEN=sk-xxxx` variable. Since you are probably not Jacob Keisling, if you want to re-upload, please pick a repo you actually own." 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "from datasets import load_from_disk\n", 33 | "from dotenv import load_dotenv\n", 34 | "import os\n", 35 | "load_dotenv()\n", 36 | "\n", 37 | "# If you created this from scratch\n", 38 | "dataset = load_from_disk(\"./encoded_libritts\")\n", 39 | "dataset = dataset.with_format(\"torch\")\n", 40 | "dataset = dataset.rename_column(original_column_name=\"sentences\", new_column_name=\"text_normalized\")\n", 41 | "# Uncomment if this is your first time pushing your new dataset\n", 42 | "dataset.push_to_hub(\"jkeisling/libritts-r-mimi\", token=os.getenv(\"HUGGINGFACE_TOKEN\"))" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "import torch\n", 52 | "import torchaudio\n", 53 | "\n", 54 | "codes = dataset['dev.clean'][0]['codes']\n", 55 | "codes" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "codes_input = codes.to('cuda').to(torch.long).unsqueeze(0)\n", 65 | "out_pcm = model.decode(codes_input)\n", 66 | "torchaudio.save(\"out.wav\", out_pcm.audio_values[0].to(\"cpu\"), 24000)" 67 | ] 68 | } 69 | ], 70 | "metadata": { 71 | "kernelspec": { 72 | "display_name": ".venv", 73 | "language": "python", 74 | "name": "python3" 75 | }, 76 | "language_info": { 77 | "codemirror_mode": { 78 | "name": "ipython", 79 | "version": 3 80 | }, 81 | "file_extension": ".py", 82 | "mimetype": "text/x-python", 83 | "name": "python", 84 | "nbconvert_exporter": "python", 85 | "pygments_lexer": "ipython3", 86 | "version": "3.9.6" 87 | } 88 | }, 89 | "nbformat": 4, 90 | "nbformat_minor": 2 91 | } 92 | -------------------------------------------------------------------------------- /data_pipeline/preview/app.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, load_from_disk 2 | from data_pipeline.utils.codec import MimiCodec 3 | import gradio as gr 4 | import json 5 | import os 6 | from pathlib import Path 7 | from pydantic import BaseModel 8 | import torch 9 | from typing import List, Tuple 10 | 11 | 12 | class DatasetConfig(BaseModel): 13 | dataset_path: Path 14 | sample_rate: int = 16000 15 | batch_size: int = 5 16 | split: str 17 | 18 | 19 | def load_config(config_path: str = "decoder_config.json") -> DatasetConfig: 20 | try: 21 | with open(config_path, "r") as f: 22 | config_data = json.load(f) 23 | # Brute force resolve relative path from script location 24 | config_data["dataset_path"] = os.path.abspath( 25 | os.path.join(os.path.dirname(__file__), config_data["dataset_path"]) 26 | ) 27 | return DatasetConfig(**config_data) 28 | except FileNotFoundError: 29 | config = DatasetConfig( 30 | dataset_path=Path("./data/my_dataset.hf"), 31 | sample_rate=24000, 32 | batch_size=5, 33 | split="full", 34 | ) 35 | with open(config_path, "w") as f: 36 | json.dump(config.model_dump(), f, indent=2) 37 | print(f"Created default config at {config_path}") 38 | return config 39 | 40 | 41 | # Rest of your gradio app 42 | # Load config at startup 43 | CONFIG = load_config() 44 | 45 | # # Load dataset in streaming mode at startup 46 | dataset = load_from_disk( 47 | dataset_path=str(CONFIG.dataset_path / CONFIG.split), 48 | keep_in_memory=False, 49 | ).shuffle() 50 | if "sentences" in dataset.column_names: 51 | dataset = dataset.rename_column("sentences", "text_normalized") 52 | dataset_iter = iter(dataset) 53 | mimi_model = MimiCodec() 54 | 55 | 56 | def get_random_samples(n: int = 5) -> List[Tuple[str, torch.Tensor]]: 57 | """Grab n random samples from our dataset""" 58 | n = n or CONFIG.batch_size 59 | samples = [] 60 | for _ in range(n): 61 | sample = next(dataset_iter) 62 | samples.append((sample["text_normalized"], torch.tensor(sample["codes"]))) 63 | return samples 64 | 65 | 66 | def decode_and_display(): 67 | """Get random samples and decode them""" 68 | samples = get_random_samples() 69 | outputs = [] 70 | 71 | for text, codes in samples: 72 | # YOUR DECODE LOGIC HERE 73 | # Assuming codes is [8, seqlen] tensor of hierarchical codes 74 | # Return whatever format Gradio Audio widget expects 75 | 76 | # Placeholder for your decode logic 77 | fake_audio = mimi_model.decode(codes).squeeze(0) 78 | 79 | outputs.extend( 80 | [ 81 | text, 82 | gr.Audio( 83 | value=(CONFIG.sample_rate, fake_audio.numpy()), 84 | type="numpy", 85 | ), 86 | ] 87 | ) 88 | 89 | return outputs 90 | 91 | 92 | with gr.Blocks() as demo: 93 | gr.Markdown( 94 | f""" 95 | ## Quick Mimi Validation 96 | Loading from: {CONFIG.dataset_path} 97 | """ 98 | ) 99 | 100 | with gr.Row(): 101 | roll_btn = gr.Button("🎲 Roll Random Samples", variant="primary") 102 | 103 | # Create N rows of Text + Audio pairs 104 | display_rows = [] 105 | for i in range(CONFIG.batch_size): 106 | with gr.Row(): 107 | text = gr.Textbox(label=f"Text {i+1}", lines=2) 108 | audio = gr.Audio(label=f"Audio {i+1}") 109 | display_rows.extend([text, audio]) 110 | 111 | roll_btn.click(fn=decode_and_display, outputs=display_rows) 112 | 113 | # Launch with a decent sized queue for rapid checking 114 | # demo.queue(max_size=10).launch() 115 | demo.launch(server_name="0.0.0.0") 116 | -------------------------------------------------------------------------------- /data_pipeline/preview/decoder_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_path": "../../../Kokoro-82M/project_gutenberg_mimi_kokoro", 3 | "sample_rate": 24000, 4 | "batch_size": 5, 5 | "split": "." 6 | } -------------------------------------------------------------------------------- /data_pipeline/scripts/audio_tokenizer_configs/emilia_v1.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_id": "jkeisling/emilia_en_mimi", 3 | "speaker": { 4 | "strategy": "omit" 5 | }, 6 | "tokenization": { 7 | "tokenizer_path": "inits/smoltts_byte_kokoro", 8 | "strategy": "bytelevel", 9 | "duplicate_code_0": false 10 | }, 11 | "audio": { 12 | "frame_rate": 12.5, 13 | "max_sample_secs": 17.5 14 | }, 15 | "packing": { 16 | "max_sequence_length": 640, 17 | "max_items_per_pack": 4, 18 | "window_size": 1600 19 | } 20 | } -------------------------------------------------------------------------------- /data_pipeline/scripts/audio_tokenizer_configs/project_gutenberg_v1.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_id": "jkeisling/project-gutenberg-kokoro-2K", 3 | "speaker": { 4 | "strategy": "id_token", 5 | "speaker_names": [ 6 | "default", 7 | "sarah", 8 | "sky", 9 | "adam", 10 | "emma", 11 | "isabella", 12 | "george", 13 | "lewis" 14 | ] 15 | }, 16 | "tokenization": { 17 | "tokenizer_path": "inits/smoltts_byte_kokoro", 18 | "strategy": "bytelevel" 19 | }, 20 | "audio": { 21 | "frame_rate": 12.5, 22 | "max_sample_secs": 15.0 23 | } 24 | } -------------------------------------------------------------------------------- /data_pipeline/scripts/audio_tokenizer_configs/project_gutenberg_v2.1.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_id": "jkeisling/projectgutenberg-kokoro_v1-mimi", 3 | "speaker": { 4 | "strategy": "id_token", 5 | "speaker_names": [ 6 | "af_heart", 7 | "af_bella", 8 | "af_nova", 9 | "af_sky", 10 | "af_sarah", 11 | "am_michael", 12 | "am_puck", 13 | "am_liam", 14 | "bf_emma", 15 | "bf_isabella", 16 | "bm_fable" 17 | ] 18 | }, 19 | "tokenization": { 20 | "tokenizer_path": "inits/smoltts_byte_kokoro", 21 | "strategy": "bytelevel", 22 | "duplicate_code_0": false 23 | }, 24 | "audio": { 25 | "frame_rate": 12.5, 26 | "max_sample_secs": 20.0 27 | }, 28 | "packing": { 29 | "max_sequence_length": 768, 30 | "max_items_per_pack": 5, 31 | "window_size": 1600 32 | } 33 | } -------------------------------------------------------------------------------- /data_pipeline/scripts/audio_tokenizer_configs/project_gutenberg_v2.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_id": "jkeisling/projectgutenberg-kokoro_v1-mimi", 3 | "speaker": { 4 | "strategy": "id_token", 5 | "speaker_names": [ 6 | "af_heart", 7 | "af_bella", 8 | "af_nova", 9 | "af_sky", 10 | "af_sarah", 11 | "am_michael", 12 | "am_puck", 13 | "am_liam", 14 | "bf_emma", 15 | "bf_isabella", 16 | "bm_fable" 17 | ] 18 | }, 19 | "tokenization": { 20 | "tokenizer_path": "inits/smoltts_byte_kokoro", 21 | "strategy": "bytelevel" 22 | }, 23 | "audio": { 24 | "frame_rate": 12.5, 25 | "max_sample_secs": 15.0 26 | }, 27 | "packing": { 28 | "max_sequence_length": 768, 29 | "max_items_per_pack": 5, 30 | "window_size": 1600 31 | } 32 | } -------------------------------------------------------------------------------- /data_pipeline/scripts/bpe_tokenizer_configs/kokoro_bytelevel.json: -------------------------------------------------------------------------------- 1 | { 2 | 3 | } -------------------------------------------------------------------------------- /data_pipeline/scripts/chatml_tokenize_dataset.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from dotenv import load_dotenv 3 | from datasets import load_dataset, load_from_disk, concatenate_datasets, DatasetDict 4 | import json 5 | import os 6 | from pydantic import BaseModel, Field 7 | import torch 8 | from transformers import AutoTokenizer 9 | from typing import Dict, Optional, List, Literal 10 | 11 | 12 | from data_pipeline.utils.prompt import PromptEncoder, TokenizationConfig 13 | 14 | 15 | class TokenizationStrategy(BaseModel): 16 | tokenizer_path: str 17 | strategy: Literal["bpe", "bytelevel", "phoneme", "hybrid"] 18 | duplicate_code_0: Optional[bool] = True 19 | 20 | 21 | class AudioConfig(BaseModel): 22 | frame_rate: float = Field(default=12.5) 23 | max_sample_secs: float = Field(default=15.0) 24 | 25 | 26 | class SpeakerStrategy(BaseModel): 27 | strategy: Literal["id_token", "fixed", "omit"] 28 | speaker_names: Optional[List[str]] = Field(default=None) 29 | default_sysprompt: Optional[str] = Field(default=None) 30 | 31 | 32 | class PackingStrategy(BaseModel): 33 | max_sequence_length: int = Field(default=768) 34 | max_items_per_pack: int = Field(default=5) 35 | window_size: int = Field(default=1600) 36 | 37 | 38 | class Config(BaseModel): 39 | dataset_id: Optional[str] = Field(default=None) 40 | dataset_path: Optional[str] = Field(default=None) 41 | tokenization: TokenizationStrategy 42 | speaker: SpeakerStrategy 43 | audio: AudioConfig 44 | packing: Optional[PackingStrategy] = Field(default=None) 45 | 46 | 47 | class SyspromptEncoder: 48 | default_sysprompt: Optional[torch.Tensor] = None 49 | speaker_cache: Optional[Dict[str, torch.Tensor]] = None 50 | 51 | def __init__(self, dataset_config: Config, prompt_encoder: PromptEncoder): 52 | self.dataset_config = dataset_config 53 | self.prompt_encoder = PromptEncoder 54 | if dataset_config.speaker.default_sysprompt is not None: 55 | # One single sysprompt 56 | self.default_sysprompt = prompt_encoder.encode_text_turn( 57 | role="system", 58 | content=dataset_config.speaker.default_sysprompt, 59 | add_generation_prompt=False, 60 | ) 61 | elif dataset_config.speaker.speaker_names is not None: 62 | # Precompute speaker prompt cache if we have a known small subset 63 | self.speaker_cache = { 64 | speaker_name: prompt_encoder.encode_text_turn( 65 | role="system", 66 | content=f"<|speaker:{id}|>", 67 | add_generation_prompt=False, 68 | ) 69 | for id, speaker_name in enumerate(dataset_config.speaker.speaker_names) 70 | } 71 | 72 | def get_sysprompt_length(self, speaker_id: str) -> int: 73 | if self.default_sysprompt is not None: 74 | # Fixed 75 | return self.default_sysprompt.size(-1) 76 | elif self.speaker_cache is not None: 77 | # Speaker ID from known set 78 | return self.speaker_cache[speaker_id].size(-1) 79 | else: 80 | # TODO handle arbitrary token length 81 | return 0 82 | 83 | def add_sysprompt( 84 | self, ground_truth: torch.Tensor, speaker_id: str 85 | ) -> torch.Tensor: 86 | if self.dataset_config.speaker.strategy == "omit": 87 | return ground_truth 88 | else: 89 | if self.default_sysprompt is not None: 90 | speaker_entry = self.default_sysprompt 91 | elif self.speaker_cache is not None: 92 | speaker_entry = self.speaker_cache[speaker_id] 93 | else: 94 | raise ValueError( 95 | f"Must have default syprompt or IDs, current strategy: {self.dataset_config.speaker.strategy}" 96 | ) 97 | 98 | return torch.cat([speaker_entry, ground_truth], dim=1) 99 | 100 | 101 | def tts_tokenize_row( 102 | row: Dict, 103 | prompt_encoder: PromptEncoder, 104 | dataset_config: Config, 105 | ): 106 | """ 107 | NOTE: unlike the notebook, this does NOT handle 108 | - Speaker prompt 109 | - Causal shift 110 | """ 111 | user_line = prompt_encoder.encode_text_turn( 112 | role="user", 113 | content=row["text_normalized"].encode("utf-8").decode("latin-1") 114 | if dataset_config.tokenization.strategy == "bpe" 115 | else row["text_normalized"], 116 | add_generation_prompt=True, 117 | ) 118 | assistant_line = prompt_encoder.encode_vq(row["codes"]) 119 | 120 | ground_truth = torch.cat([user_line, assistant_line], dim=1) 121 | 122 | return { 123 | "ground_truth": ground_truth.clone(), 124 | } 125 | 126 | 127 | def causal_shift_row(row): 128 | tokens = row["ground_truth"][:, :-1].clone() 129 | labels = row["ground_truth"][:, 1:].clone() 130 | 131 | text_only_mask = labels[1:, :] == 0 132 | labels[1:, :][text_only_mask] = -100 133 | return {"tokens": tokens, "labels": labels} 134 | 135 | 136 | def pack_utterances(batch: Dict, sysprompt_encoder: SyspromptEncoder): 137 | # Group utterances by speaker 138 | speakers = {} 139 | 140 | for speaker, tokens in zip(batch["speaker_id"], batch["ground_truth"]): 141 | if speaker not in speakers: 142 | speakers[speaker] = [] 143 | speakers[speaker].append(tokens) 144 | 145 | # Greedy packing per speaker (First-fit decreasing) 146 | for speaker in speakers: 147 | speakers[speaker].sort(key=lambda x: x.size(-1), reverse=True) 148 | 149 | packed_bins = [] 150 | packed_ids = [] 151 | for speaker, utterances in speakers.items(): 152 | sysprompt_length = sysprompt_encoder.get_sysprompt_length(speaker_id=speaker) 153 | bins = [] 154 | for utterance in utterances: 155 | placed = False 156 | for i in range(len(bins)): 157 | if ( 158 | bins[i].size(-1) + utterance.size(-1) + sysprompt_length 159 | <= sysprompt_encoder.dataset_config.packing.max_sequence_length 160 | ): 161 | bins[i] = torch.cat([bins[i], utterance], dim=1) 162 | placed = True 163 | break 164 | if not placed: 165 | bins.append(utterance) 166 | 167 | packed_bins += bins 168 | packed_ids += [speaker] * len(bins) 169 | 170 | packed_bins = [ 171 | sysprompt_encoder.add_sysprompt(seq, speaker_id) 172 | for seq, speaker_id in zip(packed_bins, packed_ids) 173 | ] 174 | 175 | return {"ground_truth": packed_bins, "speaker_id": packed_ids} 176 | 177 | 178 | parser = ArgumentParser( 179 | description="Tokenize Mimi-encoded dataset for final consumption" 180 | ) 181 | parser.add_argument("-c", "--config", type=str, required=True) 182 | parser.add_argument( 183 | "-o", "--out-path", type=str, required=True, help="Local path of dataset output" 184 | ) 185 | parser.add_argument("--shards", type=int) 186 | 187 | 188 | # TODO configure this 189 | NUM_PROC = 12 190 | 191 | 192 | def main(): 193 | args = parser.parse_args() 194 | 195 | with open(args.config) as f: 196 | config_dict = json.load(f) 197 | dataset_config = Config(**config_dict) 198 | 199 | load_dotenv() 200 | if dataset_config.dataset_path: 201 | dataset = load_from_disk(dataset_config.dataset_path) 202 | elif dataset_config.dataset_id: 203 | dataset = load_dataset( 204 | dataset_config.dataset_id, token=os.getenv("HUGGINGFACE_TOKEN") 205 | ) 206 | else: 207 | raise ValueError("Neither dataset_id nor dataset_path specified in config!") 208 | 209 | print("Loaded dataset") 210 | dataset = dataset.with_format("torch") 211 | if "text" in dataset["train"].column_names: 212 | dataset = dataset.rename_column("text", "text_normalized") 213 | if "speaker" in dataset["train"].column_names: 214 | dataset = dataset.rename_column("speaker", "speaker_id") 215 | 216 | tokenizer = AutoTokenizer.from_pretrained( 217 | dataset_config.tokenization.tokenizer_path 218 | ) 219 | tokenizer.use_default_system_prompt = False 220 | tokenization_config = TokenizationConfig( 221 | duplicate_code_0=dataset_config.tokenization.duplicate_code_0 222 | ) 223 | prompt_encoder = PromptEncoder(tokenizer, tokenization_config) 224 | sysprompt_encoder = SyspromptEncoder( 225 | dataset_config=dataset_config, prompt_encoder=prompt_encoder 226 | ) 227 | 228 | n_shards = args.shards if args.shards is not None else 1 229 | 230 | full_dataset = dataset 231 | completed = [] 232 | for i in range(n_shards): 233 | dataset = full_dataset["train"].shard(n_shards, i) 234 | print(f"Filtering rows above {dataset_config.audio.max_sample_secs}s") 235 | dataset = dataset.filter( 236 | lambda row: row["codes"].size(-1) 237 | <= dataset_config.audio.frame_rate * dataset_config.audio.max_sample_secs, 238 | num_proc=NUM_PROC, 239 | ) 240 | 241 | print("Tokenizing dataset") 242 | dataset = dataset.map( 243 | lambda row: tts_tokenize_row(row, prompt_encoder, dataset_config), 244 | remove_columns="codes", 245 | num_proc=NUM_PROC, 246 | ) 247 | 248 | if dataset_config.packing is not None: 249 | print("Packing sequence") 250 | dataset = dataset.map( 251 | lambda row: pack_utterances(row, sysprompt_encoder), 252 | batched=True, 253 | batch_size=dataset_config.packing.window_size, 254 | num_proc=NUM_PROC, 255 | remove_columns=dataset.column_names, 256 | ) 257 | 258 | completed.append(dataset) 259 | 260 | dataset = concatenate_datasets(completed) 261 | 262 | # print("Adding system prompt") 263 | # dataset = dataset.map(sysprompt_encoder.add_sysprompt, num_proc=NUM_PROC) 264 | # print("Causally shifting tokens, masking text-only") 265 | # dataset = dataset.map( 266 | # causal_shift_row, num_proc=NUM_PROC, remove_columns=["ground_truth"] 267 | # ) 268 | dataset = DatasetDict({"train": dataset}) 269 | 270 | dataset.save_to_disk(args.out_path, max_shard_size="5GB") 271 | 272 | 273 | if __name__ == "__main__": 274 | main() 275 | -------------------------------------------------------------------------------- /data_pipeline/scripts/create_bytelevel_init.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import json 3 | import os 4 | from pydantic import BaseModel, Field 5 | from tokenizers import Tokenizer, models, decoders 6 | from tokenizers.trainers import BpeTrainer 7 | from transformers import PreTrainedTokenizerFast 8 | 9 | CHATML_TEMPLATE = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" 10 | 11 | class Config(BaseModel): 12 | codebook_size: int = Field(default=2048) 13 | 14 | 15 | def get_blank_tokenizer() -> Tokenizer: 16 | tokenizer = Tokenizer(models.BPE()) 17 | trainer = BpeTrainer(vocab_size=256, special_tokens=[]) 18 | 19 | # Create actual bytes 20 | byte_data = [bytes([i]) for i in range(256)] 21 | # Preserve the actual byte values in strings 22 | byte_strings = [b.decode('latin-1') for b in byte_data] 23 | 24 | # "Train" the tokenizer 25 | tokenizer.train_from_iterator(byte_strings, trainer=trainer) 26 | tokenizer.pre_tokenizer = None 27 | tokenizer.normalizer = None 28 | tokenizer.decoder = decoders.ByteLevel() 29 | 30 | return tokenizer 31 | 32 | def add_special_tokens(tokenizer: Tokenizer, config: Config) -> Tokenizer: 33 | semantic_tokens = [f"<|semantic:{i}|>" for i in range(config.codebook_size)] 34 | control_tokens = [ 35 | "system", 36 | "user", 37 | "assistant", 38 | "<|british|>", 39 | "<|american|>", 40 | "<|male|>", 41 | "<|female|>", 42 | "<|unknown|>", 43 | "<|endoftext|>", 44 | "<|voice|>", 45 | "<|semantic|>", 46 | "<|pad|>", 47 | "<|epad|>", 48 | "<|im_start|>", 49 | "<|im_end|>", 50 | ] 51 | 52 | speaker_id_tokens = [f"<|speaker:{i}|>" for i in range(64 - len(control_tokens))] 53 | charset = [*control_tokens, *speaker_id_tokens, *semantic_tokens] 54 | 55 | tokenizer.add_special_tokens(charset) 56 | 57 | return tokenizer 58 | 59 | parser = ArgumentParser(description="Create BPE tokenizer for Kokoro-style bounded speaker models") 60 | parser.add_argument( 61 | "--config-file", 62 | type=str, 63 | ) 64 | parser.add_argument( 65 | "--out-dir", "-o", 66 | required=True, 67 | help="Directory where tokenizer will be saved to" 68 | ) 69 | 70 | def main(): 71 | args = parser.parse_args() 72 | if args.config_file is not None: 73 | with open(args.config_file_path, "r") as f: 74 | config_data = json.load(f) 75 | 76 | config = Config(**config_data) 77 | else: 78 | # Default is fine 79 | config = Config() 80 | 81 | base_tokenizer = get_blank_tokenizer() 82 | tokenizer = add_special_tokens(base_tokenizer, config) 83 | 84 | final_tokenizer = PreTrainedTokenizerFast( 85 | tokenizer_object=tokenizer, 86 | bos_token="<|im_start|>", 87 | eos_token="<|endoftext|>", 88 | unk_token="<|unknown|>", 89 | pad_token="<|pad|>", 90 | chat_template=CHATML_TEMPLATE 91 | ) 92 | 93 | os.makedirs(args.out_dir, exist_ok=True) 94 | final_tokenizer.save_pretrained(args.out_dir) 95 | print("Saving complete!") 96 | 97 | if __name__ == "__main__": 98 | main() -------------------------------------------------------------------------------- /data_pipeline/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndlessReform/smoltts/410d6524966e2d433ce565e4f14f1fac2f983ff9/data_pipeline/utils/__init__.py -------------------------------------------------------------------------------- /data_pipeline/utils/codec.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.nn.utils.rnn import pad_sequence 4 | from transformers import MimiModel 5 | from typing import List 6 | 7 | SAMPLING_RATE = 24_000 8 | CODEC_HZ = 12.5 9 | 10 | 11 | def get_target_length( 12 | arr: torch.Tensor, sampling_rate=SAMPLING_RATE, codec_hz=CODEC_HZ 13 | ) -> int: 14 | return math.ceil(arr.size(-1) / (sampling_rate / codec_hz)) 15 | 16 | 17 | class MimiCodec: 18 | def __init__(self, model_name: str = "kyutai/mimi", device: str = "cuda"): 19 | model = MimiModel.from_pretrained(model_name) 20 | model = model.to(device) 21 | self.model = model 22 | self.device = device 23 | 24 | def decode(self, codes: torch.Tensor) -> torch.Tensor: 25 | if codes.ndim == 2: 26 | # Add spurious batch dimension 27 | codes = codes.unsqueeze(0) 28 | 29 | with torch.no_grad(): 30 | out_pcm = self.model.decode(codes.to(self.device)) 31 | out_tensor = out_pcm.audio_values[0].detach().to("cpu") 32 | # Trim final frame to prevent random artifacts 33 | return out_tensor[:, :-1] 34 | 35 | def encode(self, audio: torch.Tensor) -> torch.Tensor: 36 | if audio.ndim == 1: 37 | # Single mono audio 38 | audio = audio.unsqueeze(0).unsqueeze(0) 39 | elif audio.ndim == 2: 40 | # Dual channel 41 | audio = audio.unsqueeze(0) 42 | else: 43 | raise ValueError( 44 | f"Use batch endpoint to encode audio safely; got {audio.ndim} dims but expected channel, seqlen or seqlen" 45 | ) 46 | 47 | with torch.no_grad(): 48 | encoded = self.model.encode(audio.to(self.device)) 49 | codes = encoded.audio_codes[:, 0:8, :].clone().cpu() 50 | del encoded 51 | torch.cuda.empty_cache() 52 | return codes.squeeze(0) 53 | 54 | def encode_batch(self, audios: List[torch.Tensor]) -> List[torch.Tensor]: 55 | target_lengths = [get_target_length(arr) - 1 for arr in audios] 56 | # Add spurious channel dimension, audio should be mono 57 | padded_batch = pad_sequence(audios, batch_first=True).unsqueeze(1) 58 | padding_mask = (padded_batch != 0).float() 59 | 60 | with torch.no_grad(): 61 | encoder_outputs = self.model.encode( 62 | padded_batch.to(self.device), padding_mask=padding_mask.to(self.device) 63 | ) 64 | codes = encoder_outputs.audio_codes[:, 0:8, :].cpu() 65 | del padded_batch 66 | del encoder_outputs 67 | torch.cuda.empty_cache() 68 | 69 | chunked = list(torch.unbind(codes, dim=0)) 70 | return [t[:, :length] for t, length in zip(chunked, target_lengths)] 71 | -------------------------------------------------------------------------------- /data_pipeline/utils/prompt.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | import torch 3 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast 4 | from typing import Optional, Union 5 | 6 | 7 | class TokenizationConfig(BaseModel): 8 | num_codebooks: int = Field(default=8) 9 | acoustic_delay: int = Field(default=0) 10 | duplicate_code_0: Optional[bool] = Field(default=True) 11 | 12 | 13 | class PromptEncoder: 14 | tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] 15 | num_codebooks: int 16 | trailing_im_end: torch.Tensor 17 | semantic_offset: int 18 | 19 | def __init__( 20 | self, 21 | tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], 22 | config: TokenizationConfig, 23 | ): 24 | self.tokenizer = tokenizer 25 | self.config = config 26 | self.semantic_offset = tokenizer.encode("<|semantic:0|>")[0] 27 | self.pad_id = tokenizer.encode("<|pad|>")[0] 28 | zero_buffer = [0] * ( 29 | self.config.num_codebooks 30 | if self.config.duplicate_code_0 31 | else self.config.num_codebooks - 1 32 | ) 33 | self.trailing_im_end = torch.tensor( 34 | [ 35 | tokenizer.encode("<|im_end|>") + zero_buffer, 36 | tokenizer.encode("\n") + zero_buffer, 37 | ] 38 | ).T 39 | 40 | def get_lower_zeros(self, length: int) -> torch.Tensor: 41 | return torch.zeros( 42 | self.config.num_codebooks 43 | if self.config.duplicate_code_0 44 | else self.config.num_codebooks - 1, 45 | length, 46 | dtype=torch.long, 47 | ) 48 | 49 | def tokenize_text(self, text: str) -> torch.Tensor: 50 | turn_codes = ( 51 | self.tokenizer(text, return_tensors="pt").unsqueeze(0).to(torch.uint32) 52 | ) 53 | zeros_mask = self.get_lower_zeros(turn_codes.size(-1)) 54 | return torch.concat([turn_codes, zeros_mask], dim=0) 55 | 56 | def encode_text_turn( 57 | self, role: str, content: str, add_generation_prompt: bool = True 58 | ) -> torch.Tensor: 59 | baseline = self.tokenizer.apply_chat_template( 60 | [{"role": role, "content": content}], 61 | add_generation_prompt=add_generation_prompt, 62 | return_tensors="pt", 63 | ) 64 | zeros_mask = self.get_lower_zeros(baseline.size(-1)) 65 | return torch.cat([baseline, zeros_mask], dim=0) 66 | 67 | def encode_vq(self, codes: torch.Tensor) -> torch.Tensor: 68 | if codes.ndim != 2: 69 | raise ValueError("Must be single batch") 70 | 71 | semantic_line = (codes[0, :] + self.semantic_offset).unsqueeze(0) 72 | lower_codes = codes if self.config.duplicate_code_0 else codes[1:, :] 73 | 74 | # TODO DO NOT MERGE, WRONG BAD 75 | if self.config.acoustic_delay != 0: 76 | semantic_suffix = torch.tensor( 77 | [self.pad_id] * self.config.acoustic_delay, dtype=torch.uint32 78 | ).unsqueeze(0) 79 | lower_codes_prefix = self.get_lower_zeros(self.config.acoustic_delay) 80 | semantic_line = torch.cat([semantic_line, semantic_suffix], dim=1) 81 | lower_codes = torch.cat([lower_codes_prefix, lower_codes], dim=1) 82 | 83 | vq_block = torch.cat([semantic_line, lower_codes]) 84 | block = torch.cat([vq_block, self.trailing_im_end], dim=1) 85 | return block 86 | 87 | def encode_vq_corrupt(self, codes: torch.Tensor, dropout=0.2) -> torch.Tensor: 88 | """ 89 | NO temporal delays or offsetting. 90 | 91 | Corrupts only the non-semantic codes. 92 | """ 93 | if codes.ndim != 2: 94 | raise ValueError("Must be single batch!") 95 | 96 | semantic_line = (codes[0, :] + self.semantic_offset).unsqueeze(0) 97 | first_residual = codes[0, :].unsqueeze(0) 98 | remaining_codes = codes[1:, :] 99 | 100 | corrupt_mask = torch.rand_like(remaining_codes.float()) < dropout 101 | 102 | # TODO: parameterize for 1024-size codebook 103 | random_codes = torch.randint( 104 | low=1, high=2048, size=remaining_codes.shape, device=remaining_codes.device 105 | ) 106 | 107 | corrupted_remaining_codes = torch.where( 108 | corrupt_mask, random_codes, remaining_codes 109 | ) 110 | vq_block = torch.cat([semantic_line, first_residual, corrupted_remaining_codes]) 111 | block = torch.cat([vq_block, self.trailing_im_end], dim=1) 112 | 113 | return block 114 | -------------------------------------------------------------------------------- /docs/dualar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndlessReform/smoltts/410d6524966e2d433ce565e4f14f1fac2f983ff9/docs/dualar.png -------------------------------------------------------------------------------- /docs/examples/bella.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndlessReform/smoltts/410d6524966e2d433ce565e4f14f1fac2f983ff9/docs/examples/bella.wav -------------------------------------------------------------------------------- /docs/examples/emma.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndlessReform/smoltts/410d6524966e2d433ce565e4f14f1fac2f983ff9/docs/examples/emma.wav -------------------------------------------------------------------------------- /docs/examples/fable.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndlessReform/smoltts/410d6524966e2d433ce565e4f14f1fac2f983ff9/docs/examples/fable.wav -------------------------------------------------------------------------------- /docs/examples/fenrir.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndlessReform/smoltts/410d6524966e2d433ce565e4f14f1fac2f983ff9/docs/examples/fenrir.wav -------------------------------------------------------------------------------- /docs/examples/heart.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndlessReform/smoltts/410d6524966e2d433ce565e4f14f1fac2f983ff9/docs/examples/heart.wav -------------------------------------------------------------------------------- /docs/examples/isabella.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndlessReform/smoltts/410d6524966e2d433ce565e4f14f1fac2f983ff9/docs/examples/isabella.wav -------------------------------------------------------------------------------- /docs/examples/liam.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndlessReform/smoltts/410d6524966e2d433ce565e4f14f1fac2f983ff9/docs/examples/liam.wav -------------------------------------------------------------------------------- /docs/examples/michael.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndlessReform/smoltts/410d6524966e2d433ce565e4f14f1fac2f983ff9/docs/examples/michael.wav -------------------------------------------------------------------------------- /docs/examples/nova.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndlessReform/smoltts/410d6524966e2d433ce565e4f14f1fac2f983ff9/docs/examples/nova.wav -------------------------------------------------------------------------------- /docs/examples/sarah.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndlessReform/smoltts/410d6524966e2d433ce565e4f14f1fac2f983ff9/docs/examples/sarah.wav -------------------------------------------------------------------------------- /docs/examples/sky.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndlessReform/smoltts/410d6524966e2d433ce565e4f14f1fac2f983ff9/docs/examples/sky.wav -------------------------------------------------------------------------------- /mlx_inference/README.md: -------------------------------------------------------------------------------- 1 | # smolltts-mlx 2 | 3 | ## Installation 4 | 5 | Requires working Python instance and Apple Silicon Mac. 6 | 7 | ```bash 8 | pip install smoltts-mlx 9 | ``` 10 | 11 | Or if you have [`uv`](https://docs.astral.sh/uv/) (hint hint), simply use [uvx](https://docs.astral.sh/uv/guides/tools/): 12 | 13 | ```bash 14 | ux --from smoltts_mlx smoltts-server 15 | ``` 16 | 17 | ## Server 18 | 19 | ### Startup 20 | 21 | From the CLI, run: 22 | 23 | ```bash 24 | smoltts-server 25 | ``` 26 | 27 | Options: 28 | 29 | - `--port` (optional): Port to listen on (default: 8000) 30 | - `--config` (optional): Point to a JSON file. (See below for spec) 31 | 32 | ### Supported voices 33 | 34 | As of February 2025, we support these voices from Kokoro: 35 | 36 | - **American:** heart (default), bella, nova, sky, sarah, michael, fenrir, liam 37 | - **British:** emma, isabella, fable 38 | 39 | Voice cloning is currently not supported, but coming soon! 40 | 41 | Unfortunately, GitHub doesn't support audio previews, but check out `docs/examples` for samples. 42 | 43 | ### ElevenLabs endpoints 44 | 45 | We support the following two ElevenLabs endpoints (more to come): 46 | 47 | - [`/v1/text-to-speech/$ID`](https://elevenlabs.io/docs/api-reference/text-to-speech/convert) with MP3, WAV, and PCM output 48 | - [`/v1/text-to-speech/$ID/stream`](https://elevenlabs.io/docs/api-reference/text-to-speech/convert-as-stream) (PCM-only for now) 49 | 50 | Here's an example with the Python SDK: 51 | 52 | ```python 53 | from elevenlabs.client import ElevenLabs 54 | 55 | client = ElevenLabs( 56 | # or wherever you're running this on 57 | base_url="http://localhost:8000", 58 | ) 59 | 60 | request_gen = client.text_to_speech.convert( 61 | voice_id="0", 62 | output_format="mp3_44100_128", 63 | text="You can turn on latency optimizations at some cost of quality. The best possible final latency varies by model.", 64 | ) 65 | ``` 66 | 67 | ### OpenAI endpoints 68 | 69 | We support `/v1/audio/speech` (MP3 and WAV). 70 | 71 | Here's an example with the [OpenAI Python SDK](https://platform.openai.com/docs/guides/text-to-speech#quickstart): 72 | 73 | ```python 74 | from pathlib import Path 75 | from openai import OpenAI 76 | 77 | client = OpenAI( 78 | base_url="http://localhost:8000" 79 | ) 80 | 81 | speech_file_path = Path(__file__).parent / "speech.mp3" 82 | response = client.audio.speech.create( 83 | model="tts-1", 84 | voice="alloy", 85 | input="Today is a wonderful day to build something people love!", 86 | ) 87 | response.stream_to_file(speech_file_path) 88 | ``` 89 | 90 | ### Configuration 91 | 92 | Default settings are stored by default at `~/Library/Cache/smoltts`. 93 | 94 | You can also specify a JSON file with `--config`. 95 | 96 | ```json 97 | { 98 | // "checkpoint_dir": "../inits/foobar/" 99 | "model_id": "jkeisling/smoltts_v0", 100 | "generation": { 101 | "default_temp": 0.0, 102 | "default_fast_temp": 0.5, 103 | "min_p": 0.1 104 | }, 105 | "model_type": { 106 | "family": "dual_ar", 107 | "codec": "mimi", 108 | "version": null 109 | } 110 | } 111 | ``` 112 | 113 | ## Library 114 | 115 | ### Basic Usage 116 | 117 | ```python 118 | from smoltts_mlx import SmolTTS 119 | from IPython.display import Audio 120 | 121 | # Initialize model (downloads weights automatically) 122 | model = SmolTTS() 123 | 124 | # Basic generation to numpy PCM array 125 | pcm = model("Hello world!") 126 | Audio(pcm, rate=model.sampling_rate) 127 | 128 | # Streaming generation for real-time audio 129 | for pcm_chunk in model.stream("This is a longer piece of text to stream."): 130 | # Yields 80ms PCM frames as they're generated 131 | process_audio(pcm_chunk) 132 | ``` 133 | 134 | ### Voice Selection 135 | 136 | ```python 137 | # Use a specific voice 138 | pcm = model("Hello!", voice="af_bella") 139 | 140 | # Create a custom voice from reference audio 141 | speaker_prompt = model.create_speaker( 142 | system_prompt="<|speaker:0|>", 143 | samples=[{ 144 | "text": "This is a sample sentence.", 145 | "audio": reference_audio # Numpy array of PCM data 146 | }] 147 | ) 148 | 149 | # Generate with custom voice 150 | pcm = model( 151 | "Using a custom voice created from reference audio.", 152 | speaker_prompt=speaker_prompt 153 | ) 154 | ``` 155 | 156 | ### Working with Audio 157 | 158 | The model works with raw PCM audio at 24kHz sample rate. For format conversion: 159 | 160 | ```python 161 | import soundfile as sf 162 | 163 | # Save to WAV 164 | sf.write("output.wav", pcm, model.sampling_rate) 165 | 166 | # Load reference audio 167 | audio, sr = sf.read("reference.wav") 168 | if sr != model.sampling_rate: 169 | # Resample if needed 170 | audio = resampy.resample(audio, sr, model.sampling_rate) 171 | ``` 172 | 173 | ### Advanced Configuration 174 | 175 | ```python 176 | # Use custom model weights 177 | model = SmolTTS( 178 | model_id="path/to/custom/model", 179 | checkpoint_dir="/path/to/local/weights" 180 | ) 181 | 182 | # Access underlying components 183 | mimi_codes = model.codec.encode(pcm) # Work with Mimi tokens directly 184 | ``` 185 | 186 | ### Performance Notes 187 | 188 | - First generation may be slower due to model loading and warmup 189 | - Use streaming for longer texts to begin playback before full generation 190 | 191 | ### Requirements 192 | 193 | - Apple Silicon Mac (M1/M2/M3) 194 | - Python 3.9 or later 195 | 196 | ## Developing locally 197 | 198 | Please use [uv](https://docs.astral.sh/uv/). 199 | 200 | From the root of this repo: 201 | 202 | ```bash 203 | uv sync --all-packages 204 | ``` 205 | -------------------------------------------------------------------------------- /mlx_inference/default_settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_id": "jkeisling/smoltts_v0", 3 | "generation": { 4 | "default_temp": 0.4, 5 | "default_fast_temp": 0.5, 6 | "min_p": 0.10 7 | }, 8 | "model_type": { 9 | "family": "dual_ar", 10 | "codec": "mimi", 11 | "version": null 12 | } 13 | } -------------------------------------------------------------------------------- /mlx_inference/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "smoltts_mlx" 3 | version = "0.1.1" 4 | description = "MLX inference for autoregressive speech models" 5 | readme = "README.md" 6 | requires-python = ">=3.9" 7 | dependencies = [ 8 | "fastapi>=0.115.6", 9 | "huggingface-hub>=0.27.1", 10 | "mlx==0.22.0", 11 | "numpy>=2.0.2", 12 | "pydantic>=2.10.5", 13 | "pydub>=0.25.1", 14 | "scipy>=1.13.1", 15 | "soundfile>=0.13.1", 16 | "tokenizers>=0.21.0", 17 | "tqdm>=4.67.1", 18 | "uvicorn[standard]>=0.34.0", 19 | ] 20 | 21 | [project.scripts] 22 | smoltts-server = "smoltts_mlx.scripts.server:main" 23 | 24 | [build-system] 25 | requires = ["setuptools>=61"] 26 | build-backend = "setuptools.build_meta" 27 | 28 | [tool.setuptools.packages.find] 29 | where = ["src"] 30 | include = ["smoltts_mlx*"] 31 | -------------------------------------------------------------------------------- /mlx_inference/src/smoltts_mlx/__init__.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import snapshot_download 2 | import mlx.core as mx 3 | import numpy as np 4 | from pathlib import Path 5 | from tokenizers import Tokenizer 6 | from tqdm import tqdm 7 | from typing import List, Optional 8 | 9 | from smoltts_mlx.codec.mimi import load_mimi 10 | from smoltts_mlx.lm.config import ModelType 11 | from smoltts_mlx.lm.utils.prompt import PromptEncoder 12 | from smoltts_mlx.lm.rq_transformer import ( 13 | RQTransformer, 14 | RQTransformerModelArgs, 15 | TokenConfig, 16 | ) 17 | from smoltts_mlx.lm.cache import make_prompt_cache 18 | from smoltts_mlx.lm.generate import ( 19 | generate_blocking, 20 | SingleBatchGenerator, 21 | GenerationSettings, 22 | ) 23 | 24 | 25 | class SmolTTS: 26 | def __init__( 27 | self, 28 | model_id="jkeisling/smoltts_v0", 29 | checkpoint_dir: Optional[str] = None, 30 | ): 31 | checkpoint_dir = Path( 32 | checkpoint_dir 33 | if checkpoint_dir is not None 34 | else snapshot_download(model_id) 35 | ) 36 | config = RQTransformerModelArgs.from_json_file( 37 | str(checkpoint_dir / "config.json") 38 | ) 39 | # TODO support other configs once changes are made 40 | model_type = ModelType.smoltts_v0() 41 | 42 | tokenizer = Tokenizer.from_file(str(checkpoint_dir / "tokenizer.json")) 43 | token_config = TokenConfig.from_tokenizer( 44 | model=model_type, tokenizer=tokenizer, config=config 45 | ) 46 | 47 | model = RQTransformer(config, token_config, model_type) 48 | model_path = str(checkpoint_dir / "model.safetensors") 49 | model.load_weights(model_path, strict=True) 50 | mx.eval(model.parameters()) 51 | model.eval() 52 | 53 | prompt_encoder = PromptEncoder.from_model(tokenizer, model) 54 | codec = load_mimi() 55 | 56 | self.lm = model 57 | self.prompt_encoder = prompt_encoder 58 | self.codec = codec 59 | 60 | # TODO load speakers here 61 | # TODO make this configurable 62 | self.sampling_rate = 24_000 63 | 64 | def __call__( 65 | self, 66 | input: str, 67 | voice: Optional[str] = "heart", 68 | speaker: Optional[mx.array] = None, 69 | ) -> np.ndarray: 70 | """ 71 | Returns flattened PCM array 72 | """ 73 | prompt = self._get_prompt( 74 | input, voice if voice is not None else "heart", sysprompt=speaker 75 | ) 76 | # TODO make this configurable 77 | gen = generate_blocking(self.lm, prompt, GenerationSettings()) 78 | out = self.codec.decode(gen) 79 | mx.metal.clear_cache() 80 | 81 | return np.array(out).flatten() 82 | 83 | def stream(self, input: str, voice: Optional[str] = "heart"): 84 | prompt = self._get_prompt(input, voice if voice is not None else "0") 85 | frame_gen = SingleBatchGenerator(self.lm, prompt, GenerationSettings()) 86 | mimi_cache = make_prompt_cache(self.codec.decoder_transformer) 87 | 88 | for frame in tqdm(frame_gen): 89 | audio_tokens = frame.vq_tensor[:, 1:, :] 90 | pcm_chunk = self.codec.decode_step(audio_tokens, mimi_cache) 91 | audio_data = np.array(pcm_chunk).flatten() 92 | yield audio_data 93 | 94 | self.codec.decoder.reset() 95 | mx.metal.clear_cache() 96 | 97 | def create_speaker( 98 | self, samples: List[dict], system_prompt: Optional[str] = None 99 | ) -> mx.array: 100 | turns = [] 101 | for sample in samples: 102 | if "audio" not in sample or "text" not in sample: 103 | raise ValueError( 104 | f"Sample must contain both 'text' and 'audio' but got {sample.keys()}" 105 | ) 106 | user_prompt = self.prompt_encoder.encode_text_turn("user", sample["text"]) 107 | encoded_audio = self.codec.encode(mx.array(sample["audio"])) 108 | codes = self.prompt_encoder.encode_vq(encoded_audio.squeeze(0)[:8, :]) 109 | turns.append(user_prompt) 110 | turns.append(codes) 111 | 112 | if system_prompt is not None: 113 | turns = [ 114 | self.prompt_encoder.encode_text_turn("system", system_prompt), 115 | *turns, 116 | ] 117 | 118 | return mx.concat(turns, axis=1) 119 | 120 | def _get_prompt(self, input: str, voice: str, sysprompt=None): 121 | # TODO remove this after voices are configurable 122 | voice_map = { 123 | k: v 124 | for v, k in enumerate( 125 | [ 126 | "heart", 127 | "bella", 128 | "nova", 129 | "sky", 130 | "sarah", 131 | "michael", 132 | "fenrir", 133 | "liam", 134 | "emma", 135 | "isabella", 136 | "fable", 137 | ] 138 | ) 139 | } 140 | voice_id = voice_map.get(voice, 0) 141 | 142 | if sysprompt is None: 143 | sysprompt = self.prompt_encoder.encode_text_turn( 144 | "system", f"<|speaker:{voice_id}|>" 145 | ) 146 | user_prompt = self.prompt_encoder.encode_text_turn("user", input) 147 | assistant_prefix = self.prompt_encoder.encode_text_turn("assistant") 148 | prompt = mx.concat([sysprompt, user_prompt, assistant_prefix], axis=1)[ 149 | mx.newaxis, :, : 150 | ] 151 | return prompt 152 | -------------------------------------------------------------------------------- /mlx_inference/src/smoltts_mlx/codec/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndlessReform/smoltts/410d6524966e2d433ce565e4f14f1fac2f983ff9/mlx_inference/src/smoltts_mlx/codec/__init__.py -------------------------------------------------------------------------------- /mlx_inference/src/smoltts_mlx/codec/conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | import mlx.core as mx 3 | import mlx.nn as nn 4 | from pydantic import BaseModel, Field 5 | from typing import List, Optional, Tuple 6 | 7 | 8 | class SeanetConfig(BaseModel): 9 | dimension: int = Field(default=512) 10 | channels: int = 1 11 | n_filters: int = 64 12 | n_residual_layers: int = 1 13 | compress: int = 2 14 | dilation_base: int = 2 15 | disable_norm_outer_blocks: int = 0 16 | kernel_size: int = 7 17 | residual_kernel_size: int = 3 18 | last_kernel_size: int = 3 19 | ratios: List[int] = [8, 6, 5, 4] 20 | trim_right_ratio: float = 1.0 21 | sampling_rate: float = 24_000.0 22 | upsample_groups: int = 512 23 | 24 | 25 | @mx.compile 26 | def causal_pad1d( 27 | x: mx.array, paddings: Tuple[int, int], mode: str = "zero", value: float = 0.0 28 | ) -> mx.array: 29 | if x.ndim < 2: 30 | raise ValueError( 31 | "Input tensor must have at least 2 dimensions (seq_len, channels)." 32 | ) 33 | 34 | padding_left, padding_right = paddings 35 | 36 | # Create a padding tuple that pads only the second-to-last dimension (seqlen) 37 | pad_tuple = [(0, 0)] * x.ndim 38 | pad_tuple[-2] = (padding_left + padding_right, 0) 39 | pad_tuple = tuple(pad_tuple) 40 | 41 | if mode != "reflect": 42 | out = mx.pad(x, pad_tuple, mode, value) 43 | return out 44 | 45 | # Handle reflect mode with possible extra padding 46 | length = x.shape[-2] 47 | max_pad = max(padding_left, padding_right) 48 | extra_pad = 0 49 | 50 | if length <= max_pad: 51 | extra_pad = max_pad - length + 1 52 | # Apply extra padding to the seqlen dimension 53 | x_pad = [(0, 0)] * x.ndim 54 | x_pad[-2] = (0, extra_pad) 55 | x = mx.pad(x, tuple(x_pad), "constant") 56 | 57 | padded = mx.pad(x, pad_tuple, mode, value) 58 | 59 | if extra_pad > 0: 60 | # Slice to remove the extra padding added for reflection 61 | slices = [slice(None)] * x.ndim 62 | slices[-2] = slice(None, padded.shape[-2] - extra_pad) 63 | padded = padded[tuple(slices)] 64 | 65 | return padded 66 | 67 | 68 | class MimiConv1d(nn.Module): 69 | def __init__( 70 | self, 71 | config: SeanetConfig, 72 | in_channels: int, 73 | out_channels: int, 74 | kernel_size: int, 75 | stride: int = 1, 76 | dilation: int = 1, 77 | groups: int = 1, 78 | bias: bool = True, 79 | pad_mode: Optional[str] = None, 80 | ): 81 | super().__init__() 82 | self.pad_mode = pad_mode if pad_mode is not None else "constant" 83 | self.conv = nn.Conv1d( 84 | in_channels, 85 | out_channels, 86 | kernel_size, 87 | stride, 88 | dilation=dilation, 89 | groups=groups, 90 | bias=bias, 91 | ) 92 | self.stride = stride 93 | self.dilation = dilation 94 | effective_kernel_size = (kernel_size - 1) * dilation + 1 95 | self.kernel_size = effective_kernel_size 96 | self.padding_total = effective_kernel_size - stride 97 | 98 | self.padding_right = self.padding_total // 2 99 | self.padding_left = self.padding_total - self.padding_right 100 | 101 | self._stream_prev_in: Optional[mx.array] = None 102 | self._left_pad_applied = False 103 | 104 | def reset_state(self): 105 | """ 106 | Clears any leftover input from previous streaming steps. 107 | Call this before processing a brand new stream. 108 | """ 109 | self._stream_prev_in = None 110 | self._left_pad_applied = False 111 | 112 | def _get_extra_padding_for_conv1d(self, x: mx.array) -> int: 113 | length = x.shape[-2] # Use the seqlen dimension 114 | n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1 115 | n_frames = math.ceil(n_frames) - 1 116 | ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total 117 | 118 | return ideal_length - length 119 | 120 | def __call__(self, x: mx.array) -> mx.array: 121 | extra_padding = self._get_extra_padding_for_conv1d(x) 122 | x = causal_pad1d( 123 | x, 124 | (self.padding_left, self.padding_right + extra_padding), 125 | self.pad_mode, 126 | ) 127 | x = self.conv(x) 128 | return x 129 | 130 | def step(self, x: mx.array) -> Optional[mx.array]: 131 | """ 132 | Streaming forward pass: processes chunk x and merges with leftover output 133 | from the previous step to avoid edge artifacts. 134 | """ 135 | effective_k_size = (self.kernel_size - 1) * self.dilation + 1 136 | if not self._left_pad_applied: 137 | self._left_pad_applied = True 138 | padding_total = effective_k_size - self.stride 139 | x = mx.pad(x, [(0, 0), (padding_total, 0), (0, 0)]) 140 | 141 | # Restore previous input 142 | if self._stream_prev_in is not None: 143 | x_long = mx.concat([self._stream_prev_in, x], axis=-2) 144 | else: 145 | x_long = x 146 | 147 | num_frames = int( 148 | max(x_long.shape[-2] + self.stride - effective_k_size, 0) / self.stride 149 | ) 150 | if num_frames == 0: 151 | return None 152 | 153 | offset = num_frames * self.stride 154 | self._stream_prev_in = x_long[:, offset:, :] 155 | 156 | in_length = (num_frames - 1) * self.stride + effective_k_size 157 | xs = x_long[:, :in_length, :] 158 | return self.conv(xs) 159 | 160 | 161 | class MimiConvTranspose1d(nn.Module): 162 | def __init__( 163 | self, 164 | config: SeanetConfig, 165 | in_channels: int, 166 | out_channels: int, 167 | kernel_size: int, 168 | stride: int = 1, 169 | bias=True, 170 | ): 171 | super().__init__() 172 | self.trim_right_ratio = config.trim_right_ratio 173 | self.conv = nn.ConvTranspose1d( 174 | in_channels, 175 | out_channels, 176 | kernel_size, 177 | stride, 178 | bias=bias, 179 | ) 180 | 181 | self.stride = stride 182 | self.kernel_size = kernel_size 183 | padding_total = kernel_size - stride 184 | self.padding_right = math.ceil(padding_total * self.trim_right_ratio) 185 | self.padding_left = padding_total - self.padding_right 186 | 187 | self._stream_prev_out: Optional[mx.array] = None 188 | 189 | def reset_state(self): 190 | """ 191 | Clears leftover output from previous streaming steps. 192 | """ 193 | self._stream_prev_out = None 194 | 195 | def __call__(self, x: mx.array): 196 | x = self.conv(x) 197 | end = x.shape[-2] - self.padding_right 198 | x = x[:, self.padding_left : end, :] 199 | return x 200 | 201 | def step(self, x: mx.array) -> mx.array: 202 | """ 203 | Adapted from rustymimi rather than the HF Transformers port (which has no streaming support). 204 | I freely admit I do not understand this, but yet it works. 205 | """ 206 | ys = self.conv(x) 207 | ot = ys.shape[-2] 208 | if self._stream_prev_out is not None: 209 | prev_len = self._stream_prev_out.shape[-2] 210 | # Remove the bias to avoid it happening multiple times 211 | if self.conv.bias is not None: 212 | self._stream_prev_out -= self.conv.bias 213 | ys1 = ys[:, :prev_len, :] + self._stream_prev_out 214 | ys2 = ys[:, prev_len:, :] 215 | ys = mx.concat([ys1, ys2], axis=-2) 216 | invalid_steps = self.kernel_size - self.stride 217 | split_point = ot - invalid_steps 218 | ys, prev_ys = ys[:, :split_point, :], ys[:, split_point:, :] 219 | self._stream_prev_out = prev_ys 220 | return ys 221 | 222 | 223 | class MeaninglessConvPassthrough(nn.Module): 224 | def __init__(self, in_channels: int, out_channels: int, kernel_size: int): 225 | """ 226 | This is necessary to load in the weights without using the MLX nn wrapper 227 | """ 228 | super().__init__() 229 | self.weight = mx.zeros([in_channels, kernel_size, out_channels]) 230 | 231 | 232 | class GroupedConvTranspose1d(nn.Module): 233 | def __init__( 234 | self, 235 | config: SeanetConfig, 236 | in_channels: int, 237 | out_channels: int, 238 | kernel_size: int, 239 | stride: int = 1, 240 | groups: int = 1, 241 | bias=False, 242 | ): 243 | super().__init__() 244 | self.trim_right_ratio = config.trim_right_ratio 245 | channels_per_group = out_channels 246 | self.conv = MeaninglessConvPassthrough( 247 | in_channels, 248 | out_channels 249 | // channels_per_group, # This becomes 1 when groups == in_channels 250 | kernel_size, 251 | ) 252 | padding_total = kernel_size - stride 253 | # Due to the dilation 254 | self.padding_right = math.ceil(padding_total * self.trim_right_ratio) 255 | self.padding_left = padding_total - self.padding_right 256 | self.groups = groups 257 | self.kernel_size = kernel_size 258 | self.stride = stride 259 | self.in_channels = in_channels 260 | self._conv_weight = None 261 | 262 | @property 263 | def conv_weight(self): 264 | if self._conv_weight is None: 265 | # Torch collapses the groups on the OUTPUT, but MLX collapses on the INPUT 266 | self._conv_weight = self.conv.weight.reshape( 267 | self.groups, self.kernel_size, self.in_channels // self.groups 268 | ) 269 | return self._conv_weight 270 | 271 | def __call__(self, x: mx.array): 272 | x = mx.conv_transpose1d( 273 | x, 274 | self.conv_weight, 275 | padding=0, 276 | stride=self.stride, 277 | groups=self.groups, 278 | ) 279 | 280 | end = x.shape[-2] - self.padding_right 281 | x = x[:, self.padding_left : end, :] 282 | return x 283 | -------------------------------------------------------------------------------- /mlx_inference/src/smoltts_mlx/codec/mimi.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import hf_hub_download 2 | import math 3 | import mlx.core as mx 4 | import mlx.nn as nn 5 | import numpy as np 6 | from pydantic import BaseModel 7 | from typing import Any, List, Optional 8 | 9 | from smoltts_mlx.codec.rvq import RVQConfig, MimiSplitResidualVectorQuantizer 10 | from smoltts_mlx.codec.conv import ( 11 | SeanetConfig, 12 | MimiConv1d, 13 | GroupedConvTranspose1d, 14 | ) 15 | from smoltts_mlx.codec.seanet import MimiEncoder, MimiDecoder 16 | from smoltts_mlx.codec.transformer import MimiTransformerConfig, MimiTransformer 17 | 18 | 19 | class MimiConfig(BaseModel): 20 | seanet: SeanetConfig 21 | transformer: MimiTransformerConfig 22 | rvq: RVQConfig 23 | 24 | 25 | def get_encodec_frame_rate(config: MimiConfig): 26 | hop_length = np.prod(config.seanet.ratios) 27 | return math.ceil(config.seanet.sampling_rate / hop_length) 28 | 29 | 30 | class MimiModel(nn.Module): 31 | def __init__(self, config: MimiConfig): 32 | super().__init__() 33 | self.config = config 34 | 35 | self.encoder = MimiEncoder(config.seanet) 36 | self.encoder_transformer = MimiTransformer(config.transformer) 37 | encodec_frame_rate = get_encodec_frame_rate(config) 38 | 39 | self.downsample = MimiConv1d( 40 | config.seanet, 41 | config.seanet.dimension, 42 | config.seanet.dimension, 43 | kernel_size=2 * int(encodec_frame_rate / config.rvq.frame_rate), 44 | stride=2, 45 | bias=False, 46 | pad_mode="edge", 47 | ) 48 | kernel_size = 2 * int(encodec_frame_rate / config.rvq.frame_rate) 49 | self.upsample = GroupedConvTranspose1d( 50 | config.seanet, 51 | config.seanet.dimension, 52 | config.seanet.dimension, 53 | kernel_size=kernel_size, 54 | stride=2, 55 | bias=False, 56 | groups=512, 57 | ) 58 | 59 | self.decoder_transformer = MimiTransformer(config.transformer) 60 | self.decoder = MimiDecoder(config.seanet) 61 | 62 | self.quantizer = MimiSplitResidualVectorQuantizer(config.rvq) 63 | 64 | def encode(self, x: mx.array): 65 | # Deliberately not implementing streaming encode for now 66 | x = mx.swapaxes(x, 1, 2) 67 | embedded = self.encoder(x) 68 | transformed = self.encoder_transformer(embedded) 69 | downsampled = self.downsample(transformed) 70 | codes = self.quantizer.encode(downsampled) 71 | return mx.swapaxes(codes, 0, 1) 72 | 73 | def _decode_frame( 74 | self, codes: mx.array, cache: Optional[List[Any]], is_step=False 75 | ) -> mx.array: 76 | embeddings = self.quantizer.decode(codes) 77 | embeddings = self.upsample(embeddings) 78 | decoder_outputs = self.decoder_transformer(embeddings, cache=cache) 79 | embeddings = decoder_outputs 80 | with mx.stream(mx.gpu): 81 | if is_step: 82 | outputs = self.decoder.step(embeddings) 83 | else: 84 | outputs = self.decoder(embeddings) 85 | out = mx.swapaxes(outputs, 1, 2) 86 | return out 87 | 88 | def decode( 89 | self, 90 | audio_codes: mx.array, 91 | cache: Optional[List[Any]] = None, 92 | padding_mask: Optional[mx.array] = None, 93 | ): 94 | audio_values = self._decode_frame(audio_codes, cache) 95 | 96 | if padding_mask is not None and padding_mask.shape[-1] < audio_values.shape[-1]: 97 | audio_values = audio_values[:, :, : padding_mask.shape[-1]] 98 | 99 | return audio_values 100 | 101 | def decode_step(self, codes: mx.array, cache: Optional[List[Any]]) -> mx.array: 102 | audio_values = self._decode_frame(codes, cache, is_step=True) 103 | 104 | return audio_values 105 | 106 | 107 | def load_mimi(format: str = "fp32") -> MimiModel: 108 | config = MimiConfig( 109 | seanet=SeanetConfig(), transformer=MimiTransformerConfig(), rvq=RVQConfig() 110 | ) 111 | model_path = hf_hub_download("kyutai/mimi", "model.safetensors") 112 | model = MimiModel(config) 113 | state_dict = mx.load(model_path) 114 | 115 | # Yes, this is dumb. 116 | # The all-knowing maintainers of MLX decided to serialize conv1ds as NHWC instaed of NCHW, 117 | # despite the entire API surface being designed to mimic pytorch, because it's faster on apple silicon, and then 118 | # "helpfully" leaked that abstraction onto me. 119 | # But the convtrans1d is DIFFERENT yet again. 120 | def is_convtrans1d(key) -> bool: 121 | return ( 122 | # Decoder only 123 | "decoder" in key 124 | and key.endswith(".conv.weight") 125 | and "block" not in key 126 | # Layer 0 is regular 127 | and "0" not in key 128 | # Final layer is regular 129 | and "14" not in key 130 | ) 131 | 132 | def is_conv1d(key): 133 | return ( 134 | key.endswith(".conv.weight") 135 | # RVQ proj 136 | or "quantizer.input_proj" in key 137 | or "quantizer.output_proj" in key 138 | and key != "upsample.conv.weight" 139 | ) 140 | 141 | converted_dict = { 142 | k: v.transpose(1, 2, 0) 143 | if is_convtrans1d(k) 144 | else v.transpose(0, 2, 1) 145 | if is_conv1d(k) 146 | else v 147 | for k, v in state_dict.items() 148 | } 149 | dtype = mx.bfloat16 if format == "bf16" else mx.float32 150 | weight_list = [(k, v.astype(dtype)) for k, v in converted_dict.items()] 151 | 152 | model.load_weights(weight_list, strict=True) 153 | mx.eval(model.parameters()) 154 | model.eval() 155 | 156 | return model 157 | -------------------------------------------------------------------------------- /mlx_inference/src/smoltts_mlx/codec/rvq.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | import mlx.nn as nn 3 | from pydantic import BaseModel 4 | from typing import Optional 5 | 6 | 7 | class RVQConfig(BaseModel): 8 | codebook_size: int = 2048 9 | codebook_dim: int = 256 10 | num_quantizers: int = 32 11 | num_semantic_quantizers: int = 1 12 | frame_rate: float = 12.5 13 | hidden_dim: int = 512 14 | 15 | 16 | @mx.compile 17 | def cdist(x: mx.array, y: mx.array): 18 | x1_square_norms = mx.sum(x**2, axis=-1, keepdims=True) 19 | x2_square_norms = mx.swapaxes(mx.sum(y**2, axis=-1, keepdims=True), 2, 1) 20 | 21 | dot_products = mx.matmul(x, mx.swapaxes(y, 2, 1)) 22 | dists_sq = x1_square_norms + x2_square_norms - 2 * dot_products 23 | return dists_sq.sqrt() 24 | 25 | 26 | class MimiEuclideanCodebook(nn.Module): 27 | """ 28 | Codebook with Euclidean distance. 29 | """ 30 | 31 | def __init__(self, config: RVQConfig, epsilon: float = 1e-5): 32 | super().__init__() 33 | self.embed_sum = mx.zeros([config.codebook_size, config.codebook_dim]) 34 | self.codebook_size = config.codebook_size 35 | self.epsilon = epsilon 36 | # This does nothing at inference, but we need it so the keys line up 37 | self.initialized = mx.array([True], dtype=mx.float32) 38 | self.cluster_usage = mx.ones(config.codebook_size) 39 | self._embed = None 40 | 41 | @property 42 | def embed(self) -> mx.array: 43 | if self._embed is None: 44 | self._embed = ( 45 | self.embed_sum 46 | / mx.maximum(self.cluster_usage, self.epsilon)[:, mx.newaxis] 47 | ) 48 | return self._embed 49 | 50 | def decode(self, embed_ind) -> mx.array: 51 | quantize = self.embed[embed_ind, :] 52 | return quantize 53 | 54 | def quantize(self, x: mx.array) -> mx.array: 55 | dists = cdist(x[mx.newaxis, :], self.embed[mx.newaxis, :])[0] 56 | embed_ind = dists.argmin(axis=-1) 57 | return embed_ind 58 | 59 | def encode(self, x: mx.array) -> mx.array: 60 | shape = x.shape 61 | x = x.reshape((-1, shape[-1])) 62 | embed_ind = self.quantize(x) 63 | embed_ind = embed_ind.reshape(shape[:-1]) 64 | return embed_ind 65 | 66 | 67 | class MimiVectorQuantization(nn.Module): 68 | def __init__(self, config: RVQConfig): 69 | super().__init__() 70 | self.codebook = MimiEuclideanCodebook(config) 71 | 72 | def encode(self, x: mx.array) -> mx.array: 73 | embed_in = self.codebook.encode(x) 74 | return embed_in 75 | 76 | def decode(self, embed_ind: mx.array) -> mx.array: 77 | quantize = self.codebook.decode(embed_ind) 78 | out = mx.transpose(quantize, (0, 2, 1)) 79 | return out 80 | 81 | 82 | class MimiResidualVectorQuantizer(nn.Module): 83 | def __init__(self, config: RVQConfig, num_quantizers: Optional[int] = None): 84 | super().__init__() 85 | self.num_quantizers = ( 86 | num_quantizers if num_quantizers is not None else config.num_quantizers 87 | ) 88 | self.layers = [ 89 | MimiVectorQuantization(config) for _ in range(self.num_quantizers) 90 | ] 91 | 92 | self.input_proj = nn.Conv1d( 93 | config.hidden_dim, config.codebook_dim, 1, bias=False 94 | ) 95 | self.output_proj = nn.Conv1d( 96 | config.codebook_dim, config.hidden_dim, 1, bias=False 97 | ) 98 | 99 | def encode( 100 | self, embeddings: mx.array, num_quantizers: Optional[int] = None 101 | ) -> mx.array: 102 | embeddings = self.input_proj(embeddings) 103 | num_quantizers = ( 104 | num_quantizers if num_quantizers is not None else self.num_quantizers 105 | ) 106 | 107 | residual = embeddings 108 | all_indices = [] 109 | for layer in self.layers[:num_quantizers]: 110 | indices = layer.encode(residual) 111 | quantized = layer.decode(indices) 112 | residual = residual - mx.swapaxes(quantized, 1, 2) 113 | all_indices.append(indices) 114 | 115 | out_indices = mx.stack(all_indices) 116 | return out_indices 117 | 118 | def decode(self, codes: mx.array) -> mx.array: 119 | quantized_out = mx.array(0.0) 120 | codes = mx.swapaxes(codes, 0, 1) 121 | for i, indices in enumerate(codes): 122 | layer = self.layers[i] 123 | quantized = layer.decode(indices) 124 | quantized_out = quantized_out + quantized 125 | 126 | # (bsz, dim, seqlen) to dim first 127 | quantized_out = mx.swapaxes(quantized_out, 1, 2) 128 | if self.output_proj is not None: 129 | quantized_out = self.output_proj(quantized_out) 130 | 131 | return quantized_out 132 | 133 | 134 | class MimiSplitResidualVectorQuantizer(nn.Module): 135 | def __init__(self, config: RVQConfig): 136 | super().__init__() 137 | self.codebook_size = config.codebook_size 138 | self.frame_rate = config.frame_rate 139 | self.max_num_quantizers = config.num_quantizers 140 | 141 | self.num_semantic_quantizers = config.num_semantic_quantizers 142 | self.num_acoustic_quantizers = ( 143 | config.num_quantizers - config.num_semantic_quantizers 144 | ) 145 | 146 | self.semantic_residual_vector_quantizer = MimiResidualVectorQuantizer( 147 | config, self.num_semantic_quantizers 148 | ) 149 | self.acoustic_residual_vector_quantizer = MimiResidualVectorQuantizer( 150 | config, self.num_acoustic_quantizers 151 | ) 152 | 153 | def encode( 154 | self, embeddings: mx.array, num_quantizers: Optional[int] = None 155 | ) -> mx.array: 156 | num_quantizers = ( 157 | self.max_num_quantizers if num_quantizers is None else num_quantizers 158 | ) 159 | 160 | if num_quantizers > self.max_num_quantizers: 161 | raise ValueError( 162 | f"The number of quantizers (i.e codebooks) asked should be lower than the total number of quantizers {self.max_num_quantizers}, but is currently {num_quantizers}." 163 | ) 164 | 165 | if num_quantizers < self.num_semantic_quantizers: 166 | raise ValueError( 167 | f"The number of quantizers (i.e codebooks) asked should be higher than the number of semantic quantizers {self.num_semantic_quantizers}, but is currently {num_quantizers}." 168 | ) 169 | 170 | codes = self.semantic_residual_vector_quantizer.encode(embeddings) 171 | if num_quantizers > self.num_semantic_quantizers: 172 | acoustic_codes = self.acoustic_residual_vector_quantizer.encode( 173 | embeddings, num_quantizers=num_quantizers - self.num_semantic_quantizers 174 | ) 175 | codes = mx.concat([codes, acoustic_codes], axis=0) 176 | 177 | return codes 178 | 179 | def decode(self, codes: mx.array): 180 | quantized_out = self.semantic_residual_vector_quantizer.decode( 181 | codes[:, : self.num_semantic_quantizers] 182 | ) 183 | quantized_out += self.acoustic_residual_vector_quantizer.decode( 184 | codes[:, self.num_semantic_quantizers :] 185 | ) 186 | return quantized_out 187 | -------------------------------------------------------------------------------- /mlx_inference/src/smoltts_mlx/codec/seanet.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | import mlx.nn as nn 3 | from typing import List, Optional 4 | 5 | from smoltts_mlx.codec.conv import SeanetConfig, MimiConv1d, MimiConvTranspose1d 6 | 7 | 8 | class MimiResnetBlock(nn.Module): 9 | def __init__(self, config: SeanetConfig, dim: int, dilations: List[int]): 10 | super().__init__() 11 | assert len(dilations) == 2, ( 12 | "Number of kernel sizes should match number of dilations" 13 | ) 14 | kernel_sizes = (config.residual_kernel_size, 1) 15 | hidden = dim // config.compress 16 | block = [] 17 | for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)): 18 | in_chs = dim if i == 0 else hidden 19 | out_chs = dim if i == len(kernel_sizes) - 1 else hidden 20 | block += [nn.ELU()] 21 | block += [ 22 | MimiConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation) 23 | ] 24 | 25 | self.block = block 26 | 27 | def __call__(self, x: mx.array) -> mx.array: 28 | residual = x 29 | for layer in self.block: 30 | x = layer(x) 31 | return residual + x 32 | 33 | def step(self, x: mx.array) -> Optional[mx.array]: 34 | residual = x 35 | for layer in self.block: 36 | if callable(getattr(layer, "step", None)): 37 | step = layer.step(x) 38 | if step is not None: 39 | x = step 40 | else: 41 | return None 42 | else: 43 | x = layer(x) 44 | return residual + x 45 | 46 | def reset_state(self): 47 | for layer in self.block: 48 | if callable(getattr(layer, "reset_state", None)): 49 | layer.reset_state() 50 | 51 | 52 | class MimiEncoder(nn.Module): 53 | """ 54 | SEANet encoder as used by Mimi. 55 | """ 56 | 57 | def __init__(self, config: SeanetConfig): 58 | super().__init__() 59 | model = [ 60 | MimiConv1d(config, config.channels, config.n_filters, config.kernel_size) 61 | ] 62 | scaling = 1 63 | 64 | for ratio in reversed(config.ratios): 65 | current_scale = scaling * config.n_filters 66 | for j in range(config.n_residual_layers): 67 | model += [ 68 | MimiResnetBlock(config, current_scale, [config.dilation_base**j, 1]) 69 | ] 70 | model += [nn.ELU()] 71 | model += [ 72 | MimiConv1d( 73 | config, 74 | current_scale, 75 | current_scale * 2, 76 | kernel_size=ratio * 2, 77 | stride=ratio, 78 | ) 79 | ] 80 | scaling *= 2 81 | 82 | model += [nn.ELU()] 83 | model += [ 84 | MimiConv1d( 85 | config, 86 | scaling * config.n_filters, 87 | config.dimension, 88 | config.last_kernel_size, 89 | ) 90 | ] 91 | self.layers = model 92 | 93 | def __call__(self, x: mx.array) -> mx.array: 94 | for layer in self.layers: 95 | x = layer(x) 96 | return x 97 | 98 | 99 | class MimiDecoder(nn.Module): 100 | """ 101 | SEANet decoder as used by Mimi 102 | """ 103 | 104 | def __init__(self, config: SeanetConfig): 105 | super().__init__() 106 | scaling = int(2 ** len(config.ratios)) 107 | model = [ 108 | MimiConv1d( 109 | config, config.dimension, scaling * config.n_filters, config.kernel_size 110 | ) 111 | ] 112 | 113 | for ratio in config.ratios: 114 | current_scale = scaling * config.n_filters 115 | model += [nn.ELU()] 116 | model += [ 117 | MimiConvTranspose1d( 118 | config, 119 | current_scale, 120 | current_scale // 2, 121 | kernel_size=ratio * 2, 122 | stride=ratio, 123 | ) 124 | ] 125 | for j in range(config.n_residual_layers): 126 | model += [ 127 | MimiResnetBlock( 128 | config, current_scale // 2, [config.dilation_base**j, 1] 129 | ) 130 | ] 131 | scaling //= 2 132 | 133 | model += [nn.ELU()] 134 | model += [ 135 | MimiConv1d( 136 | config, config.n_filters, config.channels, config.last_kernel_size 137 | ) 138 | ] 139 | self.layers = model 140 | 141 | def __call__(self, x: mx.array) -> mx.array: 142 | for i, layer in enumerate(self.layers): 143 | x = layer(x) 144 | return x 145 | 146 | def step(self, x: mx.array) -> Optional[mx.array]: 147 | for i, layer in enumerate(self.layers): 148 | if callable(getattr(layer, "step", None)): 149 | step = layer.step(x) 150 | if step is not None: 151 | x = step 152 | else: 153 | return None 154 | else: 155 | x = layer(x) 156 | return x 157 | 158 | def reset(self): 159 | for layer in self.layers: 160 | if callable(getattr(layer, "reset_state", "None")): 161 | layer.reset_state() 162 | -------------------------------------------------------------------------------- /mlx_inference/src/smoltts_mlx/codec/transformer.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | import mlx.nn as nn 3 | from pydantic import BaseModel, Field 4 | from typing import Any, List, Optional 5 | 6 | 7 | from smoltts_mlx.lm.rq_transformer import create_attention_mask 8 | 9 | 10 | class MimiTransformerConfig(BaseModel): 11 | # NOTE: Leaving out norm, RoPE call, ffn, attn bias because this is pointless, it's standard LLaMA 12 | d_model: int = Field(default=512) 13 | num_heads: int = Field(default=8) 14 | """ 15 | Mimi v1 is NOT GQA so I'm deliberately setting this one only 16 | """ 17 | head_dim: int = Field(default=64) 18 | num_layers: int = Field(default=8) 19 | causal: bool = Field(default=True) 20 | norm_first: bool = Field(default=True) 21 | layer_scale: Optional[float] = Field(default=0.01) 22 | context: int = Field(default=250) 23 | conv_kernel_size: int = Field(default=5) 24 | use_conv_bias: bool = Field(default=True) 25 | use_conv_block: bool = Field(default=False) 26 | max_period: int = Field(default=10_000) 27 | 28 | dim_feedforward: int = Field(default=2048) 29 | kv_repeat: int = Field(default=1) 30 | max_seq_len: int = Field(default=8192) 31 | rope_theta: float = Field(default=10_000) 32 | 33 | 34 | class MimiMLP(nn.Module): 35 | def __init__(self, config: MimiTransformerConfig): 36 | super().__init__() 37 | self.fc1 = nn.Linear(config.d_model, config.dim_feedforward, bias=False) 38 | self.fc2 = nn.Linear(config.dim_feedforward, config.d_model, bias=False) 39 | 40 | def __call__(self, x: mx.array) -> mx.array: 41 | x = self.fc1(x) 42 | x = nn.gelu(x) 43 | x = self.fc2(x) 44 | return x 45 | 46 | 47 | class MimiAttention(nn.Module): 48 | def __init__(self, config: MimiTransformerConfig): 49 | self.q_proj = nn.Linear( 50 | config.d_model, config.num_heads * config.head_dim, bias=False 51 | ) 52 | self.k_proj = nn.Linear( 53 | config.d_model, config.num_heads * config.head_dim, bias=False 54 | ) 55 | self.v_proj = nn.Linear( 56 | config.d_model, config.num_heads * config.head_dim, bias=False 57 | ) 58 | self.o_proj = nn.Linear( 59 | config.num_heads * config.head_dim, config.d_model, bias=False 60 | ) 61 | self.rope = nn.RoPE( 62 | int(config.d_model / config.num_heads), 63 | traditional=False, 64 | base=config.rope_theta, 65 | ) 66 | 67 | self.scaling = config.head_dim**-0.5 68 | self.n_head = config.num_heads 69 | self.head_dim = config.head_dim 70 | 71 | def __call__( 72 | self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None 73 | ) -> mx.array: 74 | bsz, seqlen, _ = x.shape 75 | q = self.q_proj(x) 76 | k = self.k_proj(x) 77 | v = self.v_proj(x) 78 | 79 | q = q.reshape((bsz, seqlen, self.n_head, self.head_dim)) 80 | k = k.reshape((bsz, seqlen, self.n_head, self.head_dim)) 81 | v = v.reshape((bsz, seqlen, self.n_head, self.head_dim)) 82 | 83 | q, k, v = map(lambda x: x.transpose(0, 2, 1, 3), (q, k, v)) 84 | if cache is not None: 85 | q = self.rope(q, offset=cache.offset) 86 | k = self.rope(k, offset=cache.offset) 87 | k, v = cache.update_and_fetch(k, v) 88 | else: 89 | q = self.rope(q) 90 | k = self.rope(k) 91 | 92 | output = mx.fast.scaled_dot_product_attention( 93 | q=q, k=k, v=v, scale=self.scaling, mask=mask 94 | ) 95 | output = output.transpose(0, 2, 1, 3).reshape(bsz, seqlen, -1) 96 | return self.o_proj(output) 97 | 98 | 99 | class MimiLayerScale(nn.Module): 100 | def __init__(self, config: MimiTransformerConfig): 101 | super().__init__() 102 | self.scale = mx.full((config.d_model,), config.layer_scale) 103 | 104 | def __call__(self, x: mx.array) -> mx.array: 105 | # TODO check this 106 | return x * self.scale 107 | 108 | 109 | class MimiTransformerLayer(nn.Module): 110 | def __init__(self, config: MimiTransformerConfig): 111 | self.mlp = MimiMLP(config) 112 | # Leaving out eps since the MLX default is identical and the weights won't change 113 | self.input_layernorm = nn.LayerNorm(config.d_model) 114 | self.post_attention_layernorm = nn.LayerNorm(config.d_model) 115 | self.self_attn = MimiAttention(config) 116 | 117 | self.self_attn_layer_scale = MimiLayerScale(config) 118 | self.mlp_layer_scale = MimiLayerScale(config) 119 | 120 | def __call__( 121 | self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None 122 | ) -> mx.array: 123 | residual = x 124 | hidden_states = self.self_attn(self.input_layernorm(x), mask=mask, cache=cache) 125 | h = residual + self.self_attn_layer_scale(hidden_states) 126 | 127 | residual = h 128 | hidden_states = self.post_attention_layernorm(h) 129 | hidden_states = self.mlp(hidden_states) 130 | h = residual + self.mlp_layer_scale(hidden_states) 131 | return h 132 | 133 | 134 | class MimiTransformer(nn.Module): 135 | def __init__(self, config: MimiTransformerConfig): 136 | super().__init__() 137 | self.layers = [MimiTransformerLayer(config) for _ in range(config.num_layers)] 138 | self.config = config 139 | 140 | def __call__( 141 | self, 142 | x: mx.array, 143 | cache: Optional[List[Any]] = None, 144 | ): 145 | mask = create_attention_mask(x, cache) if x.shape[1] > 1 else None 146 | 147 | for layer, layer_cache in zip(self.layers, cache or [None] * len(self.layers)): 148 | x = layer(x, mask=mask, cache=layer_cache) 149 | 150 | return x 151 | -------------------------------------------------------------------------------- /mlx_inference/src/smoltts_mlx/io/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndlessReform/smoltts/410d6524966e2d433ce565e4f14f1fac2f983ff9/mlx_inference/src/smoltts_mlx/io/__init__.py -------------------------------------------------------------------------------- /mlx_inference/src/smoltts_mlx/io/wav.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def pcm_to_wav_bytes(pcm_data: np.ndarray, sample_rate: int = 24000) -> bytes: 5 | """Convert raw PCM data to WAV bytes. 6 | 7 | Args: 8 | pcm_data: Floating point PCM data in range [-1, 1], will be flattened 9 | sample_rate: Sample rate in Hz (default: 24000) 10 | 11 | Returns: 12 | Complete WAV file as bytes 13 | """ 14 | # Flatten to 1D array first 15 | pcm_data = pcm_data.flatten() 16 | 17 | header = bytearray() 18 | # RIFF chunk 19 | header.extend(b"RIFF") 20 | header.extend((len(pcm_data) * 2 + 36).to_bytes(4, "little")) # Total size minus 8 21 | header.extend(b"WAVE") 22 | # fmt chunk 23 | header.extend(b"fmt ") 24 | header.extend((16).to_bytes(4, "little")) # fmt chunk size 25 | header.extend((1).to_bytes(2, "little")) # PCM format 26 | header.extend((1).to_bytes(2, "little")) # Mono 27 | header.extend(sample_rate.to_bytes(4, "little")) 28 | header.extend((sample_rate * 2).to_bytes(4, "little")) # Bytes per second 29 | header.extend((2).to_bytes(2, "little")) # Block align 30 | header.extend((16).to_bytes(2, "little")) # Bits per sample 31 | # data chunk 32 | header.extend(b"data") 33 | header.extend((len(pcm_data) * 2).to_bytes(4, "little")) # Data size 34 | 35 | # Convert PCM to 16-bit samples 36 | wav_data = (pcm_data * 32767).astype(np.int16).tobytes() 37 | return bytes(header) + wav_data 38 | -------------------------------------------------------------------------------- /mlx_inference/src/smoltts_mlx/lm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndlessReform/smoltts/410d6524966e2d433ce565e4f14f1fac2f983ff9/mlx_inference/src/smoltts_mlx/lm/__init__.py -------------------------------------------------------------------------------- /mlx_inference/src/smoltts_mlx/lm/cache.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | import mlx.nn as nn 3 | from typing import Any, List 4 | 5 | 6 | class KVCache: 7 | def __init__(self): 8 | self.keys = None 9 | self.values = None 10 | self.offset = 0 # Keep tracking offset for interface compatibility 11 | 12 | def update_and_fetch(self, keys: mx.array, values: mx.array): 13 | """Update cache with new keys/values and return full concatenated cache.""" 14 | if self.keys is None: 15 | self.keys = keys 16 | self.values = values 17 | else: 18 | self.keys = mx.concatenate([self.keys, keys], axis=2) 19 | self.values = mx.concatenate([self.values, values], axis=2) 20 | 21 | self.offset += keys.shape[2] 22 | return self.keys, self.values 23 | 24 | 25 | def make_prompt_cache(model: nn.Module, is_fast: bool = False) -> List[Any]: 26 | """Construct the model's cache for use during generation.""" 27 | if hasattr(model, "make_cache"): 28 | return model.make_cache() 29 | 30 | if is_fast: 31 | return [KVCache() for _ in range(len(model.fast_layers))] 32 | else: 33 | return [KVCache() for _ in range(len(model.layers))] 34 | -------------------------------------------------------------------------------- /mlx_inference/src/smoltts_mlx/lm/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from typing import Literal, Optional 3 | 4 | 5 | class ModelType(BaseModel): 6 | family: Literal["fish", "dual_ar"] 7 | version: Optional[Literal["1.5", "1.4", "1.2"]] 8 | codec: Literal["mimi", "1.4", "1.2"] 9 | 10 | @classmethod 11 | def smoltts_v0(cls): 12 | return cls(family="dual_ar", version=None, codec="mimi") 13 | -------------------------------------------------------------------------------- /mlx_inference/src/smoltts_mlx/lm/generate.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | from pydantic import BaseModel, Field 3 | import time 4 | from tqdm import tqdm 5 | from typing import Any, Optional, List 6 | 7 | from smoltts_mlx.lm.rq_transformer import RQTransformer 8 | from smoltts_mlx.lm.cache import make_prompt_cache, KVCache 9 | from smoltts_mlx.lm.utils.samplers import min_p_sampling 10 | 11 | 12 | class GenerationSettings(BaseModel): 13 | default_temp: float = Field(default=0.7) 14 | default_fast_temp: Optional[float] = Field(default=0.7) 15 | min_p: Optional[float] = Field(default=None) 16 | max_new_tokens: int = Field(default=1024) 17 | 18 | 19 | class VQToken(BaseModel): 20 | semantic_code: int 21 | audio_codes: Optional[Any] 22 | vq_tensor: Any 23 | 24 | 25 | class SingleBatchGenerator: 26 | model: RQTransformer 27 | input_pos: int 28 | generation_settings: GenerationSettings 29 | prompt: Optional[mx.array] 30 | previous_codes: Optional[List[int]] 31 | audio_only: bool 32 | cache: List[KVCache] 33 | 34 | def __init__( 35 | self, 36 | model: RQTransformer, 37 | prompt: mx.array, 38 | generation_settings: GenerationSettings, 39 | audio_only: bool = True, 40 | ): 41 | self.model = model 42 | # TODO handle KV cache 43 | self.input_pos = 0 44 | # TODO handle this 45 | self.max_new_tokens = ( 46 | generation_settings.max_new_tokens 47 | if generation_settings.max_new_tokens is not None 48 | else model.config.max_seq_len 49 | ) 50 | self.generation_settings = generation_settings 51 | self.audio_only = audio_only 52 | self.prompt = prompt 53 | self.previous_codes = None 54 | self.slow_cache = make_prompt_cache(model) 55 | 56 | def __iter__(self): 57 | return self 58 | 59 | def __next__(self): 60 | if self.input_pos > self.max_new_tokens: 61 | raise StopIteration 62 | elif self.prompt is None: 63 | # Previous iteration told us to stop 64 | raise StopIteration 65 | 66 | x = self.prompt 67 | prompt_length = x.shape[-1] 68 | 69 | x = x if x.ndim == 3 else x[mx.newaxis, :, :] 70 | 71 | logits, hidden_states = self.model.forward_generate( 72 | self.prompt, self.slow_cache 73 | ) 74 | mx.eval(logits, hidden_states) 75 | logits = logits if logits.ndim == 3 else logits[mx.newaxis, :, :] 76 | # slow_logits = ( 77 | # constrain_logits_to_audio( 78 | # logits, self.model.model_type, self.model.token_config 79 | # ) 80 | # if ( 81 | # is_modern := self.audio_only 82 | # and self.model.model_type.family == "dual_ar" 83 | # or self.model.model_type.version == "1.5" 84 | # ) 85 | # else logits 86 | # ) 87 | slow_logits = logits 88 | if self.generation_settings.default_temp == 0.0: 89 | token_ids = mx.argmax(slow_logits, axis=-1) 90 | elif self.generation_settings.min_p is not None: 91 | token_ids = min_p_sampling( 92 | slow_logits, 93 | min_p=self.generation_settings.min_p, 94 | temperature=self.generation_settings.default_temp, 95 | ) 96 | else: 97 | # TODO improve sampling, I just want SOME output 98 | slow_logits = slow_logits / self.generation_settings.default_temp 99 | token_ids = mx.random.categorical(slow_logits) 100 | 101 | # slow_token_id = ( 102 | # rescale_semantic_tokens( 103 | # token_ids, self.model.model_type, self.model.token_config 104 | # )[0] 105 | # if is_modern 106 | # else token_ids[0] 107 | # ) 108 | slow_token_id = token_ids.flatten()[0] 109 | 110 | codes = [] 111 | x = hidden_states[mx.newaxis, :, :] 112 | fast_cache = make_prompt_cache(self.model, is_fast=True) 113 | for i in range(0, self.model.max_fast_seqlen): 114 | fast_logits = self.model.forward_generate_fast(x, i, cache=fast_cache) 115 | mx.eval(fast_logits) 116 | 117 | # TODO handle sampling, esp. if it sounds terrible 118 | if ( 119 | fast_temp := self.generation_settings.default_fast_temp 120 | ) is not None and fast_temp > 0: 121 | if self.generation_settings.min_p is not None: 122 | next_token_tensor = min_p_sampling( 123 | fast_logits.squeeze(0), 124 | min_p=self.generation_settings.min_p, 125 | temperature=self.generation_settings.default_fast_temp, 126 | ) 127 | next_token_tensor = next_token_tensor[mx.newaxis, :] 128 | else: 129 | fast_logits = fast_logits / fast_temp 130 | next_token_tensor = mx.random.categorical(fast_logits) 131 | else: 132 | next_token_tensor = mx.argmax(fast_logits, axis=-1) 133 | 134 | # model GETS higher 135 | code = next_token_tensor.flatten()[0] 136 | if self.model.config.depthwise_wte: 137 | offset = i if self.model.config.duplicate_code_0 else i + 1 138 | next_token_tensor += max(0, offset * self.model.config.codebook_size) 139 | 140 | x = self.model.fast_embeddings(next_token_tensor) 141 | codes.append(code) 142 | 143 | codes_tensor = mx.array([slow_token_id, *codes], dtype=mx.uint32)[ 144 | mx.newaxis, :, mx.newaxis 145 | ] 146 | if ( 147 | slow_token_id >= self.model.token_config.semantic_start_id 148 | and self.model.token_config.semantic_end_id is not None 149 | and slow_token_id <= self.model.token_config.semantic_end_id 150 | ): 151 | audio_code = slow_token_id - self.model.token_config.semantic_start_id 152 | codes_arr = ( 153 | codes if self.model.config.duplicate_code_0 else [audio_code, *codes] 154 | ) 155 | audio_tensor = mx.array(codes_arr, dtype=mx.uint32)[ 156 | mx.newaxis, :, mx.newaxis 157 | ] 158 | else: 159 | audio_tensor = None 160 | 161 | self.input_pos += prompt_length if self.input_pos is None else 1 162 | self.prompt = ( 163 | None 164 | if self.audio_only and slow_token_id == self.model.token_config.im_end_id 165 | else codes_tensor 166 | ) 167 | return VQToken( 168 | semantic_code=slow_token_id.tolist(), 169 | audio_codes=audio_tensor, 170 | vq_tensor=codes_tensor, 171 | ) 172 | 173 | 174 | def generate_blocking( 175 | model: RQTransformer, 176 | prompt: mx.array, 177 | generation_settings: GenerationSettings, 178 | audio_only: bool = True, 179 | ) -> mx.array: 180 | prompt_size = prompt.shape[-1] 181 | token_generator = SingleBatchGenerator( 182 | model, 183 | prompt, 184 | generation_settings, 185 | audio_only, 186 | ) 187 | prefill_start_time = time.time() 188 | first_vq_token = next(token_generator) 189 | prefill_end_time = time.time() 190 | prefill_ms = (prefill_end_time - prefill_start_time) * 1000 191 | print( 192 | f"{prefill_ms:3f}ms prompt processing: {prompt_size} tokens ({prompt_size / (prefill_end_time - prefill_start_time):3f} tokens/s)" 193 | ) 194 | 195 | previous_vq_codes = ( 196 | [first_vq_token.audio_codes] if audio_only else [first_vq_token.vq_tensor] 197 | ) 198 | 199 | decode_start_time = time.time() 200 | for maybe_vq_token in tqdm(token_generator): 201 | if audio_only: 202 | if maybe_vq_token.audio_codes is not None: 203 | previous_vq_codes.append(maybe_vq_token.audio_codes) 204 | else: 205 | previous_vq_codes.append(maybe_vq_token.vq_tensor) 206 | decode_end_time = time.time() 207 | decode_duration = decode_end_time - decode_start_time 208 | 209 | out_tokens = mx.concat(previous_vq_codes, axis=-1) 210 | out_len = len(previous_vq_codes) - 1 211 | frame_rate = 12.5 if model.model_type.family == "dual_ar" else 21.535 212 | print( 213 | f"Generated in {decode_duration:.2f}s ({(out_len / decode_duration):.2f} tokens/s, {((decode_duration * 1000) / out_len):.2f}ms/token), {(out_len / frame_rate) / decode_duration:.2f}x realtime" 214 | ) 215 | mx.eval(out_tokens) 216 | return out_tokens 217 | -------------------------------------------------------------------------------- /mlx_inference/src/smoltts_mlx/lm/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndlessReform/smoltts/410d6524966e2d433ce565e4f14f1fac2f983ff9/mlx_inference/src/smoltts_mlx/lm/utils/__init__.py -------------------------------------------------------------------------------- /mlx_inference/src/smoltts_mlx/lm/utils/constraints.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | from typing import List 3 | 4 | from smoltts_mlx.lm.config import ModelType 5 | from smoltts_mlx.lm.rq_transformer import TokenConfig 6 | 7 | 8 | def constrain_logits_to_audio( 9 | x: mx.array, model_type: ModelType, token_config: TokenConfig 10 | ) -> mx.array: 11 | if model_type.family == "dual_ar" or model_type.version == "1.5": 12 | # Base layer uses <|semantic:n|> range up top 13 | if token_config.im_end_id == token_config.semantic_start_id - 1: 14 | # Saves us an indexop 15 | return x[:, :, token_config.im_end_id :] 16 | else: 17 | im_end_prob = x[:, :, token_config.im_end_id] 18 | semantic_token_range = x[:, :, token_config.semantic_start_id :] 19 | # TODO this is probably wrong 20 | return mx.concat([im_end_prob, semantic_token_range], -1) 21 | else: 22 | return x 23 | 24 | 25 | def rescale_semantic_tokens( 26 | tokens: List[int], model_type: ModelType, token_config: TokenConfig 27 | ): 28 | token_range_is_contiguous = ( 29 | token_config.im_end_id == token_config.semantic_start_id - 1 30 | ) 31 | 32 | def rescale_token(token: int) -> int: 33 | if token_range_is_contiguous: 34 | return token + token_config.im_end_id 35 | elif token == 0: 36 | return token_config.im_end_id 37 | else: 38 | return token - 1 + token_config.semantic_start_id 39 | 40 | if model_type.family == "dual_ar" or model_type.version == "1.5": 41 | return [rescale_token(t) for t in tokens] 42 | else: 43 | return tokens 44 | -------------------------------------------------------------------------------- /mlx_inference/src/smoltts_mlx/lm/utils/prompt.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | import tokenizers 3 | from tokenizers import Tokenizer 4 | from typing import Optional 5 | 6 | from smoltts_mlx.lm.config import ModelType 7 | from smoltts_mlx.lm.rq_transformer import RQTransformer 8 | 9 | 10 | class PromptEncoder: 11 | tokenizer: Tokenizer 12 | depth: int 13 | model_type: ModelType 14 | 15 | def __init__( 16 | self, 17 | tokenizer: Tokenizer, 18 | model_type: ModelType, 19 | semantic_offset: int, 20 | num_codebooks: int = 8, 21 | duplicate_code_0: bool = True, 22 | ): 23 | self.tokenizer = tokenizer 24 | self.model_type = model_type 25 | self.depth = num_codebooks if duplicate_code_0 else num_codebooks - 1 26 | self.semantic_offset = semantic_offset 27 | 28 | @classmethod 29 | def from_model(cls, tokenizer: Tokenizer, model: RQTransformer): 30 | return cls( 31 | tokenizer, 32 | num_codebooks=model.config.num_codebooks, 33 | model_type=model.model_type, 34 | semantic_offset=model.token_config.semantic_start_id, 35 | duplicate_code_0=dc0 36 | if (dc0 := model.config.duplicate_code_0) is not None 37 | else True, 38 | ) 39 | 40 | def tokenize_text(self, text: str) -> mx.array: 41 | turn_codes: tokenizers.Encoding = self.tokenizer.encode( 42 | text, add_special_tokens=True 43 | ) 44 | tokens = mx.array(turn_codes.ids, dtype=mx.uint32)[mx.newaxis, :] 45 | zeros = mx.zeros([self.depth, tokens.shape[-1]], dtype=mx.uint32) 46 | return mx.concat([tokens, zeros], axis=0) 47 | 48 | def encode_text_turn(self, role: str, content: Optional[str] = None) -> mx.array: 49 | content_suffix = f"{content}<|im_end|>" if content is not None else "" 50 | turn_string = f"<|im_start|>{role}\n{content_suffix}" 51 | return self.tokenize_text(turn_string) 52 | 53 | def encode_vq(self, codes: mx.array) -> mx.array: 54 | if codes.ndim != 2: 55 | raise ValueError("Must be single batch") 56 | 57 | semantic_line = (codes[0, :] + self.semantic_offset)[mx.newaxis, :] 58 | lower_start = codes.shape[0] - self.depth 59 | lower_codes = codes[lower_start:, :] 60 | vq_block = mx.concat([semantic_line, lower_codes]) 61 | im_end = self.tokenize_text("<|im_end|>\n") 62 | block = mx.concat([vq_block, im_end], axis=1) 63 | return block 64 | -------------------------------------------------------------------------------- /mlx_inference/src/smoltts_mlx/lm/utils/samplers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import mlx.core as mx 3 | from functools import partial 4 | 5 | 6 | # From MLX examples 7 | @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) 8 | def min_p_sampling( 9 | logprobs: mx.array, 10 | min_p: float, 11 | min_tokens_to_keep: int = 1, 12 | temperature=1.0, 13 | ) -> mx.array: 14 | if not (0 <= min_p <= 1.0): 15 | raise ValueError( 16 | f"`min_p` has to be a float in the [0, 1] interval, but is {min_p}" 17 | ) 18 | 19 | logprobs = logprobs * (1 / temperature) 20 | # Sort indices in decreasing order 21 | sorted_indices = mx.argsort(-logprobs).squeeze(0) 22 | sorted_logprobs = logprobs[..., sorted_indices] 23 | # Get top probability 24 | top_logprobs = logprobs[..., sorted_indices] 25 | # Calculate min-p threshold 26 | scaled_min_p = top_logprobs + math.log(min_p) 27 | # Mask tokens below threshold 28 | tokens_to_remove = sorted_logprobs < scaled_min_p 29 | tokens_to_remove[..., :min_tokens_to_keep] = False 30 | # Create filtered token pool 31 | selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs) 32 | # Sample and return token 33 | sorted_token = mx.random.categorical(selected_logprobs) 34 | return sorted_indices[sorted_token] 35 | -------------------------------------------------------------------------------- /mlx_inference/src/smoltts_mlx/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndlessReform/smoltts/410d6524966e2d433ce565e4f14f1fac2f983ff9/mlx_inference/src/smoltts_mlx/scripts/__init__.py -------------------------------------------------------------------------------- /mlx_inference/src/smoltts_mlx/scripts/server.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from contextlib import asynccontextmanager 3 | from fastapi import FastAPI 4 | from fastapi.responses import FileResponse 5 | import uvicorn 6 | import time 7 | 8 | from smoltts_mlx import SmolTTS 9 | from smoltts_mlx.server.tts_core import TTSCore 10 | from smoltts_mlx.server.routes import openai, elevenlabs 11 | from smoltts_mlx.server.settings import ServerSettings 12 | 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--config", type=str, help="Path to config file") 16 | 17 | 18 | @asynccontextmanager 19 | async def lifespan(app: FastAPI): 20 | settings = app.state.settings 21 | checkpoint_dir = settings.get_checkpoint_dir() 22 | 23 | load_start_time = time.time() 24 | print("Loading model...") 25 | model = SmolTTS(checkpoint_dir=checkpoint_dir) 26 | app.state.tts_core = TTSCore( 27 | model=model, 28 | settings=settings, 29 | ) 30 | 31 | load_end_time = time.time() 32 | print(f"Loaded model and config in {load_end_time - load_start_time:.3f} seconds") 33 | 34 | yield 35 | print("shutting down") 36 | 37 | 38 | app = FastAPI(lifespan=lifespan) 39 | app.include_router(openai.router) 40 | app.include_router(elevenlabs.router) 41 | 42 | 43 | @app.get("/") 44 | async def root(): 45 | return FileResponse("static/index.html") 46 | 47 | 48 | def main(): 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument("--config", type=str, help="Path to config file") 51 | parser.add_argument("--port", type=int, help="Port to run on on (default: 8000)") 52 | args = parser.parse_args() 53 | 54 | settings = ServerSettings.get_settings(args.config) 55 | app.state.settings = settings 56 | 57 | port = args.port if args.port is not None else 8000 58 | 59 | uvicorn.run(app, host="0.0.0.0", port=port) 60 | 61 | 62 | if __name__ == "__main__": 63 | main() 64 | -------------------------------------------------------------------------------- /mlx_inference/src/smoltts_mlx/server/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndlessReform/smoltts/410d6524966e2d433ce565e4f14f1fac2f983ff9/mlx_inference/src/smoltts_mlx/server/__init__.py -------------------------------------------------------------------------------- /mlx_inference/src/smoltts_mlx/server/routes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndlessReform/smoltts/410d6524966e2d433ce565e4f14f1fac2f983ff9/mlx_inference/src/smoltts_mlx/server/routes/__init__.py -------------------------------------------------------------------------------- /mlx_inference/src/smoltts_mlx/server/routes/elevenlabs.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, Query, Request, Response 2 | from fastapi.responses import StreamingResponse 3 | from pydantic import BaseModel, Field 4 | from typing import Optional, Literal 5 | 6 | router = APIRouter(prefix="/v1", tags=["ElevenLabs"]) 7 | 8 | 9 | class CreateSpeechRequest(BaseModel): 10 | text: str 11 | model_id: Optional[str] = Field(default=None) 12 | 13 | 14 | @router.post("/text-to-speech/{voice_id}") 15 | async def text_to_speech_blocking( 16 | voice_id: str, 17 | item: CreateSpeechRequest, 18 | http_request: Request, 19 | output_format: Optional[str] = Query( 20 | None, description="Desired output format. No MP3 support" 21 | ), 22 | ): 23 | core = http_request.app.state.tts_core 24 | content, media_type = core.generate_audio( 25 | input_text=item.text, 26 | voice=voice_id, 27 | response_format=output_format, 28 | ) 29 | 30 | return Response( 31 | content=content, 32 | media_type=media_type, # or audio/l16 for 16-bit PCM 33 | headers={ 34 | "Content-Disposition": f'attachment; filename="elevenlabs_speech.{output_format.split("_")[0]}"', 35 | "X-Sample-Rate": output_format.split("_")[1] 36 | if output_format is not None 37 | else "24000", 38 | }, 39 | ) 40 | 41 | 42 | @router.post("/text-to-speech/{voice_id}/stream") 43 | async def stream_tts( 44 | voice_id: str, 45 | item: CreateSpeechRequest, 46 | http_request: Request, 47 | output_format: Literal["pcm_24000"] = "pcm_24000", 48 | ): 49 | core = http_request.app.state.tts_core 50 | 51 | def generate(): 52 | for audio_chunk in core.stream_audio(item.text, voice=voice_id): 53 | yield audio_chunk 54 | 55 | return StreamingResponse( 56 | generate(), 57 | media_type="audio/mpeg" if output_format.startswith("mp3_") else "audio/wav", 58 | headers={ 59 | "Content-Disposition": f'attachment; filename="speech.{output_format.split("_")[0]}"', 60 | # "X-Sample-Rate": output_format.split("_")[1], 61 | "X-Sample-Rate": "24000", 62 | }, 63 | ) 64 | -------------------------------------------------------------------------------- /mlx_inference/src/smoltts_mlx/server/routes/openai.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, Request, Response 2 | from pydantic import BaseModel, Field 3 | from typing import Literal, Union 4 | 5 | 6 | class SpeechRequest(BaseModel): 7 | model: str = Field(default="tts-1-hd") 8 | input: str 9 | voice: Union[str, int] = Field(default="alloy") 10 | response_format: Literal["wav"] = Field(default="wav") 11 | 12 | 13 | router = APIRouter(prefix="/v1", tags=["OpenAI"]) 14 | 15 | 16 | @router.post("/audio/speech") 17 | async def openai_speech(item: SpeechRequest, http_request: Request): 18 | core = http_request.app.state.tts_core 19 | audio_data, media_type = core.generate_audio( 20 | input_text=item.input, 21 | voice=item.voice, 22 | response_format=item.response_format + "_24000", 23 | ) 24 | return Response( 25 | audio_data, 26 | media_type=media_type, 27 | headers={"Content-Disposition": 'attachment; filename="speech.wav"'}, 28 | ) 29 | -------------------------------------------------------------------------------- /mlx_inference/src/smoltts_mlx/server/settings.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import snapshot_download 2 | import json 3 | import os 4 | from pathlib import Path 5 | from pydantic import BaseModel, model_validator 6 | from typing import Optional 7 | 8 | from smoltts_mlx.lm.config import ModelType 9 | from smoltts_mlx.lm.generate import GenerationSettings 10 | 11 | 12 | class ServerSettings(BaseModel): 13 | model_id: Optional[str] = None 14 | checkpoint_dir: Optional[str] = None 15 | generation: GenerationSettings 16 | model_type: ModelType 17 | 18 | @model_validator(mode="after") 19 | def validate_model_source(self): 20 | if self.model_id is not None and self.checkpoint_dir is not None: 21 | raise ValueError("Cannot specify both model_id and checkpoint_dir") 22 | if self.model_id is None and self.checkpoint_dir is None: 23 | raise ValueError("Must specify either model_id or checkpoint_dir") 24 | 25 | return self 26 | 27 | @classmethod 28 | def get_settings(cls, config_path: Optional[str] = None) -> "ServerSettings": 29 | """Get settings from config file or create default in cache dir.""" 30 | default_settings = { 31 | "model_id": "jkeisling/smoltts_v0", 32 | "model_type": {"family": "dual_ar", "codec": "mimi", "version": None}, 33 | "generation": { 34 | "default_temp": 0.5, 35 | "default_fast_temp": 0.0, 36 | "min_p": 0.10, 37 | "max_new_tokens": 1024, 38 | }, 39 | } 40 | 41 | if config_path: 42 | with open(config_path) as f: 43 | return cls(**json.loads(f.read())) 44 | # Use macOS cache dir 45 | config_dir = Path(os.path.expanduser("~/Library/Caches/smolltts/settings")) 46 | config_path = config_dir / "config.json" 47 | 48 | config_dir.mkdir(parents=True, exist_ok=True) 49 | if not config_path.exists(): 50 | with open(config_path, "w") as f: 51 | json.dump(default_settings, f, indent=2) 52 | return cls(**default_settings) 53 | 54 | with open(config_path) as f: 55 | return cls(**json.loads(f.read())) 56 | 57 | def get_checkpoint_dir(self) -> Path: 58 | if self.checkpoint_dir is not None: 59 | return Path(self.checkpoint_dir) 60 | else: 61 | # guaranteed to exist by validator above 62 | hf_repo_path = snapshot_download(self.model_id) 63 | return Path(hf_repo_path) 64 | -------------------------------------------------------------------------------- /mlx_inference/src/smoltts_mlx/server/tts_core.py: -------------------------------------------------------------------------------- 1 | import io 2 | import mlx.core as mx 3 | import numpy as np 4 | from pydub import AudioSegment 5 | from scipy import signal 6 | import soundfile as sf 7 | import time 8 | from typing import Union 9 | from tqdm import tqdm 10 | 11 | from smoltts_mlx import SmolTTS 12 | from smoltts_mlx.io.wav import pcm_to_wav_bytes 13 | 14 | 15 | class TTSCore: 16 | def __init__(self, model: SmolTTS, settings): 17 | self.model = model 18 | self.settings = settings 19 | 20 | def resolve_speaker_id(self, voice: Union[str, int]) -> int: 21 | # TODO: Fix speaker cache 22 | if isinstance(voice, int): 23 | return voice 24 | elif isinstance(voice, str) and voice.isnumeric(): 25 | return int(voice) 26 | return 0 27 | 28 | def generate_audio( 29 | self, input_text: str, voice: Union[str, int], response_format: str = "wav" 30 | ): 31 | pcm_data = self.model(input_text, str(voice)) 32 | 33 | start_time = time.time() 34 | audio_data, media_type = self.format_audio_chunk( 35 | pcm_data.flatten(), response_format 36 | ) 37 | end_time = time.time() 38 | print(f"Took {end_time - start_time:.2f}s to transcode") 39 | mx.metal.clear_cache() 40 | 41 | return audio_data, media_type 42 | 43 | def stream_audio(self, input_text: str, voice: Union[str, int]): 44 | for pcm_chunk in tqdm(self.model.stream(input_text, str(voice))): 45 | if pcm_chunk is not None: 46 | audio_data = pcm_chunk.tobytes() 47 | yield audio_data 48 | 49 | def format_audio_chunk( 50 | self, pcm_data: np.ndarray, output_format: str = "pcm_24000" 51 | ) -> tuple[bytes, str]: 52 | """Format a chunk of PCM data into the requested format. 53 | Returns (formatted_bytes, media_type)""" 54 | sample_rate = int(output_format.split("_")[1]) 55 | pcm_data = pcm_data.flatten() 56 | 57 | # Resample if needed 58 | if sample_rate != 24000: 59 | num_samples = int(len(pcm_data) * sample_rate / 24000) 60 | pcm_data = signal.resample(pcm_data, num_samples) 61 | 62 | # Convert to 16-bit PCM first 63 | mem_buf = io.BytesIO() 64 | sf.write(mem_buf, pcm_data, sample_rate, format="raw", subtype="PCM_16") 65 | pcm_bytes = bytes(mem_buf.getbuffer()) 66 | 67 | if output_format.startswith("pcm_"): 68 | return pcm_bytes, "audio/x-pcm" 69 | elif output_format.startswith("wav_"): 70 | wav_bytes = pcm_to_wav_bytes(pcm_data=pcm_data, sample_rate=sample_rate) 71 | return wav_bytes, "audio/wav" 72 | elif output_format.startswith("mp3_"): 73 | bitrate = output_format.split("_")[-1] 74 | audio = AudioSegment( 75 | data=pcm_bytes, 76 | sample_width=2, 77 | frame_rate=sample_rate, 78 | channels=1, 79 | ) 80 | out_buf = io.BytesIO() 81 | audio.export(out_buf, format="mp3", bitrate=f"{bitrate}k") 82 | return out_buf.getvalue(), "audio/mpeg" 83 | else: 84 | raise NotImplementedError(f"Format {output_format} not yet supported") 85 | -------------------------------------------------------------------------------- /mlx_inference/static/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | PCM TTS Player 6 | 11 | 12 | 13 |

PCM TTS Player

14 |

15 | POST http://localhost:8000/v1/text-to-speech/1/stream?output_format=pcm_24000 16 |

17 |

18 | { 19 | "text": "Further optimization is chasing a dead end. It's time to finally bite the bullet and understand flow matching." 20 | } 21 |

22 | 25 |
26 | 27 | 28 | 81 | 82 | -------------------------------------------------------------------------------- /mlx_inference/tests/compare_npy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import mlx.core as mx 3 | import numpy as np 4 | 5 | parser = argparse.ArgumentParser( 6 | description="Compare two .npy files for numerical similarity" 7 | ) 8 | parser.add_argument("file1", type=str, help="Path to first .npy file") 9 | parser.add_argument("file2", type=str, help="Path to second .npy file") 10 | 11 | 12 | def main(): 13 | args = parser.parse_args() 14 | 15 | arr1_np = np.load(args.file1) 16 | arr2_np = np.load(args.file2) 17 | 18 | arr1 = mx.array(arr1_np) 19 | arr2 = mx.array(arr2_np) 20 | 21 | if arr1.shape != arr2.shape: 22 | print(f"Shape mismatch: {arr1.shape} vs {arr2.shape}") 23 | return 24 | 25 | # Check numerical similarity 26 | is_close = mx.allclose(arr1, arr2, rtol=1e-3, atol=1e-3) 27 | mx.eval(is_close) 28 | max_diff = mx.abs(arr1 - arr2).max() 29 | if is_close: 30 | print("Arrays match within tolerance") 31 | else: 32 | # If they don't match, might be helpful to see the max difference 33 | print("Arrays differ.") 34 | print(f"Max absolute difference: {max_diff}") 35 | 36 | 37 | if __name__ == "__main__": 38 | main() 39 | -------------------------------------------------------------------------------- /mlx_inference/tests/sky.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndlessReform/smoltts/410d6524966e2d433ce565e4f14f1fac2f983ff9/mlx_inference/tests/sky.wav -------------------------------------------------------------------------------- /mlx_inference/tests/test_decoder.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | from datasets import load_dataset 3 | import numpy as np 4 | 5 | from smoltts_mlx.codec.mimi import load_mimi 6 | from smoltts_mlx.io.wav import pcm_to_wav_bytes 7 | from smoltts_mlx.lm.cache import make_prompt_cache 8 | 9 | 10 | def main(): 11 | dataset = load_dataset("jkeisling/libritts-r-mimi") 12 | dataset = dataset.with_format("numpy") 13 | arr = mx.array(dataset["dev.clean"][10]["codes"]) 14 | test_input = arr[mx.newaxis, :, :] 15 | 16 | model = load_mimi() 17 | print("Model loaded") 18 | 19 | # start_time = time.time() 20 | 21 | # dont worry about 1: from the full TTS, it's audio-only here 22 | quantized = model.quantizer.decode(test_input) 23 | 24 | embeddings = model.upsample(quantized) 25 | transformed = model.decoder_transformer(embeddings) 26 | mx.eval(transformed) 27 | # decoded = model.decode(test_input, None) 28 | # mx.eval(decoded) 29 | print(f"TRANSFORMED: {transformed.shape}") 30 | 31 | all_pcm_conv_frame = [] 32 | # upsample doubles the frame rate, so we need to match the actual streaming 33 | for frame in mx.split( 34 | transformed, axis=-2, indices_or_sections=transformed.shape[-2] // 2 35 | ): 36 | step = model.decoder.step(frame) 37 | if step is not None: 38 | # out = model.decoder(frame) 39 | all_pcm_conv_frame.append(mx.swapaxes(step, 1, 2)) 40 | else: 41 | print("Skipping step") 42 | 43 | model.decoder.reset() 44 | 45 | decoded = mx.concat(all_pcm_conv_frame, axis=-1) 46 | 47 | # end_time = time.time() 48 | # elapsed_time = end_time - start_time 49 | 50 | print("Done") 51 | print(f"Decoded shape: {decoded.shape}") 52 | # print(f"Elapsed time: {(elapsed_time * 1000):.3f} ms") 53 | wav_bytes = pcm_to_wav_bytes(np.array(decoded)) 54 | with open("output_conv.wav", "wb") as f: 55 | f.write(wav_bytes) 56 | 57 | print("Testing TRANSFORMER + RVQ + UPSAMPLE") 58 | frames = mx.split(test_input, axis=-1, indices_or_sections=test_input.shape[-1]) 59 | print(frames[0].shape) 60 | quantized_incremental = [model.quantizer.decode(arr) for arr in frames] 61 | upsampled_embeddings = [model.upsample(x) for x in quantized_incremental] 62 | # embeddings = model.upsample(quantized_incremental) 63 | cache = make_prompt_cache(model.decoder_transformer) 64 | all_transformer_out = [] 65 | # for frame in mx.split( 66 | # embeddings, axis=-2, indices_or_sections=embeddings.shape[-2] // 2 67 | # ): 68 | for frame in upsampled_embeddings: 69 | emb = model.decoder_transformer(frame, cache=cache) 70 | all_transformer_out.append(emb) 71 | transformed_incremental = mx.concat(all_transformer_out, axis=-2) 72 | decoded_t = mx.swapaxes(model.decoder(transformed_incremental), 1, 2) 73 | 74 | # reference = np.load("final.npy") 75 | wav_bytes = pcm_to_wav_bytes(np.array(decoded_t)) 76 | with open("output_t_incremental.wav", "wb") as f: 77 | f.write(wav_bytes) 78 | 79 | 80 | if __name__ == "__main__": 81 | main() 82 | -------------------------------------------------------------------------------- /mlx_inference/tests/test_encoder.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | # import numpy as np 3 | 4 | from smoltts_mlx.codec.mimi import load_mimi 5 | # from smoltts_mlx.io.wav import pcm_to_wav_bytes 6 | # from smoltts_mlx.lm.cache import make_prompt_cache 7 | 8 | 9 | def main(): 10 | # arr = mx.array(dataset["dev.clean"][10]["codes"]) 11 | # test_input = arr[mx.newaxis, :, :] 12 | # 2s of audio 13 | test_input = mx.zeros(48_000)[mx.newaxis, :, mx.newaxis] 14 | 15 | model = load_mimi() 16 | print("Model loaded") 17 | 18 | # dont worry about 1: from the full TTS, it's audio-only here 19 | embedded = model.encoder(test_input) 20 | transformed = model.encoder_transformer(embedded) 21 | downsampled = model.downsample(transformed) 22 | codes = model.quantizer.encode(downsampled) 23 | mx.save("out.npy", mx.swapaxes(codes, 0, 1)) 24 | 25 | 26 | if __name__ == "__main__": 27 | main() 28 | -------------------------------------------------------------------------------- /mlx_inference/tests/test_generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import mlx.core as mx 3 | from pathlib import Path 4 | import time 5 | from tokenizers import Tokenizer 6 | 7 | from smoltts_mlx.lm.rq_transformer import ( 8 | RQTransformerModelArgs, 9 | RQTransformer, 10 | TokenConfig, 11 | ) 12 | from smoltts_mlx.lm.config import ModelType 13 | from smoltts_mlx.lm.generate import SingleBatchGenerator 14 | 15 | parser = argparse.ArgumentParser( 16 | description="A simple one-off CLI generator for DualAR models" 17 | ) 18 | parser.add_argument("--text", type=str, default="Hello world!") 19 | parser.add_argument("--speaker", type=int, default=0) 20 | parser.add_argument("--checkpoint", type=str, default="./inits/smoltts_byte_reference") 21 | 22 | 23 | def main(): 24 | args = parser.parse_args() 25 | checkpoint_dir = Path(args.checkpoint) 26 | model_type = ModelType(family="dual_ar", version=None, codec="mimi") 27 | 28 | load_start_time = time.time() 29 | config = RQTransformerModelArgs.from_json_file(str(checkpoint_dir / "config.json")) 30 | tokenizer = Tokenizer.from_file(str(checkpoint_dir / "tokenizer.json")) 31 | token_config = TokenConfig.from_tokenizer( 32 | model=model_type, tokenizer=tokenizer, config=config 33 | ) 34 | 35 | model = RQTransformer(config, token_config, model_type) 36 | model_path = str(checkpoint_dir / "model.safetensors") 37 | model.load_weights(model_path, strict=True) 38 | # model = model.apply(lambda p: p.astype(mx.float32)) 39 | mx.eval(model.parameters()) 40 | model.eval() 41 | load_end_time = time.time() 42 | print(f"Loaded model and config in {load_end_time - load_start_time:.3f} seconds") 43 | 44 | # Initialize cache 45 | prompt = mx.zeros([1, 9, 32], mx.uint32) 46 | trace_file = "mlx_trace.gputrace" 47 | mx.metal.start_capture(trace_file) 48 | # prompt_encoder = PromptEncoder.from_model(tokenizer, model) 49 | # sysprompt = prompt_encoder.encode_text_turn("system", f"<|speaker:{args.speaker}|>") 50 | # user_prompt = prompt_encoder.encode_text_turn("user", args.text) 51 | # assistant_prefix = prompt_encoder.encode_text_turn("assistant") 52 | # print([p.shape for p in [sysprompt, user_prompt, assistant_prefix]]) 53 | # prompt = mx.concat([sysprompt, user_prompt, assistant_prefix], axis=1)[ 54 | # mx.newaxis, :, : 55 | # ] 56 | generator = SingleBatchGenerator(model, prompt, audio_only=True) 57 | next(generator) 58 | next(generator) 59 | 60 | mx.metal.stop_capture() 61 | 62 | 63 | if __name__ == "__main__": 64 | main() 65 | -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndlessReform/smoltts/410d6524966e2d433ce565e4f14f1fac2f983ff9/modeling/__init__.py -------------------------------------------------------------------------------- /modeling/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndlessReform/smoltts/410d6524966e2d433ce565e4f14f1fac2f983ff9/modeling/model/__init__.py -------------------------------------------------------------------------------- /modeling/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndlessReform/smoltts/410d6524966e2d433ce565e4f14f1fac2f983ff9/modeling/utils/__init__.py -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "smoltts" 3 | version = "0.1.0" 4 | description = "A repo for training DualAR speech models" 5 | readme = "README.md" 6 | authors = [{ name = "Jacob Keisling", email = "jacob@keisling.me" }] 7 | requires-python = ">=3.9" 8 | dependencies = [ 9 | "datasets>=3.2.0", 10 | "einops>=0.8.0", 11 | "huggingface-hub>=0.27.1", 12 | "librosa>=0.10.2.post1", 13 | "python-dotenv>=1.0.1", 14 | "safetensors>=0.5.2", 15 | "setuptools>=75.8.0", 16 | "soundfile>=0.13.0", 17 | "torch>=2.5.1", 18 | "torchaudio>=2.5.1", 19 | "transformers>=4.48.0", 20 | "wandb>=0.19.4", 21 | ] 22 | license = { file = "LICENSE" } 23 | 24 | [tool.setuptools] 25 | packages = ["data_pipeline", "modeling", "train", "mlx_inference"] 26 | 27 | [tool.pyright] 28 | venvPath = "." 29 | venv = ".venv" 30 | 31 | [tool.uv.workspace] 32 | members = ["mlx_inference", "train"] 33 | 34 | [dependency-groups] 35 | dev = ["ipykernel>=6.29.5"] 36 | -------------------------------------------------------------------------------- /sample_model_sizes/smoltts_byte_150m.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_qkv_bias": false, 3 | "codebook_size": 2048, 4 | "dim": 768, 5 | "dropout": 0.1, 6 | "fast_attention_qkv_bias": false, 7 | "fast_dim": 768, 8 | "fast_head_dim": 64, 9 | "fast_intermediate_size": 3072, 10 | "fast_n_head": 12, 11 | "fast_n_local_heads": 4, 12 | "head_dim": 64, 13 | "initializer_range": 0.041666666666666664, 14 | "intermediate_size": 3072, 15 | "is_reward_model": false, 16 | "max_seq_len": 2048, 17 | "model_type": "dual_ar", 18 | "n_fast_layer": 4, 19 | "n_head": 12, 20 | "n_layer": 10, 21 | "n_local_heads": 4, 22 | "depthwise_wte": true, 23 | "depthwise_output": true, 24 | "norm_eps": 1e-5, 25 | "num_codebooks": 8, 26 | "rope_base": 100000, 27 | "scale_codebook_embeddings": false, 28 | "share_codebook_embeddings": true, 29 | "tie_word_embeddings": true, 30 | "use_gradient_checkpointing": true, 31 | "vocab_size": 2368 32 | } -------------------------------------------------------------------------------- /sample_model_sizes/smoltts_byte_70m.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_qkv_bias": false, 3 | "codebook_size": 2048, 4 | "dim": 576, 5 | "dropout": 0.1, 6 | "fast_attention_qkv_bias": false, 7 | "fast_dim": 576, 8 | "fast_head_dim": 64, 9 | "fast_intermediate_size": 1536, 10 | "fast_n_head": 9, 11 | "fast_n_local_heads": 3, 12 | "head_dim": 64, 13 | "initializer_range": 0.041666666666666664, 14 | "intermediate_size": 1536, 15 | "is_reward_model": false, 16 | "max_seq_len": 2048, 17 | "model_type": "dual_ar", 18 | "n_fast_layer": 4, 19 | "n_head": 9, 20 | "n_layer": 10, 21 | "n_local_heads": 3, 22 | "depthwise_wte": true, 23 | "depthwise_output": true, 24 | "norm_eps": 1e-05, 25 | "num_codebooks": 8, 26 | "rope_base": 100000, 27 | "scale_codebook_embeddings": false, 28 | "share_codebook_embeddings": true, 29 | "tie_word_embeddings": true, 30 | "use_gradient_checkpointing": true, 31 | "vocab_size": 2368 32 | } -------------------------------------------------------------------------------- /train/README.md: -------------------------------------------------------------------------------- 1 | # Pretraining 2 | 3 | ## Config format 4 | 5 | Create your own run config in the `../configs/yourmodel` folder. 6 | The options should be pretty self-explanatory. 7 | 8 | ## Starting a run 9 | 10 | Training is currently only tested with CUDA and Linux. For example: 11 | 12 | ```bash 13 | uv run main.py --config ../config/kokoro_v1/scaleup.json 14 | ``` 15 | 16 | Artifacts will be saved to the `../checkpoints/` folder under a run ID. 17 | 18 | ## Extracting `model.safetensors` from checkpoint 19 | 20 | ```bash 21 | # Replace with whatever run you want 22 | uv run convert_safetensors.py ../config/your-run-id-here/step_somestep.pt 23 | ``` 24 | 25 | will save `model.safetensors` to this folder. 26 | 27 | This will be improved. 28 | -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndlessReform/smoltts/410d6524966e2d433ce565e4f14f1fac2f983ff9/train/__init__.py -------------------------------------------------------------------------------- /train/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pydantic import BaseModel 3 | from typing import Tuple, Optional 4 | 5 | 6 | class TrainingConfig(BaseModel): 7 | # Core paths and identifiers 8 | project_name: str = "ljspeech_train" 9 | checkpoint_path: str = "checkpoints" 10 | model_path: str = "pretrained_model" 11 | dataset_path: str 12 | init_folder: str 13 | 14 | # Training params 15 | batch_size: int = 8 16 | max_epochs: int = 10 17 | num_workers: int = 4 18 | gradient_clip: float = 1.0 19 | accumulate_steps: int = 1 20 | 21 | # Optimizer settings 22 | learning_rate: float = 1e-4 23 | lr_start: float = 1e-3 24 | lr_warmup_steps: int = 3000 25 | weight_decay: float = 0.0 26 | betas: Tuple[float, float] = (0.9, 0.95) 27 | eps: float = 1e-5 28 | 29 | # Validation & Checkpointing 30 | val_every_n_steps: int = 100 31 | save_every_n_steps: int = 500 32 | 33 | # Model/Data params 34 | max_sequence_length: int = 896 # Much smaller than original 4096 for LJSpeech 35 | use_bf16: bool = True 36 | use_wandb: bool = False 37 | use_pretrained: bool = True 38 | 39 | 40 | def load_config(path: str) -> TrainingConfig: 41 | with open(path) as f: 42 | config_dict = json.load(f) 43 | return TrainingConfig(**config_dict) 44 | -------------------------------------------------------------------------------- /train/convert_safetensors.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from safetensors.torch import save_file 3 | import torch 4 | 5 | 6 | def main(): 7 | data = torch.load(sys.argv[1]) 8 | model = data["model_state_dict"] 9 | model = {key.replace("_orig_mod.", ""): value for key, value in model.items()} 10 | if model["fast_output.weight"].ndim == 3: 11 | print(f"Flattening 3D output projection {model['fast_output.weight'].shape}") 12 | w = model["fast_output.weight"] # [codebooks, hidden_dim, codebook_size] 13 | model["fast_output.weight"] = model["fast_output.weight"] = ( 14 | w.permute(1, 0, 2).reshape(768, -1).T.contiguous() 15 | ) 16 | save_file(model, "model.safetensors") 17 | 18 | 19 | if __name__ == "__main__": 20 | main() 21 | -------------------------------------------------------------------------------- /train/data.py: -------------------------------------------------------------------------------- 1 | from datasets import load_from_disk, Dataset 2 | from typing import Tuple 3 | import torch 4 | 5 | 6 | def load_splits(path: str, max_sequence_len: int = 768) -> Tuple[Dataset, Dataset]: 7 | """ 8 | Returns (train, val) datasets 9 | """ 10 | TEST_SIZE = 10_000 11 | print(f"Loading dataset from {path}") 12 | dataset = load_from_disk(path) 13 | dataset = dataset.with_format("torch") 14 | if isinstance(dataset, Dataset): 15 | dataset = dataset.train_test_split(test_size=TEST_SIZE) 16 | print(f"Keys: {dataset.keys()}") 17 | if "full" in (splits := list(dataset.keys())): 18 | dataset = dataset["full"].shuffle() 19 | split_dataset = dataset.train_test_split(test_size=TEST_SIZE) 20 | train_dataset = split_dataset["train"] 21 | val_dataset = split_dataset["test"] 22 | elif "val" in splits: 23 | train_dataset = dataset["train"].shuffle(42) 24 | val_dataset = dataset["val"] 25 | elif "test" in splits: 26 | train_dataset = dataset["train"].shuffle(42) 27 | val_dataset = dataset["test"] 28 | else: 29 | dataset.shuffle(42) 30 | split_dataset = dataset["train"].train_test_split(test_size=TEST_SIZE) 31 | train_dataset = split_dataset["train"] 32 | val_dataset = split_dataset["test"] 33 | 34 | print(train_dataset.column_names) 35 | print(val_dataset.column_names) 36 | return train_dataset, val_dataset 37 | 38 | 39 | def collate_fn( 40 | batch, semantic_pad_id: int, duplicate_code_0: bool = True, codebook_size: int = 8 41 | ): 42 | """ 43 | batch is a list of dicts: each dict has "tokens" shape [9, T], 44 | and "labels" shape [9, T]. 45 | We pad them into [B, 9, T_max]. 46 | """ 47 | # TODO handle >8 codebooks 48 | height = codebook_size + (1 if duplicate_code_0 else 0) 49 | max_input_len = max(item["ground_truth"].shape[1] - 1 for item in batch) 50 | 51 | B = len(batch) 52 | # We'll create padded arrays: 53 | tokens = torch.full((B, height, max_input_len), 0, dtype=torch.long) # 2=some 54 | tokens[:, 0, :] = semantic_pad_id 55 | labels = torch.full( 56 | (B, height, max_input_len), -100, dtype=torch.long 57 | ) # default is ignore_index 58 | 59 | pad_mask = torch.ones(B, max_input_len) 60 | 61 | for i, item in enumerate(batch): 62 | seq_len = item["ground_truth"].shape[1] - 1 63 | tokens[i, :, :seq_len] = item["ground_truth"][:, :-1].clone() 64 | 65 | label = item["ground_truth"][:, 1:] 66 | text_only_mask = label[1:, :] == 0 67 | label[1:, :][text_only_mask] = -100 68 | labels[i, :, :seq_len] = label 69 | 70 | pad_mask[i, :seq_len] = False 71 | 72 | return {"tokens": tokens, "labels": labels, "pad_mask": pad_mask} 73 | -------------------------------------------------------------------------------- /train/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from train.config import load_config 5 | from train.data import load_splits 6 | from train.optim import setup_training 7 | from train.state import CheckpointManager 8 | from train.trainer import train 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument( 12 | "--checkpoint", 13 | type=str, 14 | help="Path to checkpoint file to resume from", 15 | ) 16 | parser.add_argument( 17 | "--config", 18 | type=str, 19 | help="Path to config file to resume from", 20 | ) 21 | args = parser.parse_args() 22 | 23 | 24 | def main(): 25 | # Requiring config now 26 | config = load_config(args.config) 27 | train_ds, val_ds = load_splits(config.dataset_path) 28 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 29 | 30 | # Initialize state management 31 | checkpoint_manager = CheckpointManager(config.checkpoint_path) 32 | 33 | # Load or create model and training state 34 | if args.checkpoint: 35 | state = checkpoint_manager.load_checkpoint(args.checkpoint, config, device) 36 | else: 37 | state = checkpoint_manager.init_model(config, device) 38 | 39 | # Setup optimizer and scheduler 40 | optimizer, scheduler = setup_training(state.model, config, state.global_step) 41 | 42 | # Environment setup 43 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 44 | print(f"Dropout: {state.model.layers[0].attention.dropout}") 45 | 46 | # Start training 47 | train( 48 | state.model, 49 | train_ds, 50 | val_ds, 51 | config, 52 | device, 53 | optimizer, 54 | scheduler, 55 | checkpoint_manager, 56 | state.start_epoch, 57 | state.global_step, 58 | ) 59 | 60 | 61 | if __name__ == "__main__": 62 | main() 63 | -------------------------------------------------------------------------------- /train/optim.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | from torch.optim import AdamW 3 | from torch.optim.lr_scheduler import LambdaLR 4 | import torch.nn as nn 5 | from train.config import TrainingConfig 6 | 7 | 8 | def partition_params( 9 | model: nn.Module, 10 | ) -> Tuple[List[nn.Parameter], List[nn.Parameter]]: 11 | """Split params into decay/no-decay groups""" 12 | weight_decay_params, no_decay_params = [], [] 13 | for name, param in sorted(model.named_parameters()): 14 | if param.requires_grad: # Only include trainable params 15 | if ".bias" in name or "norm.weight" in name or ".embeddings." in name: 16 | no_decay_params.append(param) 17 | else: 18 | weight_decay_params.append(param) 19 | 20 | print(f"Weight decay params: {len(weight_decay_params)}") 21 | print(f"No decay params: {len(no_decay_params)}") 22 | return weight_decay_params, no_decay_params 23 | 24 | 25 | def create_optimizer( 26 | model: nn.Module, 27 | config: TrainingConfig, 28 | ) -> AdamW: 29 | """Create optimizer with proper weight decay grouping""" 30 | weight_decay_params, no_decay_params = partition_params(model) 31 | 32 | return AdamW( 33 | [ 34 | {"params": weight_decay_params, "weight_decay": config.weight_decay}, 35 | {"params": no_decay_params, "weight_decay": 0.0}, 36 | ], 37 | lr=config.learning_rate, 38 | betas=config.betas, 39 | eps=config.eps, 40 | ) 41 | 42 | 43 | def get_lr_scheduler( 44 | optimizer: AdamW, config: TrainingConfig, warmup_start_step: int = 0 45 | ) -> LambdaLR: 46 | """Creates scheduler with linear warmup""" 47 | 48 | def lr_lambda(current_step: int): 49 | relative_step = current_step - warmup_start_step 50 | if relative_step < config.lr_warmup_steps: 51 | # Linear decay from lr_start to learning_rate 52 | progress = float(relative_step) / float(max(1, config.lr_warmup_steps)) 53 | return config.lr_start / config.learning_rate * (1.0 - progress) + progress 54 | return 1.0 # Return to base learning_rate 55 | 56 | return LambdaLR(optimizer, lr_lambda) 57 | 58 | 59 | def setup_training( 60 | model: nn.Module, 61 | config: TrainingConfig, 62 | global_step: int = 0, 63 | ) -> Tuple[AdamW, LambdaLR]: 64 | """One-shot creation of optimizer and scheduler""" 65 | optimizer = create_optimizer(model, config) 66 | scheduler = get_lr_scheduler(optimizer, config, global_step) 67 | 68 | if global_step > 0: 69 | # Initialize scheduler with current step 70 | scheduler.last_epoch = ( 71 | global_step - 1 72 | ) # -1 because scheduler steps once on first call 73 | 74 | return optimizer, scheduler 75 | -------------------------------------------------------------------------------- /train/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "train" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.9" 7 | dependencies = [] 8 | -------------------------------------------------------------------------------- /train/state.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from pathlib import Path 3 | from typing import NamedTuple, Optional 4 | import torch 5 | from torch.optim import Optimizer 6 | from torch.optim.lr_scheduler import LRScheduler 7 | from modeling.model.rq_transformer import RQTransformer 8 | from train.config import TrainingConfig 9 | 10 | 11 | class TrainingState(NamedTuple): 12 | model: RQTransformer 13 | optimizer: Optional[Optimizer] 14 | scheduler: Optional[LRScheduler] 15 | start_epoch: int 16 | global_step: int 17 | 18 | 19 | class CheckpointManager: 20 | def __init__(self, base_directory: str, keep_last_n: int = 5): 21 | self.base_dir = Path(base_directory) 22 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 23 | self.run_dir = self.base_dir / f"run_{timestamp}" 24 | self.run_dir.mkdir(parents=True, exist_ok=True) 25 | self.keep_last_n = keep_last_n 26 | print(f"Checkpoint directory for this run: {self.run_dir}") 27 | 28 | def load_checkpoint( 29 | self, checkpoint_path: str, config: TrainingConfig, device: torch.device 30 | ) -> TrainingState: 31 | """Load a checkpoint and return initialized training state""" 32 | checkpoint = torch.load(checkpoint_path, map_location=device) 33 | checkpoint_config = TrainingConfig(**checkpoint["config"]) 34 | 35 | # Check for config mismatches that would affect optimizer/scheduler 36 | optimizer_keys = ["learning_rate", "weight_decay", "betas", "eps"] 37 | scheduler_keys = ["lr_start", "lr_warmup_steps"] 38 | 39 | optimizer_changed = any( 40 | getattr(config, key) != getattr(checkpoint_config, key) 41 | for key in optimizer_keys 42 | ) 43 | scheduler_changed = any( 44 | getattr(config, key) != getattr(checkpoint_config, key) 45 | for key in scheduler_keys 46 | ) 47 | 48 | if optimizer_changed or scheduler_changed: 49 | print("Detected changes in optimization parameters:") 50 | if optimizer_changed: 51 | print("Optimizer changes:") 52 | for key in optimizer_keys: 53 | old_val = getattr(checkpoint_config, key) 54 | new_val = getattr(config, key) 55 | if old_val != new_val: 56 | print(f" {key}: {old_val} -> {new_val}") 57 | if scheduler_changed: 58 | print("Scheduler changes:") 59 | for key in scheduler_keys: 60 | old_val = getattr(checkpoint_config, key) 61 | new_val = getattr(config, key) 62 | if old_val != new_val: 63 | print(f" {key}: {old_val} -> {new_val}") 64 | print("Will reinitialize optimizer and scheduler with new settings") 65 | 66 | # Load model with original architecture but override weights 67 | model = RQTransformer.from_pretrained( 68 | config.init_folder, 69 | load_weights=False, 70 | ) 71 | model_state_dict = { 72 | k.replace("_orig_mod.", ""): v 73 | for k, v in checkpoint["model_state_dict"].items() 74 | } 75 | model.load_state_dict(model_state_dict) 76 | 77 | # Move model to device and dtype 78 | model = model.to(device) 79 | model = model.to(torch.bfloat16) 80 | 81 | return TrainingState( 82 | model=model, 83 | optimizer=None, # Caller will reinit if needed 84 | scheduler=None, # Caller will reinit if needed 85 | start_epoch=checkpoint["epoch"], 86 | global_step=checkpoint["global_step"], 87 | ) 88 | 89 | def init_model(self, config: TrainingConfig, device: torch.device) -> TrainingState: 90 | """Create a fresh model and training state""" 91 | model = RQTransformer.from_pretrained( 92 | config.init_folder, load_weights=config.use_pretrained 93 | ) 94 | num_params = sum(p.numel() for p in model.parameters()) 95 | print(f"Total number of parameters: {num_params}") 96 | 97 | model = model.to(device) 98 | model = model.to(torch.bfloat16) 99 | 100 | print(f"Model max_seq_len: {model.max_seq_len}") 101 | 102 | return TrainingState( 103 | model=model, optimizer=None, scheduler=None, start_epoch=0, global_step=0 104 | ) 105 | 106 | def save(self, state: TrainingState, config: TrainingConfig) -> None: 107 | """Save current training state""" 108 | if state.global_step == 0: 109 | print("Skipping step 0") 110 | return None 111 | 112 | state_dict = { 113 | k.replace("_orig_mod.", ""): v for k, v in state.model.state_dict().items() 114 | } 115 | 116 | checkpoint_path = self.run_dir / f"step_{state.global_step:06d}.pt" 117 | torch.save( 118 | { 119 | "epoch": state.start_epoch, 120 | "global_step": state.global_step, 121 | "model_state_dict": state_dict, 122 | "optimizer_state_dict": ( 123 | state.optimizer.state_dict() if state.optimizer else None 124 | ), 125 | "scheduler_state_dict": ( 126 | state.scheduler.state_dict() if state.scheduler else None 127 | ), 128 | "config": config.model_dump(), 129 | }, 130 | checkpoint_path, 131 | ) 132 | 133 | self._cleanup_old_checkpoints() 134 | 135 | def _cleanup_old_checkpoints(self): 136 | """Remove old checkpoints from current run""" 137 | files = sorted(self.run_dir.glob("step_*.pt")) 138 | if len(files) > self.keep_last_n: 139 | for f in files[: -self.keep_last_n]: 140 | f.unlink() 141 | -------------------------------------------------------------------------------- /train/trainer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from einops import rearrange 3 | from functools import partial 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader, Dataset 7 | from tqdm import tqdm 8 | from typing import List 9 | import wandb 10 | 11 | from modeling.model.rq_transformer import RQTransformer 12 | from train.config import TrainingConfig 13 | from train.data import collate_fn 14 | from train.state import TrainingState, CheckpointManager 15 | 16 | 17 | @dataclass 18 | class TrainStepOutput: 19 | loss: torch.Tensor 20 | base_loss: float 21 | semantic_loss: float 22 | lr: float 23 | 24 | 25 | def compute_losses( 26 | outputs, labels: torch.Tensor, per_codebook_loss: bool = False 27 | ) -> tuple[torch.Tensor, torch.Tensor, List[float]]: 28 | """Compute base and semantic losses, plus individual codebook losses""" 29 | # Base loss computation remains the same 30 | base_loss = F.cross_entropy( 31 | outputs.token_logits.view(-1, outputs.token_logits.size(-1)), 32 | labels[:, 0, :].reshape(-1), 33 | ignore_index=-100, 34 | ) 35 | 36 | # Compute individual codebook losses 37 | n_codebooks = labels.shape[1] - 1 # Subtract 1 for the base tokens 38 | if per_codebook_loss: 39 | codebook_losses = [] 40 | 41 | for i in range(n_codebooks): 42 | # Reshape logits and labels for current codebook 43 | current_logits = outputs.codebook_logits[:, :, i, :] # [batch, seq, vocab] 44 | current_labels = labels[:, i + 1, :] # [batch, seq] 45 | 46 | loss = F.cross_entropy( 47 | current_logits.reshape(-1, current_logits.size(-1)), 48 | current_labels.reshape(-1), 49 | ignore_index=-100, 50 | ) 51 | codebook_losses.append(loss.item()) 52 | else: 53 | codebook_losses = [] 54 | 55 | # Compute total semantic loss (same as before, just using einops) 56 | codebook_logits = rearrange(outputs.codebook_logits, "b s n d -> (b s n) d") 57 | codebook_labels = rearrange(labels[:, 1:, :], "b n s -> (b s n)") 58 | semantic_loss = F.cross_entropy(codebook_logits, codebook_labels, ignore_index=-100) 59 | 60 | return base_loss, semantic_loss, codebook_losses 61 | 62 | 63 | def train_step( 64 | model: torch.nn.Module, 65 | batch: dict, 66 | device: torch.device, 67 | accumulate_steps: int = 1, # New parameter for loss scaling 68 | ) -> TrainStepOutput: 69 | """ 70 | Executes a forward pass and backward pass (with loss scaling for gradient accumulation). 71 | Does NOT perform optimizer.step() or gradient clipping here. 72 | """ 73 | tokens = batch["tokens"].to(device) 74 | labels = batch["labels"].to(device) 75 | pad_mask = batch["pad_mask"].to(device) 76 | 77 | outputs = model(inp=tokens, key_padding_mask=pad_mask) 78 | base_loss, semantic_loss, _ = compute_losses(outputs, labels) 79 | 80 | # Compute total loss and scale it for accumulation 81 | total_loss = base_loss + semantic_loss 82 | scaled_loss = total_loss / accumulate_steps 83 | scaled_loss.backward() 84 | 85 | # Return the unscaled losses for logging purposes 86 | return TrainStepOutput( 87 | loss=total_loss, 88 | base_loss=base_loss.item(), 89 | semantic_loss=semantic_loss.item(), 90 | lr=0.0, # Will be updated in the training loop after optimizer.step() 91 | ) 92 | 93 | 94 | def validate( 95 | model: torch.nn.Module, val_loader: DataLoader, device: torch.device 96 | ) -> dict: 97 | """Run validation""" 98 | model.eval() 99 | total_loss = total_base_loss = total_semantic_loss = num_batches = 0 100 | total_codebook_losses = None 101 | 102 | with torch.no_grad(): 103 | for batch in val_loader: 104 | tokens = batch["tokens"].to(device) 105 | labels = batch["labels"].to(device) 106 | 107 | outputs = model(tokens) 108 | base_loss, semantic_loss, codebook_losses = compute_losses(outputs, labels) 109 | loss = base_loss + semantic_loss 110 | 111 | # Initialize total_codebook_losses on first batch 112 | if total_codebook_losses is None: 113 | total_codebook_losses = [0.0] * len(codebook_losses) 114 | 115 | # Accumulate losses 116 | total_loss += loss.item() 117 | total_base_loss += base_loss.item() 118 | total_semantic_loss += semantic_loss.item() 119 | total_codebook_losses = [ 120 | total + current 121 | for total, current in zip(total_codebook_losses, codebook_losses) 122 | ] 123 | num_batches += 1 124 | 125 | del tokens, labels, outputs, base_loss, semantic_loss, loss 126 | torch.cuda.empty_cache() 127 | 128 | model.train() 129 | return { 130 | "loss": total_loss / num_batches, 131 | "base_loss": total_base_loss / num_batches, 132 | "semantic_loss": total_semantic_loss / num_batches, 133 | "codebook_losses": [loss / num_batches for loss in total_codebook_losses], 134 | } 135 | 136 | 137 | def create_dataloaders( 138 | train_ds: Dataset, 139 | val_ds: Dataset, 140 | config: TrainingConfig, 141 | pad_id: int, 142 | duplicate_code_0: bool = True, 143 | ) -> tuple[DataLoader, DataLoader]: 144 | pad_collate_fn = partial( 145 | collate_fn, semantic_pad_id=pad_id, duplicate_code_0=duplicate_code_0 146 | ) 147 | 148 | """Create train and validation dataloaders""" 149 | train_loader = DataLoader( 150 | train_ds, 151 | batch_size=config.batch_size, 152 | shuffle=True, 153 | collate_fn=pad_collate_fn, 154 | num_workers=config.num_workers, 155 | pin_memory=True, 156 | ) 157 | 158 | val_loader = DataLoader( 159 | val_ds, 160 | batch_size=config.batch_size, 161 | shuffle=False, 162 | collate_fn=pad_collate_fn, 163 | num_workers=config.num_workers, 164 | pin_memory=True, 165 | ) 166 | return train_loader, val_loader 167 | 168 | 169 | def train( 170 | model: RQTransformer, 171 | train_ds: Dataset, 172 | val_ds: Dataset, 173 | config: TrainingConfig, 174 | device: torch.device, 175 | optimizer: torch.optim.Optimizer, 176 | scheduler: torch.optim.lr_scheduler.LRScheduler, 177 | checkpoint_manager: CheckpointManager, 178 | start_epoch: int = 0, 179 | global_step: int = 0, 180 | ): 181 | pad_id = model.tokenizer.pad_token_id 182 | if config.use_wandb: 183 | wandb.init(project=config.project_name, resume="allow") 184 | wandb.config.update(config.model_dump()) 185 | 186 | train_loader, val_loader = create_dataloaders( 187 | train_ds, 188 | val_ds, 189 | config, 190 | pad_id=pad_id, 191 | duplicate_code_0=model.config.duplicate_code_0, 192 | ) 193 | 194 | # Initialize accumulation counter and zero the gradients initially. 195 | accumulation_counter = 0 196 | optimizer.zero_grad() 197 | 198 | for epoch in range(start_epoch, config.max_epochs): 199 | model.train() 200 | progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}") 201 | 202 | for batch in progress_bar: 203 | # Forward and backward pass (loss scaled inside train_step) 204 | step_output = train_step( 205 | model, batch, device=device, accumulate_steps=config.accumulate_steps 206 | ) 207 | 208 | accumulation_counter += 1 209 | 210 | # Only perform an optimizer step when enough gradients have accumulated. 211 | if accumulation_counter == config.accumulate_steps: 212 | # Optionally clip gradients 213 | if config.gradient_clip > 0: 214 | torch.nn.utils.clip_grad_norm_( 215 | model.parameters(), config.gradient_clip 216 | ) 217 | optimizer.step() 218 | scheduler.step() 219 | optimizer.zero_grad() 220 | accumulation_counter = 0 221 | torch.cuda.empty_cache() 222 | 223 | # Get current learning rate (even if not updated, it stays the same) 224 | current_lr = scheduler.get_last_lr()[0] 225 | 226 | if config.use_wandb: 227 | metrics = { 228 | "train/loss": float(step_output.loss), 229 | "train/base_loss": float(step_output.base_loss), 230 | "train/semantic_loss": float(step_output.semantic_loss), 231 | "train/learning_rate": current_lr, 232 | "epoch": epoch, 233 | } 234 | wandb.log(metrics, step=global_step) 235 | 236 | progress_bar.set_postfix( 237 | loss=f"lm={step_output.base_loss:.4f},codes={step_output.semantic_loss:.4f}", 238 | lr=f"{current_lr:.2e}", 239 | ) 240 | 241 | if ( 242 | global_step % config.val_every_n_steps == 0 243 | and config.use_wandb 244 | and global_step != 0 245 | ): 246 | val_metrics = validate(model, val_loader, device) 247 | wandb.log( 248 | { 249 | "val/loss": float(val_metrics["loss"]), 250 | "val/base_loss": float(val_metrics["base_loss"]), 251 | "val/semantic_loss": float(val_metrics["semantic_loss"]), 252 | **{ 253 | f"val/codebook_{i + 1}_loss": loss 254 | for i, loss in enumerate(val_metrics["codebook_losses"]) 255 | }, 256 | }, 257 | step=global_step, 258 | ) 259 | 260 | if global_step % config.save_every_n_steps == 0: 261 | checkpoint_manager.save( 262 | TrainingState( 263 | model=model, 264 | optimizer=optimizer, 265 | scheduler=scheduler, 266 | start_epoch=epoch, 267 | global_step=global_step, 268 | ), 269 | config, 270 | ) 271 | 272 | global_step += 1 273 | 274 | checkpoint_manager.save( 275 | TrainingState( 276 | model=model, 277 | optimizer=optimizer, 278 | scheduler=scheduler, 279 | start_epoch=epoch, 280 | global_step=global_step, 281 | ), 282 | config, 283 | ) 284 | --------------------------------------------------------------------------------