├── .gitignore ├── LICENSE ├── README.md ├── cli ├── SparkTTS.py └── inference.py ├── example ├── infer.sh ├── prompt_audio.wav └── results │ └── 20250225113521.wav ├── requirements.txt ├── runtime └── triton_trtllm │ ├── Dockerfile.server │ ├── README.md │ ├── client_grpc.py │ ├── client_http.py │ ├── docker-compose.yml │ ├── model_repo │ ├── audio_tokenizer │ │ ├── 1 │ │ │ └── model.py │ │ └── config.pbtxt │ ├── spark_tts │ │ ├── 1 │ │ │ └── model.py │ │ └── config.pbtxt │ ├── tensorrt_llm │ │ ├── 1 │ │ │ └── .gitkeep │ │ └── config.pbtxt │ └── vocoder │ │ ├── 1 │ │ └── model.py │ │ └── config.pbtxt │ ├── run.sh │ └── scripts │ ├── convert_checkpoint.py │ └── fill_template.py ├── sparktts ├── models │ ├── audio_tokenizer.py │ └── bicodec.py ├── modules │ ├── blocks │ │ ├── layers.py │ │ ├── samper.py │ │ └── vocos.py │ ├── encoder_decoder │ │ ├── feat_decoder.py │ │ ├── feat_encoder.py │ │ └── wave_generator.py │ ├── fsq │ │ ├── finite_scalar_quantization.py │ │ └── residual_fsq.py │ ├── speaker │ │ ├── ecapa_tdnn.py │ │ ├── perceiver_encoder.py │ │ ├── pooling_layers.py │ │ └── speaker_encoder.py │ └── vq │ │ └── factorized_vector_quantize.py └── utils │ ├── __init__.py │ ├── audio.py │ ├── file.py │ ├── parse_options.sh │ └── token_parser.py ├── src ├── demos │ ├── trump │ │ └── trump_en.wav │ ├── zhongli │ │ └── zhongli_en.wav │ ├── 余承东 │ │ └── yuchengdong_zh.wav │ ├── 刘德华 │ │ └── dehua_zh.wav │ ├── 哪吒 │ │ └── nezha_zh.wav │ ├── 徐志胜 │ │ └── zhisheng_zh.wav │ ├── 李靖 │ │ └── lijing_zh.wav │ ├── 杨澜 │ │ └── yanglan_zh.wav │ ├── 马云 │ │ └── mayun_zh.wav │ └── 鲁豫 │ │ └── luyu_zh.wav ├── figures │ ├── gradio_TTS.png │ ├── gradio_control.png │ ├── infer_control.png │ └── infer_voice_cloning.png └── logo │ ├── HKUST.jpg │ ├── NPU.jpg │ ├── NTU.jpg │ ├── SJU.jpg │ ├── SparkAudio.jpg │ ├── SparkAudio2.jpg │ ├── SparkTTS.jpg │ ├── SparkTTS.png │ ├── mobvoi.jpg │ └── mobvoi.png └── webui.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | pretrained_models/ 6 | results/ 7 | demo/ 8 | # C extensions 9 | *.so 10 | .gradio/ 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | webui_test.py 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # UV 101 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | #uv.lock 105 | 106 | # poetry 107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 108 | # This is especially recommended for binary packages to ensure reproducibility, and is more 109 | # commonly ignored for libraries. 110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 111 | #poetry.lock 112 | 113 | # pdm 114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 115 | #pdm.lock 116 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 117 | # in version control. 118 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 119 | .pdm.toml 120 | .pdm-python 121 | .pdm-build/ 122 | 123 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 124 | __pypackages__/ 125 | 126 | # Celery stuff 127 | celerybeat-schedule 128 | celerybeat.pid 129 | 130 | # SageMath parsed files 131 | *.sage.py 132 | 133 | # Environments 134 | .env 135 | .venv 136 | env/ 137 | venv/ 138 | ENV/ 139 | env.bak/ 140 | venv.bak/ 141 | 142 | # Spyder project settings 143 | .spyderproject 144 | .spyproject 145 | 146 | # Rope project settings 147 | .ropeproject 148 | 149 | # mkdocs documentation 150 | /site 151 | 152 | # mypy 153 | .mypy_cache/ 154 | .dmypy.json 155 | dmypy.json 156 | 157 | # Pyre type checker 158 | .pyre/ 159 | 160 | # pytype static type analyzer 161 | .pytype/ 162 | 163 | # Cython debug symbols 164 | cython_debug/ 165 | 166 | # PyCharm 167 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 168 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 169 | # and can be added to the global gitignore or merged into this file. For a more nuclear 170 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 171 | #.idea/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

3 | Spark-TTS 4 |

5 |

6 | Official PyTorch code for inference of
7 | Spark-TTS: An Efficient LLM-Based Text-to-Speech Model with Single-Stream Decoupled Speech Tokens 8 |

9 |

10 | Spark-TTS Logo 11 |

12 |

13 | Institution 1 14 | Institution 2 15 | Institution 3 16 |

17 |

18 | Institution 4 19 | Institution 5 20 | Institution 6 21 |

22 |

23 |

24 | paper 25 | version 26 | Hugging Face 27 | version 28 | version 29 | python 30 | mit 31 |
32 | 33 | 34 | ## Spark-TTS 🔥 35 | 36 | ### Overview 37 | 38 | Spark-TTS is an advanced text-to-speech system that uses the power of large language models (LLM) for highly accurate and natural-sounding voice synthesis. It is designed to be efficient, flexible, and powerful for both research and production use. 39 | 40 | ### Key Features 41 | 42 | - **Simplicity and Efficiency**: Built entirely on Qwen2.5, Spark-TTS eliminates the need for additional generation models like flow matching. Instead of relying on separate models to generate acoustic features, it directly reconstructs audio from the code predicted by the LLM. This approach streamlines the process, improving efficiency and reducing complexity. 43 | - **High-Quality Voice Cloning**: Supports zero-shot voice cloning, which means it can replicate a speaker's voice even without specific training data for that voice. This is ideal for cross-lingual and code-switching scenarios, allowing for seamless transitions between languages and voices without requiring separate training for each one. 44 | - **Bilingual Support**: Supports both Chinese and English, and is capable of zero-shot voice cloning for cross-lingual and code-switching scenarios, enabling the model to synthesize speech in multiple languages with high naturalness and accuracy. 45 | - **Controllable Speech Generation**: Supports creating virtual speakers by adjusting parameters such as gender, pitch, and speaking rate. 46 | 47 | --- 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 |
Inference Overview of Voice Cloning
Inference Overview of Controlled Generation
57 | 58 | 59 | ## 🚀 News 60 | 61 | - **[2025-03-04]** Our paper on this project has been published! You can read it here: [Spark-TTS](https://arxiv.org/pdf/2503.01710). 62 | 63 | - **[2025-03-12]** Nvidia Triton Inference Serving is now supported. See the Runtime section below for more details. 64 | 65 | 66 | ## Install 67 | **Clone and Install** 68 | 69 | Here are instructions for installing on Linux. If you're on Windows, please refer to the [Windows Installation Guide](https://github.com/SparkAudio/Spark-TTS/issues/5). 70 | *(Thanks to [@AcTePuKc](https://github.com/AcTePuKc) for the detailed Windows instructions!)* 71 | 72 | 73 | - Clone the repo 74 | ``` sh 75 | git clone https://github.com/SparkAudio/Spark-TTS.git 76 | cd Spark-TTS 77 | ``` 78 | 79 | - Install Conda: please see https://docs.conda.io/en/latest/miniconda.html 80 | - Create Conda env: 81 | 82 | ``` sh 83 | conda create -n sparktts -y python=3.12 84 | conda activate sparktts 85 | pip install -r requirements.txt 86 | # If you are in mainland China, you can set the mirror as follows: 87 | pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com 88 | ``` 89 | 90 | **Model Download** 91 | 92 | Download via python: 93 | ```python 94 | from huggingface_hub import snapshot_download 95 | 96 | snapshot_download("SparkAudio/Spark-TTS-0.5B", local_dir="pretrained_models/Spark-TTS-0.5B") 97 | ``` 98 | 99 | Download via git clone: 100 | ```sh 101 | mkdir -p pretrained_models 102 | 103 | # Make sure you have git-lfs installed (https://git-lfs.com) 104 | git lfs install 105 | 106 | git clone https://huggingface.co/SparkAudio/Spark-TTS-0.5B pretrained_models/Spark-TTS-0.5B 107 | ``` 108 | 109 | **Basic Usage** 110 | 111 | You can simply run the demo with the following commands: 112 | ``` sh 113 | cd example 114 | bash infer.sh 115 | ``` 116 | 117 | Alternatively, you can directly execute the following command in the command line to perform inference: 118 | 119 | ``` sh 120 | python -m cli.inference \ 121 | --text "text to synthesis." \ 122 | --device 0 \ 123 | --save_dir "path/to/save/audio" \ 124 | --model_dir pretrained_models/Spark-TTS-0.5B \ 125 | --prompt_text "transcript of the prompt audio" \ 126 | --prompt_speech_path "path/to/prompt_audio" 127 | ``` 128 | 129 | **Web UI Usage** 130 | 131 | You can start the UI interface by running `python webui.py --device 0`, which allows you to perform Voice Cloning and Voice Creation. Voice Cloning supports uploading reference audio or directly recording the audio. 132 | 133 | 134 | | **Voice Cloning** | **Voice Creation** | 135 | |:-------------------:|:-------------------:| 136 | | ![Image 1](src/figures/gradio_TTS.png) | ![Image 2](src/figures/gradio_control.png) | 137 | 138 | 139 | **Optional Methods** 140 | 141 | For additional CLI and Web UI methods, including alternative implementations and extended functionalities, you can refer to: 142 | 143 | - [CLI and UI by AcTePuKc](https://github.com/SparkAudio/Spark-TTS/issues/10) 144 | 145 | 146 | ## Runtime 147 | 148 | **Nvidia Triton Inference Serving** 149 | 150 | We now provide a reference for deploying Spark-TTS with Nvidia Triton and TensorRT-LLM. The table below presents benchmark results on a single L20 GPU, using 26 different prompt_audio/target_text pairs (totalling 169 seconds of audio): 151 | 152 | | Model | Note | Concurrency | Avg Latency | RTF | 153 | |-------|-----------|-----------------------|---------|--| 154 | | Spark-TTS-0.5B | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 1 | 876.24 ms | 0.1362| 155 | | Spark-TTS-0.5B | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 2 | 920.97 ms | 0.0737| 156 | | Spark-TTS-0.5B | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 4 | 1611.51 ms | 0.0704| 157 | 158 | 159 | Please see the detailed instructions in [runtime/triton_trtllm/README.md](runtime/triton_trtllm/README.md ) for more information. 160 | 161 | 162 | ## **Demos** 163 | 164 | Here are some demos generated by Spark-TTS using zero-shot voice cloning. For more demos, visit our [demo page](https://sparkaudio.github.io/spark-tts/). 165 | 166 | --- 167 | 168 | 169 | 170 | 174 | 178 | 179 | 180 | 181 | 186 | 191 | 192 |
171 | 172 | **Donald Trump** 173 | 175 | 176 | **Zhongli (Genshin Impact)** 177 |
182 | 183 | [Donald Trump](https://github.com/user-attachments/assets/fb225780-d9fe-44b2-9b2e-54390cb3d8fd) 184 | 185 | 187 | 188 | [Zhongli](https://github.com/user-attachments/assets/80eeb9c7-0443-4758-a1ce-55ac59e64bd6) 189 | 190 |
193 | 194 | --- 195 | 196 | 197 | 198 | 199 | 203 | 207 | 208 | 209 | 210 | 214 | 218 | 219 |
200 | 201 | **陈鲁豫 Chen Luyu** 202 | 204 | 205 | **杨澜 Yang Lan** 206 |
211 | 212 | [陈鲁豫Chen_Luyu.webm](https://github.com/user-attachments/assets/5c6585ae-830d-47b1-992d-ee3691f48cf4) 213 | 215 | 216 | [Yang_Lan.webm](https://github.com/user-attachments/assets/2fb3d00c-abc3-410e-932f-46ba204fb1d7) 217 |
220 | 221 | --- 222 | 223 | 224 | 225 | 226 | 230 | 234 | 235 | 236 | 237 | 242 | 247 | 248 |
227 | 228 | **余承东 Richard Yu** 229 | 231 | 232 | **马云 Jack Ma** 233 |
238 | 239 | [Yu_Chengdong.webm](https://github.com/user-attachments/assets/78feca02-84bb-4d3a-a770-0cfd02f1a8da) 240 | 241 | 243 | 244 | [Ma_Yun.webm](https://github.com/user-attachments/assets/2d54e2eb-cec4-4c2f-8c84-8fe587da321b) 245 | 246 |
249 | 250 | --- 251 | 252 | 253 | 254 | 255 | 259 | 263 | 264 | 265 | 266 | 271 | 276 | 277 |
256 | 257 | **刘德华 Andy Lau** 258 | 260 | 261 | **徐志胜 Xu Zhisheng** 262 |
267 | 268 | [Liu_Dehua.webm](https://github.com/user-attachments/assets/195b5e97-1fee-4955-b954-6d10fa04f1d7) 269 | 270 | 272 | 273 | [Xu_Zhisheng.webm](https://github.com/user-attachments/assets/dd812af9-76bd-4e26-9988-9cdb9ccbb87b) 274 | 275 |
278 | 279 | 280 | --- 281 | 282 | 283 | 284 | 288 | 292 | 293 | 294 | 295 | 299 | 304 | 305 |
285 | 286 | **哪吒 Nezha** 287 | 289 | 290 | **李靖 Li Jing** 291 |
296 | 297 | [Ne_Zha.webm](https://github.com/user-attachments/assets/8c608037-a17a-46d4-8588-4db34b49ed1d) 298 | 300 | 301 | [Li_Jing.webm](https://github.com/user-attachments/assets/aa8ba091-097c-4156-b4e3-6445da5ea101) 302 | 303 |
306 | 307 | 308 | ## To-Do List 309 | 310 | - [x] Release the Spark-TTS paper. 311 | - [ ] Release the training code. 312 | - [ ] Release the training dataset, VoxBox. 313 | 314 | 315 | ## Citation 316 | 317 | ``` 318 | @misc{wang2025sparktts, 319 | title={Spark-TTS: An Efficient LLM-Based Text-to-Speech Model with Single-Stream Decoupled Speech Tokens}, 320 | author={Xinsheng Wang and Mingqi Jiang and Ziyang Ma and Ziyu Zhang and Songxiang Liu and Linqin Li and Zheng Liang and Qixi Zheng and Rui Wang and Xiaoqin Feng and Weizhen Bian and Zhen Ye and Sitong Cheng and Ruibin Yuan and Zhixian Zhao and Xinfa Zhu and Jiahao Pan and Liumeng Xue and Pengcheng Zhu and Yunlin Chen and Zhifei Li and Xie Chen and Lei Xie and Yike Guo and Wei Xue}, 321 | year={2025}, 322 | eprint={2503.01710}, 323 | archivePrefix={arXiv}, 324 | primaryClass={cs.SD}, 325 | url={https://arxiv.org/abs/2503.01710}, 326 | } 327 | ``` 328 | 329 | 330 | ## ⚠️ Usage Disclaimer 331 | 332 | This project provides a zero-shot voice cloning TTS model intended for academic research, educational purposes, and legitimate applications, such as personalized speech synthesis, assistive technologies, and linguistic research. 333 | 334 | Please note: 335 | 336 | - Do not use this model for unauthorized voice cloning, impersonation, fraud, scams, deepfakes, or any illegal activities. 337 | 338 | - Ensure compliance with local laws and regulations when using this model and uphold ethical standards. 339 | 340 | - The developers assume no liability for any misuse of this model. 341 | 342 | We advocate for the responsible development and use of AI and encourage the community to uphold safety and ethical principles in AI research and applications. If you have any concerns regarding ethics or misuse, please contact us. -------------------------------------------------------------------------------- /cli/SparkTTS.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 SparkAudio 2 | # 2025 Xinsheng Wang (w.xinshawn@gmail.com) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import re 17 | import torch 18 | from typing import Tuple 19 | from pathlib import Path 20 | from transformers import AutoTokenizer, AutoModelForCausalLM 21 | 22 | from sparktts.utils.file import load_config 23 | from sparktts.models.audio_tokenizer import BiCodecTokenizer 24 | from sparktts.utils.token_parser import LEVELS_MAP, GENDER_MAP, TASK_TOKEN_MAP 25 | 26 | 27 | class SparkTTS: 28 | """ 29 | Spark-TTS for text-to-speech generation. 30 | """ 31 | 32 | def __init__(self, model_dir: Path, device: torch.device = torch.device("cuda:0")): 33 | """ 34 | Initializes the SparkTTS model with the provided configurations and device. 35 | 36 | Args: 37 | model_dir (Path): Directory containing the model and config files. 38 | device (torch.device): The device (CPU/GPU) to run the model on. 39 | """ 40 | self.device = device 41 | self.model_dir = model_dir 42 | self.configs = load_config(f"{model_dir}/config.yaml") 43 | self.sample_rate = self.configs["sample_rate"] 44 | self._initialize_inference() 45 | 46 | def _initialize_inference(self): 47 | """Initializes the tokenizer, model, and audio tokenizer for inference.""" 48 | self.tokenizer = AutoTokenizer.from_pretrained(f"{self.model_dir}/LLM") 49 | self.model = AutoModelForCausalLM.from_pretrained(f"{self.model_dir}/LLM") 50 | self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device) 51 | self.model.to(self.device) 52 | 53 | def process_prompt( 54 | self, 55 | text: str, 56 | prompt_speech_path: Path, 57 | prompt_text: str = None, 58 | ) -> Tuple[str, torch.Tensor]: 59 | """ 60 | Process input for voice cloning. 61 | 62 | Args: 63 | text (str): The text input to be converted to speech. 64 | prompt_speech_path (Path): Path to the audio file used as a prompt. 65 | prompt_text (str, optional): Transcript of the prompt audio. 66 | 67 | Return: 68 | Tuple[str, torch.Tensor]: Input prompt; global tokens 69 | """ 70 | 71 | global_token_ids, semantic_token_ids = self.audio_tokenizer.tokenize( 72 | prompt_speech_path 73 | ) 74 | global_tokens = "".join( 75 | [f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()] 76 | ) 77 | 78 | # Prepare the input tokens for the model 79 | if prompt_text is not None: 80 | semantic_tokens = "".join( 81 | [f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()] 82 | ) 83 | inputs = [ 84 | TASK_TOKEN_MAP["tts"], 85 | "<|start_content|>", 86 | prompt_text, 87 | text, 88 | "<|end_content|>", 89 | "<|start_global_token|>", 90 | global_tokens, 91 | "<|end_global_token|>", 92 | "<|start_semantic_token|>", 93 | semantic_tokens, 94 | ] 95 | else: 96 | inputs = [ 97 | TASK_TOKEN_MAP["tts"], 98 | "<|start_content|>", 99 | text, 100 | "<|end_content|>", 101 | "<|start_global_token|>", 102 | global_tokens, 103 | "<|end_global_token|>", 104 | ] 105 | 106 | inputs = "".join(inputs) 107 | 108 | return inputs, global_token_ids 109 | 110 | def process_prompt_control( 111 | self, 112 | gender: str, 113 | pitch: str, 114 | speed: str, 115 | text: str, 116 | ): 117 | """ 118 | Process input for voice creation. 119 | 120 | Args: 121 | gender (str): female | male. 122 | pitch (str): very_low | low | moderate | high | very_high 123 | speed (str): very_low | low | moderate | high | very_high 124 | text (str): The text input to be converted to speech. 125 | 126 | Return: 127 | str: Input prompt 128 | """ 129 | assert gender in GENDER_MAP.keys() 130 | assert pitch in LEVELS_MAP.keys() 131 | assert speed in LEVELS_MAP.keys() 132 | 133 | gender_id = GENDER_MAP[gender] 134 | pitch_level_id = LEVELS_MAP[pitch] 135 | speed_level_id = LEVELS_MAP[speed] 136 | 137 | pitch_label_tokens = f"<|pitch_label_{pitch_level_id}|>" 138 | speed_label_tokens = f"<|speed_label_{speed_level_id}|>" 139 | gender_tokens = f"<|gender_{gender_id}|>" 140 | 141 | attribte_tokens = "".join( 142 | [gender_tokens, pitch_label_tokens, speed_label_tokens] 143 | ) 144 | 145 | control_tts_inputs = [ 146 | TASK_TOKEN_MAP["controllable_tts"], 147 | "<|start_content|>", 148 | text, 149 | "<|end_content|>", 150 | "<|start_style_label|>", 151 | attribte_tokens, 152 | "<|end_style_label|>", 153 | ] 154 | 155 | return "".join(control_tts_inputs) 156 | 157 | @torch.no_grad() 158 | def inference( 159 | self, 160 | text: str, 161 | prompt_speech_path: Path = None, 162 | prompt_text: str = None, 163 | gender: str = None, 164 | pitch: str = None, 165 | speed: str = None, 166 | temperature: float = 0.8, 167 | top_k: float = 50, 168 | top_p: float = 0.95, 169 | ) -> torch.Tensor: 170 | """ 171 | Performs inference to generate speech from text, incorporating prompt audio and/or text. 172 | 173 | Args: 174 | text (str): The text input to be converted to speech. 175 | prompt_speech_path (Path): Path to the audio file used as a prompt. 176 | prompt_text (str, optional): Transcript of the prompt audio. 177 | gender (str): female | male. 178 | pitch (str): very_low | low | moderate | high | very_high 179 | speed (str): very_low | low | moderate | high | very_high 180 | temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8. 181 | top_k (float, optional): Top-k sampling parameter. Default is 50. 182 | top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95. 183 | 184 | Returns: 185 | torch.Tensor: Generated waveform as a tensor. 186 | """ 187 | if gender is not None: 188 | prompt = self.process_prompt_control(gender, pitch, speed, text) 189 | 190 | else: 191 | prompt, global_token_ids = self.process_prompt( 192 | text, prompt_speech_path, prompt_text 193 | ) 194 | model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device) 195 | 196 | # Generate speech using the model 197 | generated_ids = self.model.generate( 198 | **model_inputs, 199 | max_new_tokens=3000, 200 | do_sample=True, 201 | top_k=top_k, 202 | top_p=top_p, 203 | temperature=temperature, 204 | ) 205 | 206 | # Trim the output tokens to remove the input tokens 207 | generated_ids = [ 208 | output_ids[len(input_ids) :] 209 | for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) 210 | ] 211 | 212 | # Decode the generated tokens into text 213 | predicts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 214 | 215 | # Extract semantic token IDs from the generated text 216 | pred_semantic_ids = ( 217 | torch.tensor([int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicts)]) 218 | .long() 219 | .unsqueeze(0) 220 | ) 221 | 222 | if gender is not None: 223 | global_token_ids = ( 224 | torch.tensor([int(token) for token in re.findall(r"bicodec_global_(\d+)", predicts)]) 225 | .long() 226 | .unsqueeze(0) 227 | .unsqueeze(0) 228 | ) 229 | 230 | # Convert semantic tokens back to waveform 231 | wav = self.audio_tokenizer.detokenize( 232 | global_token_ids.to(self.device).squeeze(0), 233 | pred_semantic_ids.to(self.device), 234 | ) 235 | 236 | return wav -------------------------------------------------------------------------------- /cli/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 SparkAudio 2 | # 2025 Xinsheng Wang (w.xinshawn@gmail.com) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import os 18 | import argparse 19 | import torch 20 | import soundfile as sf 21 | import logging 22 | from datetime import datetime 23 | import platform 24 | 25 | from cli.SparkTTS import SparkTTS 26 | 27 | 28 | def parse_args(): 29 | """Parse command-line arguments.""" 30 | parser = argparse.ArgumentParser(description="Run TTS inference.") 31 | 32 | parser.add_argument( 33 | "--model_dir", 34 | type=str, 35 | default="pretrained_models/Spark-TTS-0.5B", 36 | help="Path to the model directory", 37 | ) 38 | parser.add_argument( 39 | "--save_dir", 40 | type=str, 41 | default="example/results", 42 | help="Directory to save generated audio files", 43 | ) 44 | parser.add_argument("--device", type=int, default=0, help="CUDA device number") 45 | parser.add_argument( 46 | "--text", type=str, required=True, help="Text for TTS generation" 47 | ) 48 | parser.add_argument("--prompt_text", type=str, help="Transcript of prompt audio") 49 | parser.add_argument( 50 | "--prompt_speech_path", 51 | type=str, 52 | help="Path to the prompt audio file", 53 | ) 54 | parser.add_argument("--gender", choices=["male", "female"]) 55 | parser.add_argument( 56 | "--pitch", choices=["very_low", "low", "moderate", "high", "very_high"] 57 | ) 58 | parser.add_argument( 59 | "--speed", choices=["very_low", "low", "moderate", "high", "very_high"] 60 | ) 61 | return parser.parse_args() 62 | 63 | 64 | def run_tts(args): 65 | """Perform TTS inference and save the generated audio.""" 66 | logging.info(f"Using model from: {args.model_dir}") 67 | logging.info(f"Saving audio to: {args.save_dir}") 68 | 69 | # Ensure the save directory exists 70 | os.makedirs(args.save_dir, exist_ok=True) 71 | 72 | # Convert device argument to torch.device 73 | if platform.system() == "Darwin" and torch.backends.mps.is_available(): 74 | # macOS with MPS support (Apple Silicon) 75 | device = torch.device(f"mps:{args.device}") 76 | logging.info(f"Using MPS device: {device}") 77 | elif torch.cuda.is_available(): 78 | # System with CUDA support 79 | device = torch.device(f"cuda:{args.device}") 80 | logging.info(f"Using CUDA device: {device}") 81 | else: 82 | # Fall back to CPU 83 | device = torch.device("cpu") 84 | logging.info("GPU acceleration not available, using CPU") 85 | 86 | # Initialize the model 87 | model = SparkTTS(args.model_dir, device) 88 | 89 | # Generate unique filename using timestamp 90 | timestamp = datetime.now().strftime("%Y%m%d%H%M%S") 91 | save_path = os.path.join(args.save_dir, f"{timestamp}.wav") 92 | 93 | logging.info("Starting inference...") 94 | 95 | # Perform inference and save the output audio 96 | with torch.no_grad(): 97 | wav = model.inference( 98 | args.text, 99 | args.prompt_speech_path, 100 | prompt_text=args.prompt_text, 101 | gender=args.gender, 102 | pitch=args.pitch, 103 | speed=args.speed, 104 | ) 105 | sf.write(save_path, wav, samplerate=16000) 106 | 107 | logging.info(f"Audio saved at: {save_path}") 108 | 109 | 110 | if __name__ == "__main__": 111 | logging.basicConfig( 112 | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" 113 | ) 114 | 115 | args = parse_args() 116 | run_tts(args) 117 | -------------------------------------------------------------------------------- /example/infer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) 2025 SparkAudio 4 | # 2025 Xinsheng Wang (w.xinshawn@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | 19 | # Get the absolute path of the script's directory 20 | script_dir=$(dirname "$(realpath "$0")") 21 | 22 | # Get the root directory 23 | root_dir=$(dirname "$script_dir") 24 | 25 | # Set default parameters 26 | device=0 27 | save_dir='example/results' 28 | model_dir="pretrained_models/Spark-TTS-0.5B" 29 | text="身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" 30 | prompt_text="吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" 31 | prompt_speech_path="example/prompt_audio.wav" 32 | 33 | # Change directory to the root directory 34 | cd "$root_dir" || exit 35 | 36 | source sparktts/utils/parse_options.sh 37 | 38 | # Run inference 39 | python -m cli.inference \ 40 | --text "${text}" \ 41 | --device "${device}" \ 42 | --save_dir "${save_dir}" \ 43 | --model_dir "${model_dir}" \ 44 | --prompt_text "${prompt_text}" \ 45 | --prompt_speech_path "${prompt_speech_path}" 46 | 47 | -------------------------------------------------------------------------------- /example/prompt_audio.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/example/prompt_audio.wav -------------------------------------------------------------------------------- /example/results/20250225113521.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/example/results/20250225113521.wav -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.8.1 2 | einx==0.3.0 3 | numpy==2.2.3 4 | omegaconf==2.3.0 5 | packaging==24.2 6 | safetensors==0.5.2 7 | soundfile==0.12.1 8 | soxr==0.5.0.post1 9 | torch==2.5.1 10 | torchaudio==2.5.1 11 | tqdm==4.66.5 12 | transformers==4.46.2 13 | gradio==5.18.0 -------------------------------------------------------------------------------- /runtime/triton_trtllm/Dockerfile.server: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/tritonserver:25.02-trtllm-python-py3 2 | RUN apt-get update && apt-get install -y cmake 3 | RUN git clone https://github.com/pytorch/audio.git && cd audio && git checkout c670ad8 && PATH=/usr/local/cuda/bin:$PATH python3 setup.py develop 4 | RUN pip install einx==0.3.0 omegaconf==2.3.0 soundfile==0.12.1 soxr==0.5.0.post1 gradio tritonclient librosa 5 | WORKDIR /workspace -------------------------------------------------------------------------------- /runtime/triton_trtllm/README.md: -------------------------------------------------------------------------------- 1 | ## Nvidia Triton Inference Serving Best Practice for Spark TTS 2 | 3 | ### Quick Start 4 | Directly launch the service using docker compose. 5 | ```sh 6 | docker compose up 7 | ``` 8 | 9 | ### Build Image 10 | Build the docker image from scratch. 11 | ```sh 12 | docker build . -f Dockerfile.server -t soar97/triton-spark-tts:25.02 13 | ``` 14 | 15 | ### Create Docker Container 16 | ```sh 17 | your_mount_dir=/mnt:/mnt 18 | docker run -it --name "spark-tts-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-spark-tts:25.02 19 | ``` 20 | 21 | ### Understanding `run.sh` 22 | 23 | The `run.sh` script automates various steps using stages. You can run specific stages using: 24 | ```sh 25 | bash run.sh [service_type] 26 | ``` 27 | - ``: The stage to begin execution from (0-5). 28 | - ``: The stage to end execution at (0-5). 29 | - `[service_type]`: Optional, specifies the service type ('streaming' or 'offline', defaults may apply based on script logic). Required for stages 4 and 5. 30 | 31 | Stages: 32 | - **Stage 0**: Download Spark-TTS-0.5B model from HuggingFace. 33 | - **Stage 1**: Convert HuggingFace checkpoint to TensorRT-LLM format and build TensorRT engines. 34 | - **Stage 2**: Create the Triton model repository structure and configure model files (adjusts for streaming/offline). 35 | - **Stage 3**: Launch the Triton Inference Server. 36 | - **Stage 4**: Run the gRPC benchmark client. 37 | - **Stage 5**: Run the single utterance client (gRPC for streaming, HTTP for offline). 38 | 39 | ### Export Models to TensorRT-LLM and Launch Server 40 | Inside the docker container, you can prepare the models and launch the Triton server by running stages 0 through 3. This involves downloading the original model, converting it to TensorRT-LLM format, building the optimized TensorRT engines, creating the necessary model repository structure for Triton, and finally starting the server. 41 | ```sh 42 | # This runs stages 0, 1, 2, and 3 43 | bash run.sh 0 3 44 | ``` 45 | *Note: Stage 2 prepares the model repository differently based on whether you intend to run streaming or offline inference later. You might need to re-run stage 2 if switching service types.* 46 | 47 | 48 | ### Single Utterance Client 49 | Run a single inference request. Specify `streaming` or `offline` as the third argument. 50 | 51 | **Streaming Mode (gRPC):** 52 | ```sh 53 | bash run.sh 5 5 streaming 54 | ``` 55 | This executes the `client_grpc.py` script with predefined example text and prompt audio in streaming mode. 56 | 57 | **Offline Mode (HTTP):** 58 | ```sh 59 | bash run.sh 5 5 offline 60 | ``` 61 | 62 | ### Benchmark using Dataset 63 | Run the benchmark client against the running Triton server. Specify `streaming` or `offline` as the third argument. 64 | ```sh 65 | # Run benchmark in streaming mode 66 | bash run.sh 4 4 streaming 67 | 68 | # Run benchmark in offline mode 69 | bash run.sh 4 4 offline 70 | 71 | # You can also customize parameters like num_task directly in client_grpc.py or via args if supported 72 | # Example from run.sh (streaming): 73 | # python3 client_grpc.py \ 74 | # --server-addr localhost \ 75 | # --model-name spark_tts \ 76 | # --num-tasks 2 \ 77 | # --mode streaming \ 78 | # --log-dir ./log_concurrent_tasks_2_streaming_new 79 | 80 | # Example customizing dataset (requires modifying client_grpc.py or adding args): 81 | # python3 client_grpc.py --num-tasks 2 --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --mode [streaming|offline] 82 | ``` 83 | 84 | ### Benchmark Results 85 | Decoding on a single L20 GPU, using 26 different prompt_audio/target_text [pairs](https://huggingface.co/datasets/yuekai/seed_tts), total audio duration 169 secs. 86 | 87 | | Mode | Note | Concurrency | Avg Latency | First Chunk Latency (P50) | RTF | 88 | |-------|-----------|-----------------------|---------|----------------|-| 89 | | Offline | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 1 | 876.24 ms |-| 0.1362| 90 | | Offline | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 2 | 920.97 ms |-|0.0737| 91 | | Offline | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 4 | 1611.51 ms |-| 0.0704| 92 | | Streaming | [Code Commit](https://github.com/yuekaizhang/Spark-TTS/commit/0e978a327f99aa49f0735f86eb09372f16410d86) | 1 | 913.28 ms |210.42 ms| 0.1501 | 93 | | Streaming | [Code Commit](https://github.com/yuekaizhang/Spark-TTS/commit/0e978a327f99aa49f0735f86eb09372f16410d86) | 2 | 1009.23 ms |226.08 ms |0.0862 | 94 | | Streaming | [Code Commit](https://github.com/yuekaizhang/Spark-TTS/commit/0e978a327f99aa49f0735f86eb09372f16410d86) | 4 | 1793.86 ms |1017.70 ms| 0.0824 | -------------------------------------------------------------------------------- /runtime/triton_trtllm/client_http.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | import requests 27 | import soundfile as sf 28 | import json 29 | import numpy as np 30 | import argparse 31 | 32 | def get_args(): 33 | parser = argparse.ArgumentParser( 34 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 35 | ) 36 | 37 | parser.add_argument( 38 | "--server-url", 39 | type=str, 40 | default="localhost:8000", 41 | help="Address of the server", 42 | ) 43 | 44 | parser.add_argument( 45 | "--reference-audio", 46 | type=str, 47 | default="../../example/prompt_audio.wav", 48 | help="Path to a single audio file. It can't be specified at the same time with --manifest-dir", 49 | ) 50 | 51 | parser.add_argument( 52 | "--reference-text", 53 | type=str, 54 | default="吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。", 55 | help="", 56 | ) 57 | 58 | parser.add_argument( 59 | "--target-text", 60 | type=str, 61 | default="身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。", 62 | help="", 63 | ) 64 | 65 | parser.add_argument( 66 | "--model-name", 67 | type=str, 68 | default="spark_tts", 69 | choices=[ 70 | "f5_tts", "spark_tts" 71 | ], 72 | help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline", 73 | ) 74 | 75 | parser.add_argument( 76 | "--output-audio", 77 | type=str, 78 | default="output.wav", 79 | help="Path to save the output audio", 80 | ) 81 | return parser.parse_args() 82 | 83 | def prepare_request( 84 | waveform, 85 | reference_text, 86 | target_text, 87 | sample_rate=16000, 88 | padding_duration: int = None, 89 | audio_save_dir: str = "./", 90 | ): 91 | assert len(waveform.shape) == 1, "waveform should be 1D" 92 | lengths = np.array([[len(waveform)]], dtype=np.int32) 93 | if padding_duration: 94 | # padding to nearset 10 seconds 95 | samples = np.zeros( 96 | ( 97 | 1, 98 | padding_duration 99 | * sample_rate 100 | * ((int(duration) // padding_duration) + 1), 101 | ), 102 | dtype=np.float32, 103 | ) 104 | 105 | samples[0, : len(waveform)] = waveform 106 | else: 107 | samples = waveform 108 | 109 | samples = samples.reshape(1, -1).astype(np.float32) 110 | 111 | data = { 112 | "inputs":[ 113 | { 114 | "name": "reference_wav", 115 | "shape": samples.shape, 116 | "datatype": "FP32", 117 | "data": samples.tolist() 118 | }, 119 | { 120 | "name": "reference_wav_len", 121 | "shape": lengths.shape, 122 | "datatype": "INT32", 123 | "data": lengths.tolist(), 124 | }, 125 | { 126 | "name": "reference_text", 127 | "shape": [1, 1], 128 | "datatype": "BYTES", 129 | "data": [reference_text] 130 | }, 131 | { 132 | "name": "target_text", 133 | "shape": [1, 1], 134 | "datatype": "BYTES", 135 | "data": [target_text] 136 | } 137 | ] 138 | } 139 | 140 | return data 141 | 142 | if __name__ == "__main__": 143 | args = get_args() 144 | server_url = args.server_url 145 | if not server_url.startswith(("http://", "https://")): 146 | server_url = f"http://{server_url}" 147 | 148 | url = f"{server_url}/v2/models/{args.model_name}/infer" 149 | waveform, sr = sf.read(args.reference_audio) 150 | assert sr == 16000, "sample rate hardcoded in server" 151 | 152 | samples = np.array(waveform, dtype=np.float32) 153 | data = prepare_request(samples, args.reference_text, args.target_text) 154 | 155 | rsp = requests.post( 156 | url, 157 | headers={"Content-Type": "application/json"}, 158 | json=data, 159 | verify=False, 160 | params={"request_id": '0'} 161 | ) 162 | result = rsp.json() 163 | audio = result["outputs"][0]["data"] 164 | audio = np.array(audio, dtype=np.float32) 165 | sf.write(args.output_audio, audio, 16000, "PCM_16") -------------------------------------------------------------------------------- /runtime/triton_trtllm/docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | tts: 3 | image: soar97/triton-spark-tts:25.02 4 | shm_size: '1gb' 5 | ports: 6 | - "8000:8000" 7 | - "8001:8001" 8 | - "8002:8002" 9 | environment: 10 | - PYTHONIOENCODING=utf-8 11 | - MODEL_ID=${MODEL_ID} 12 | deploy: 13 | resources: 14 | reservations: 15 | devices: 16 | - driver: nvidia 17 | device_ids: ['0'] 18 | capabilities: [gpu] 19 | command: > 20 | /bin/bash -c "rm -rf Spark-TTS && git clone https://github.com/SparkAudio/Spark-TTS.git && cd Spark-TTS/runtime/triton_trtllm && bash run.sh 0 3" 21 | -------------------------------------------------------------------------------- /runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | import json 27 | import torch 28 | from torch.utils.dlpack import to_dlpack 29 | 30 | import triton_python_backend_utils as pb_utils 31 | 32 | import os 33 | import numpy as np 34 | 35 | from sparktts.models.audio_tokenizer import BiCodecTokenizer 36 | 37 | class TritonPythonModel: 38 | """Triton Python model for audio tokenization. 39 | 40 | This model takes reference audio input and extracts semantic and global tokens 41 | using BiCodec tokenizer. 42 | """ 43 | 44 | def initialize(self, args): 45 | """Initialize the model. 46 | 47 | Args: 48 | args: Dictionary containing model configuration 49 | """ 50 | # Parse model parameters 51 | parameters = json.loads(args['model_config'])['parameters'] 52 | model_params = {k: v["string_value"] for k, v in parameters.items()} 53 | 54 | # Initialize tokenizer 55 | self.device = torch.device("cuda") 56 | self.audio_tokenizer = BiCodecTokenizer(model_params["model_dir"], 57 | device=self.device) 58 | 59 | def get_ref_clip(self, wav: np.ndarray) -> np.ndarray: 60 | """Extract reference audio clip for speaker embedding. 61 | 62 | Args: 63 | wav: Input waveform array 64 | 65 | Returns: 66 | Reference clip of fixed duration 67 | """ 68 | SAMPLE_RATE = 16000 69 | REF_SEGMENT_DURATION = 6 # seconds 70 | LATENT_HOP_LENGTH = 320 71 | 72 | ref_segment_length = ( 73 | int(SAMPLE_RATE * REF_SEGMENT_DURATION) 74 | // LATENT_HOP_LENGTH 75 | * LATENT_HOP_LENGTH 76 | ) 77 | wav_length = len(wav) 78 | 79 | if ref_segment_length > wav_length: 80 | # Repeat and truncate if input is too short 81 | repeat_times = ref_segment_length // wav_length + 1 82 | wav = np.tile(wav, repeat_times) 83 | 84 | return wav[:ref_segment_length] 85 | 86 | def execute(self, requests): 87 | """Execute inference on the batched requests. 88 | 89 | Args: 90 | requests: List of inference requests 91 | 92 | Returns: 93 | List of inference responses containing tokenized outputs 94 | """ 95 | reference_wav_list = [] 96 | reference_wav_ref_clip_list = [] 97 | 98 | # Process each request in batch 99 | for request in requests: 100 | # Extract input tensors 101 | wav_array = pb_utils.get_input_tensor_by_name( 102 | request, "reference_wav").as_numpy() 103 | wav_len = pb_utils.get_input_tensor_by_name( 104 | request, "reference_wav_len").as_numpy().item() 105 | 106 | # Prepare inputs 107 | wav = wav_array[:, :wav_len].squeeze(0) 108 | reference_wav_list.append(wav) 109 | 110 | wav_ref_clip = self.get_ref_clip(wav) 111 | reference_wav_ref_clip_list.append(torch.from_numpy(wav_ref_clip)) 112 | 113 | # Batch process through tokenizer 114 | ref_wav_clip_tensor = torch.stack(reference_wav_ref_clip_list, dim=0) 115 | wav2vec2_features = self.audio_tokenizer.extract_wav2vec2_features( 116 | reference_wav_list) 117 | 118 | audio_tokenizer_input = { 119 | "ref_wav": ref_wav_clip_tensor.to(self.device), 120 | "feat": wav2vec2_features.to(self.device), 121 | } 122 | semantic_tokens, global_tokens = self.audio_tokenizer.model.tokenize( 123 | audio_tokenizer_input) 124 | 125 | # Prepare responses 126 | responses = [] 127 | for i in range(len(requests)): 128 | global_tokens_tensor = pb_utils.Tensor.from_dlpack( 129 | "global_tokens", to_dlpack(global_tokens[i])) 130 | semantic_tokens_tensor = pb_utils.Tensor.from_dlpack( 131 | "semantic_tokens", to_dlpack(semantic_tokens[i])) 132 | 133 | inference_response = pb_utils.InferenceResponse( 134 | output_tensors=[global_tokens_tensor, semantic_tokens_tensor]) 135 | responses.append(inference_response) 136 | 137 | return responses 138 | -------------------------------------------------------------------------------- /runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | name: "audio_tokenizer" 16 | backend: "python" 17 | max_batch_size: ${triton_max_batch_size} 18 | dynamic_batching { 19 | max_queue_delay_microseconds: ${max_queue_delay_microseconds} 20 | } 21 | parameters [ 22 | { 23 | key: "model_dir", 24 | value: {string_value:"${model_dir}"} 25 | } 26 | ] 27 | 28 | input [ 29 | { 30 | name: "reference_wav" 31 | data_type: TYPE_FP32 32 | dims: [-1] 33 | }, 34 | { 35 | name: "reference_wav_len" 36 | data_type: TYPE_INT32 37 | dims: [1] 38 | } 39 | ] 40 | output [ 41 | { 42 | name: "global_tokens" 43 | data_type: TYPE_INT32 44 | dims: [-1] 45 | }, 46 | { 47 | name: "semantic_tokens" 48 | data_type: TYPE_INT32 49 | dims: [-1] 50 | } 51 | ] 52 | 53 | instance_group [ 54 | { 55 | count: 1 56 | kind: KIND_CPU 57 | } 58 | ] -------------------------------------------------------------------------------- /runtime/triton_trtllm/model_repo/spark_tts/config.pbtxt: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | name: "spark_tts" 16 | backend: "python" 17 | max_batch_size: ${triton_max_batch_size} 18 | dynamic_batching { 19 | max_queue_delay_microseconds: ${max_queue_delay_microseconds} 20 | } 21 | model_transaction_policy { 22 | decoupled: ${decoupled_mode} 23 | } 24 | parameters [ 25 | { 26 | key: "llm_tokenizer_dir", 27 | value: {string_value:"${llm_tokenizer_dir}"} 28 | }, 29 | { 30 | key: "audio_chunk_duration", 31 | value: {string_value:"${audio_chunk_duration}"} 32 | }, 33 | { 34 | key: "audio_chunk_size_scale_factor", 35 | value: {string_value:"${audio_chunk_size_scale_factor}"} 36 | }, 37 | { 38 | key: "max_audio_chunk_duration", 39 | value: {string_value:"${max_audio_chunk_duration}"} 40 | }, 41 | { 42 | key: "audio_chunk_overlap_duration", 43 | value: {string_value:"${audio_chunk_overlap_duration}"} 44 | }, 45 | { 46 | key: "audio_tokenizer_frame_rate", 47 | value: {string_value:"50"} 48 | } 49 | ] 50 | 51 | input [ 52 | { 53 | name: "reference_wav" 54 | data_type: TYPE_FP32 55 | dims: [-1] 56 | }, 57 | { 58 | name: "reference_wav_len" 59 | data_type: TYPE_INT32 60 | dims: [1] 61 | }, 62 | { 63 | name: "reference_text" 64 | data_type: TYPE_STRING 65 | dims: [1] 66 | }, 67 | { 68 | name: "target_text" 69 | data_type: TYPE_STRING 70 | dims: [1] 71 | } 72 | ] 73 | output [ 74 | { 75 | name: "waveform" 76 | data_type: TYPE_FP32 77 | dims: [ -1 ] 78 | } 79 | ] 80 | 81 | instance_group [ 82 | { 83 | count: ${bls_instance_num} 84 | kind: KIND_CPU 85 | } 86 | ] -------------------------------------------------------------------------------- /runtime/triton_trtllm/model_repo/tensorrt_llm/1/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/runtime/triton_trtllm/model_repo/tensorrt_llm/1/.gitkeep -------------------------------------------------------------------------------- /runtime/triton_trtllm/model_repo/vocoder/1/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | import json 28 | import os 29 | import logging 30 | from typing import List, Dict 31 | 32 | import torch 33 | from torch.utils.dlpack import to_dlpack 34 | 35 | import triton_python_backend_utils as pb_utils 36 | 37 | from sparktts.models.bicodec import BiCodec 38 | 39 | # Configure logging 40 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 41 | logger = logging.getLogger(__name__) 42 | 43 | class TritonPythonModel: 44 | """Triton Python model for vocoder. 45 | 46 | This model takes global and semantic tokens as input and generates audio waveforms 47 | using the BiCodec vocoder. 48 | """ 49 | 50 | def initialize(self, args): 51 | """Initialize the model. 52 | 53 | Args: 54 | args: Dictionary containing model configuration 55 | """ 56 | # Parse model parameters 57 | parameters = json.loads(args['model_config'])['parameters'] 58 | model_params = {key: value["string_value"] for key, value in parameters.items()} 59 | model_dir = model_params["model_dir"] 60 | 61 | # Initialize device and vocoder 62 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 63 | logger.info(f"Initializing vocoder from {model_dir} on {self.device}") 64 | 65 | self.vocoder = BiCodec.load_from_checkpoint(f"{model_dir}/BiCodec") 66 | del self.vocoder.encoder, self.vocoder.postnet 67 | self.vocoder.eval().to(self.device) # Set model to evaluation mode 68 | 69 | logger.info("Vocoder initialized successfully") 70 | 71 | 72 | def execute(self, requests): 73 | """Execute inference on the batched requests. 74 | 75 | Args: 76 | requests: List of inference requests 77 | 78 | Returns: 79 | List of inference responses containing generated waveforms 80 | """ 81 | global_tokens_list, semantic_tokens_list = [], [] 82 | 83 | # Process each request in batch 84 | for request in requests: 85 | global_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "global_tokens").as_numpy() 86 | semantic_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "semantic_tokens").as_numpy() 87 | global_tokens_list.append(torch.from_numpy(global_tokens_tensor).to(self.device)) 88 | semantic_tokens_list.append(torch.from_numpy(semantic_tokens_tensor).to(self.device)) 89 | 90 | # Concatenate tokens for batch processing 91 | global_tokens = torch.cat(global_tokens_list, dim=0) 92 | semantic_tokens = torch.cat(semantic_tokens_list, dim=0) 93 | 94 | 95 | # Generate waveforms 96 | with torch.no_grad(): 97 | wavs = self.vocoder.detokenize(semantic_tokens, global_tokens.unsqueeze(1)) 98 | 99 | # Prepare responses 100 | responses = [] 101 | for i in range(len(requests)): 102 | wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(wavs[i])) 103 | inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor]) 104 | responses.append(inference_response) 105 | 106 | return responses 107 | -------------------------------------------------------------------------------- /runtime/triton_trtllm/model_repo/vocoder/config.pbtxt: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | name: "vocoder" 16 | backend: "python" 17 | max_batch_size: ${triton_max_batch_size} 18 | dynamic_batching { 19 | max_queue_delay_microseconds: ${max_queue_delay_microseconds} 20 | } 21 | parameters [ 22 | { 23 | key: "model_dir", 24 | value: {string_value:"${model_dir}"} 25 | } 26 | ] 27 | 28 | input [ 29 | { 30 | name: "global_tokens" 31 | data_type: TYPE_INT32 32 | dims: [-1] 33 | }, 34 | { 35 | name: "semantic_tokens" 36 | data_type: TYPE_INT32 37 | dims: [-1] 38 | } 39 | ] 40 | output [ 41 | { 42 | name: "waveform" 43 | data_type: TYPE_FP32 44 | dims: [ -1 ] 45 | } 46 | ] 47 | 48 | instance_group [ 49 | { 50 | count: 1 51 | kind: KIND_CPU 52 | } 53 | ] -------------------------------------------------------------------------------- /runtime/triton_trtllm/run.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=../../../Spark-TTS/ 2 | export CUDA_VISIBLE_DEVICES=0 3 | stage=$1 4 | stop_stage=$2 5 | service_type=$3 6 | echo "Start stage: $stage, Stop stage: $stop_stage service_type: $service_type" 7 | 8 | huggingface_model_local_dir=../../pretrained_models/Spark-TTS-0.5B 9 | trt_dtype=bfloat16 10 | trt_weights_dir=./tllm_checkpoint_${trt_dtype} 11 | trt_engines_dir=./trt_engines_${trt_dtype} 12 | 13 | model_repo=./model_repo_test 14 | 15 | if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then 16 | echo "Downloading Spark-TTS-0.5B from HuggingFace" 17 | huggingface-cli download SparkAudio/Spark-TTS-0.5B --local-dir $huggingface_model_local_dir || exit 1 18 | fi 19 | 20 | 21 | if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then 22 | echo "Converting checkpoint to TensorRT weights" 23 | python scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir/LLM \ 24 | --output_dir $trt_weights_dir \ 25 | --dtype $trt_dtype || exit 1 26 | 27 | echo "Building TensorRT engines" 28 | trtllm-build --checkpoint_dir $trt_weights_dir \ 29 | --output_dir $trt_engines_dir \ 30 | --max_batch_size 16 \ 31 | --max_num_tokens 32768 \ 32 | --gemm_plugin $trt_dtype || exit 1 33 | fi 34 | 35 | if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then 36 | echo "Creating model repository" 37 | rm -rf $model_repo 38 | mkdir -p $model_repo 39 | spark_tts_dir="spark_tts" 40 | 41 | cp -r ./model_repo/${spark_tts_dir} $model_repo 42 | cp -r ./model_repo/audio_tokenizer $model_repo 43 | cp -r ./model_repo/tensorrt_llm $model_repo 44 | cp -r ./model_repo/vocoder $model_repo 45 | 46 | ENGINE_PATH=$trt_engines_dir 47 | MAX_QUEUE_DELAY_MICROSECONDS=0 48 | MODEL_DIR=$huggingface_model_local_dir 49 | LLM_TOKENIZER_DIR=$huggingface_model_local_dir/LLM 50 | BLS_INSTANCE_NUM=4 51 | TRITON_MAX_BATCH_SIZE=16 52 | # streaming TTS parameters 53 | AUDIO_CHUNK_DURATION=1.0 54 | MAX_AUDIO_CHUNK_DURATION=30.0 55 | AUDIO_CHUNK_SIZE_SCALE_FACTOR=8.0 56 | AUDIO_CHUNK_OVERLAP_DURATION=0.1 57 | python3 scripts/fill_template.py -i ${model_repo}/vocoder/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} 58 | python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} 59 | if [ "$service_type" == "streaming" ]; then 60 | DECOUPLED_MODE=True 61 | else 62 | DECOUPLED_MODE=False 63 | fi 64 | python3 scripts/fill_template.py -i ${model_repo}/${spark_tts_dir}/config.pbtxt bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},audio_chunk_duration:${AUDIO_CHUNK_DURATION},max_audio_chunk_duration:${MAX_AUDIO_CHUNK_DURATION},audio_chunk_size_scale_factor:${AUDIO_CHUNK_SIZE_SCALE_FACTOR},audio_chunk_overlap_duration:${AUDIO_CHUNK_OVERLAP_DURATION} 65 | python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32 66 | 67 | fi 68 | 69 | if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then 70 | echo "Starting Triton server" 71 | tritonserver --model-repository ${model_repo} 72 | fi 73 | 74 | 75 | if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then 76 | echo "Running benchmark client" 77 | num_task=2 78 | if [ "$service_type" == "streaming" ]; then 79 | mode="streaming" 80 | else 81 | mode="offline" 82 | fi 83 | python3 client_grpc.py \ 84 | --server-addr localhost \ 85 | --model-name spark_tts \ 86 | --num-tasks $num_task \ 87 | --mode $mode \ 88 | --log-dir ./log_concurrent_tasks_${num_task}_${mode}_new 89 | fi 90 | 91 | if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then 92 | echo "Running single utterance client" 93 | if [ "$service_type" == "streaming" ]; then 94 | python client_grpc.py \ 95 | --server-addr localhost \ 96 | --reference-audio ../../example/prompt_audio.wav \ 97 | --reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \ 98 | --target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \ 99 | --model-name spark_tts \ 100 | --chunk-overlap-duration 0.1 \ 101 | --mode streaming 102 | else 103 | python client_http.py \ 104 | --reference-audio ../../example/prompt_audio.wav \ 105 | --reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \ 106 | --target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \ 107 | --model-name spark_tts 108 | fi 109 | fi -------------------------------------------------------------------------------- /runtime/triton_trtllm/scripts/fill_template.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | from argparse import ArgumentParser 3 | from string import Template 4 | 5 | 6 | def split(string, delimiter): 7 | """Split a string using delimiter. Supports escaping. 8 | 9 | Args: 10 | string (str): The string to split. 11 | delimiter (str): The delimiter to split the string with. 12 | 13 | Returns: 14 | list: A list of strings. 15 | """ 16 | result = [] 17 | current = "" 18 | escape = False 19 | for char in string: 20 | if escape: 21 | current += char 22 | escape = False 23 | elif char == delimiter: 24 | result.append(current) 25 | current = "" 26 | elif char == "\\": 27 | escape = True 28 | else: 29 | current += char 30 | result.append(current) 31 | return result 32 | 33 | 34 | def main(file_path, substitutions, in_place): 35 | with open(file_path) as f: 36 | pbtxt = Template(f.read()) 37 | 38 | sub_dict = { 39 | "max_queue_size": 0, 40 | 'max_queue_delay_microseconds': 0, 41 | } 42 | for sub in split(substitutions, ","): 43 | key, value = split(sub, ":") 44 | sub_dict[key] = value 45 | 46 | assert key in pbtxt.template, f"key '{key}' does not exist in the file {file_path}." 47 | 48 | pbtxt = pbtxt.safe_substitute(sub_dict) 49 | 50 | if in_place: 51 | with open(file_path, "w") as f: 52 | f.write(pbtxt) 53 | else: 54 | print(pbtxt) 55 | 56 | 57 | if __name__ == "__main__": 58 | parser = ArgumentParser() 59 | parser.add_argument("file_path", help="path of the .pbtxt to modify") 60 | parser.add_argument( 61 | "substitutions", 62 | help= 63 | "substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..." 64 | ) 65 | parser.add_argument("--in_place", 66 | "-i", 67 | action="store_true", 68 | help="do the operation in-place") 69 | args = parser.parse_args() 70 | main(**vars(args)) 71 | -------------------------------------------------------------------------------- /sparktts/models/audio_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 SparkAudio 2 | # 2025 Xinsheng Wang (w.xinshawn@gmail.com) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import torch 18 | import numpy as np 19 | 20 | from pathlib import Path 21 | from typing import Any, Dict, Tuple 22 | from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model 23 | 24 | from sparktts.utils.file import load_config 25 | from sparktts.utils.audio import load_audio 26 | from sparktts.models.bicodec import BiCodec 27 | 28 | 29 | class BiCodecTokenizer: 30 | """BiCodec tokenizer for handling audio input and tokenization.""" 31 | 32 | def __init__(self, model_dir: Path, device: torch.device = None, **kwargs): 33 | super().__init__() 34 | """ 35 | Args: 36 | model_dir: Path to the model directory. 37 | device: Device to run the model on (default is GPU if available). 38 | """ 39 | self.device = device 40 | self.model_dir = model_dir 41 | self.config = load_config(f"{model_dir}/config.yaml") 42 | self._initialize_model() 43 | 44 | def _initialize_model(self): 45 | """Load and initialize the BiCodec model and Wav2Vec2 feature extractor.""" 46 | self.model = BiCodec.load_from_checkpoint(f"{self.model_dir}/BiCodec").to( 47 | self.device 48 | ) 49 | self.processor = Wav2Vec2FeatureExtractor.from_pretrained( 50 | f"{self.model_dir}/wav2vec2-large-xlsr-53" 51 | ) 52 | self.feature_extractor = Wav2Vec2Model.from_pretrained( 53 | f"{self.model_dir}/wav2vec2-large-xlsr-53" 54 | ).to(self.device) 55 | self.feature_extractor.config.output_hidden_states = True 56 | 57 | def get_ref_clip(self, wav: np.ndarray) -> np.ndarray: 58 | """Get reference audio clip for speaker embedding.""" 59 | ref_segment_length = ( 60 | int(self.config["sample_rate"] * self.config["ref_segment_duration"]) 61 | // self.config["latent_hop_length"] 62 | * self.config["latent_hop_length"] 63 | ) 64 | wav_length = len(wav) 65 | 66 | if ref_segment_length > wav_length: 67 | # Repeat and truncate to handle insufficient length 68 | wav = np.tile(wav, ref_segment_length // wav_length + 1) 69 | 70 | return wav[:ref_segment_length] 71 | 72 | def process_audio(self, wav_path: Path) -> Tuple[np.ndarray, torch.Tensor]: 73 | """load auido and get reference audio from wav path""" 74 | wav = load_audio( 75 | wav_path, 76 | sampling_rate=self.config["sample_rate"], 77 | volume_normalize=self.config["volume_normalize"], 78 | ) 79 | 80 | wav_ref = self.get_ref_clip(wav) 81 | 82 | wav_ref = torch.from_numpy(wav_ref).unsqueeze(0).float() 83 | return wav, wav_ref 84 | 85 | def extract_wav2vec2_features(self, wavs: torch.Tensor) -> torch.Tensor: 86 | """extract wav2vec2 features""" 87 | inputs = self.processor( 88 | wavs, 89 | sampling_rate=16000, 90 | return_tensors="pt", 91 | padding=True, 92 | output_hidden_states=True, 93 | ).input_values 94 | feat = self.feature_extractor(inputs.to(self.feature_extractor.device)) 95 | feats_mix = ( 96 | feat.hidden_states[11] + feat.hidden_states[14] + feat.hidden_states[16] 97 | ) / 3 98 | 99 | return feats_mix 100 | 101 | def tokenize_batch(self, batch: Dict[str, Any]) -> torch.Tensor: 102 | """tokenize the batch of audio 103 | 104 | Args: 105 | batch: 106 | wavs (List[np.ndarray]): batch of audio 107 | ref_wavs (torch.Tensor): reference audio. shape: (batch_size, seq_len) 108 | 109 | Returns: 110 | semantic_tokens: semantic tokens. shape: (batch_size, seq_len, latent_dim) 111 | global_tokens: global tokens. shape: (batch_size, seq_len, global_dim) 112 | """ 113 | feats = self.extract_wav2vec2_features(batch["wav"]) 114 | batch["feat"] = feats 115 | semantic_tokens, global_tokens = self.model.tokenize(batch) 116 | 117 | return global_tokens, semantic_tokens 118 | 119 | def tokenize(self, audio_path: str) -> Tuple[torch.Tensor, torch.Tensor]: 120 | """tokenize the audio""" 121 | wav, ref_wav = self.process_audio(audio_path) 122 | feat = self.extract_wav2vec2_features(wav) 123 | batch = { 124 | "wav": torch.from_numpy(wav).unsqueeze(0).float().to(self.device), 125 | "ref_wav": ref_wav.to(self.device), 126 | "feat": feat.to(self.device), 127 | } 128 | semantic_tokens, global_tokens = self.model.tokenize(batch) 129 | 130 | return global_tokens, semantic_tokens 131 | 132 | def detokenize( 133 | self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor 134 | ) -> np.array: 135 | """detokenize the tokens to waveform 136 | 137 | Args: 138 | global_tokens: global tokens. shape: (batch_size, global_dim) 139 | semantic_tokens: semantic tokens. shape: (batch_size, latent_dim) 140 | 141 | Returns: 142 | wav_rec: waveform. shape: (batch_size, seq_len) for batch or (seq_len,) for single 143 | """ 144 | global_tokens = global_tokens.unsqueeze(1) 145 | wav_rec = self.model.detokenize(semantic_tokens, global_tokens) 146 | return wav_rec.detach().squeeze().cpu().numpy() 147 | 148 | 149 | # test 150 | if __name__ == "__main__": 151 | import soundfile as sf 152 | 153 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 154 | tokenizer = BiCodecTokenizer( 155 | model_dir="pretrained_models/Spark-TTS-0.5B", 156 | device=device, 157 | ) 158 | wav_path = "example/prompt_audio.wav" 159 | 160 | global_tokens, semantic_tokens = tokenizer.tokenize(wav_path) 161 | 162 | wav_rec = tokenizer.detokenize(global_tokens.squeeze(0), semantic_tokens) 163 | sf.write("example/prompt_recon.wav", wav_rec, 16000) 164 | -------------------------------------------------------------------------------- /sparktts/models/bicodec.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 SparkAudio 2 | # 2025 Xinsheng Wang (w.xinshawn@gmail.com) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | import torch.nn as nn 18 | from pathlib import Path 19 | from typing import Dict, Any 20 | from omegaconf import DictConfig 21 | from safetensors.torch import load_file 22 | 23 | from sparktts.utils.file import load_config 24 | from sparktts.modules.speaker.speaker_encoder import SpeakerEncoder 25 | from sparktts.modules.encoder_decoder.feat_encoder import Encoder 26 | from sparktts.modules.encoder_decoder.feat_decoder import Decoder 27 | from sparktts.modules.encoder_decoder.wave_generator import WaveGenerator 28 | from sparktts.modules.vq.factorized_vector_quantize import FactorizedVectorQuantize 29 | 30 | 31 | class BiCodec(nn.Module): 32 | """ 33 | BiCodec model for speech synthesis, incorporating a speaker encoder, feature encoder/decoder, 34 | quantizer, and wave generator. 35 | """ 36 | 37 | def __init__( 38 | self, 39 | mel_params: Dict[str, Any], 40 | encoder: nn.Module, 41 | decoder: nn.Module, 42 | quantizer: nn.Module, 43 | speaker_encoder: nn.Module, 44 | prenet: nn.Module, 45 | postnet: nn.Module, 46 | **kwargs 47 | ) -> None: 48 | """ 49 | Initializes the BiCodec model with the required components. 50 | 51 | Args: 52 | mel_params (dict): Parameters for the mel-spectrogram transformer. 53 | encoder (nn.Module): Encoder module. 54 | decoder (nn.Module): Decoder module. 55 | quantizer (nn.Module): Quantizer module. 56 | speaker_encoder (nn.Module): Speaker encoder module. 57 | prenet (nn.Module): Prenet network. 58 | postnet (nn.Module): Postnet network. 59 | """ 60 | super().__init__() 61 | self.encoder = encoder 62 | self.decoder = decoder 63 | self.quantizer = quantizer 64 | self.speaker_encoder = speaker_encoder 65 | self.prenet = prenet 66 | self.postnet = postnet 67 | self.init_mel_transformer(mel_params) 68 | 69 | @classmethod 70 | def load_from_checkpoint(cls, model_dir: Path, **kwargs) -> "BiCodec": 71 | """ 72 | Loads the model from a checkpoint. 73 | 74 | Args: 75 | model_dir (Path): Path to the model directory containing checkpoint and config. 76 | 77 | Returns: 78 | BiCodec: The initialized BiCodec model. 79 | """ 80 | ckpt_path = f'{model_dir}/model.safetensors' 81 | config = load_config(f'{model_dir}/config.yaml')['audio_tokenizer'] 82 | mel_params = config["mel_params"] 83 | encoder = Encoder(**config["encoder"]) 84 | quantizer = FactorizedVectorQuantize(**config["quantizer"]) 85 | prenet = Decoder(**config["prenet"]) 86 | postnet = Decoder(**config["postnet"]) 87 | decoder = WaveGenerator(**config["decoder"]) 88 | speaker_encoder = SpeakerEncoder(**config["speaker_encoder"]) 89 | 90 | model = cls( 91 | mel_params=mel_params, 92 | encoder=encoder, 93 | decoder=decoder, 94 | quantizer=quantizer, 95 | speaker_encoder=speaker_encoder, 96 | prenet=prenet, 97 | postnet=postnet, 98 | ) 99 | 100 | state_dict = load_file(ckpt_path) 101 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) 102 | 103 | for key in missing_keys: 104 | print(f"Missing tensor: {key}") 105 | for key in unexpected_keys: 106 | print(f"Unexpected tensor: {key}") 107 | 108 | model.eval() 109 | model.remove_weight_norm() 110 | 111 | return model 112 | 113 | def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]: 114 | """ 115 | Performs a forward pass through the model. 116 | 117 | Args: 118 | batch (dict): A dictionary containing features, reference waveform, and target waveform. 119 | 120 | Returns: 121 | dict: A dictionary containing the reconstruction, features, and other metrics. 122 | """ 123 | feat = batch["feat"] 124 | mel = self.mel_transformer(batch["ref_wav"]).squeeze(1) 125 | 126 | z = self.encoder(feat.transpose(1, 2)) 127 | vq_outputs = self.quantizer(z) 128 | 129 | x_vector, d_vector = self.speaker_encoder(mel.transpose(1, 2)) 130 | 131 | conditions = d_vector 132 | with_speaker_loss = False 133 | 134 | x = self.prenet(vq_outputs["z_q"], conditions) 135 | pred_feat = self.postnet(x) 136 | x = x + conditions.unsqueeze(-1) 137 | wav_recon = self.decoder(x) 138 | 139 | return { 140 | "vq_loss": vq_outputs["vq_loss"], 141 | "perplexity": vq_outputs["perplexity"], 142 | "cluster_size": vq_outputs["active_num"], 143 | "recons": wav_recon, 144 | "pred_feat": pred_feat, 145 | "x_vector": x_vector, 146 | "d_vector": d_vector, 147 | "audios": batch["wav"].unsqueeze(1), 148 | "with_speaker_loss": with_speaker_loss, 149 | } 150 | 151 | @torch.no_grad() 152 | def tokenize(self, batch: Dict[str, Any]): 153 | """ 154 | Tokenizes the input audio into semantic and global tokens. 155 | 156 | Args: 157 | batch (dict): The input audio features and reference waveform. 158 | 159 | Returns: 160 | tuple: Semantic tokens and global tokens. 161 | """ 162 | feat = batch["feat"] 163 | mel = self.mel_transformer(batch["ref_wav"]).squeeze(1) 164 | 165 | z = self.encoder(feat.transpose(1, 2)) 166 | semantic_tokens = self.quantizer.tokenize(z) 167 | global_tokens = self.speaker_encoder.tokenize(mel.transpose(1, 2)) 168 | 169 | return semantic_tokens, global_tokens 170 | 171 | @torch.no_grad() 172 | def detokenize(self, semantic_tokens, global_tokens): 173 | """ 174 | Detokenizes the semantic and global tokens into a waveform. 175 | 176 | Args: 177 | semantic_tokens (tensor): Semantic tokens. 178 | global_tokens (tensor): Global tokens. 179 | 180 | Returns: 181 | tensor: Reconstructed waveform. 182 | """ 183 | z_q = self.quantizer.detokenize(semantic_tokens) 184 | d_vector = self.speaker_encoder.detokenize(global_tokens) 185 | x = self.prenet(z_q, d_vector) 186 | x = x + d_vector.unsqueeze(-1) 187 | wav_recon = self.decoder(x) 188 | 189 | return wav_recon 190 | 191 | def init_mel_transformer(self, config: Dict[str, Any]): 192 | """ 193 | Initializes the MelSpectrogram transformer based on the provided configuration. 194 | 195 | Args: 196 | config (dict): Configuration parameters for MelSpectrogram. 197 | """ 198 | import torchaudio.transforms as TT 199 | 200 | self.mel_transformer = TT.MelSpectrogram( 201 | config["sample_rate"], 202 | config["n_fft"], 203 | config["win_length"], 204 | config["hop_length"], 205 | config["mel_fmin"], 206 | config["mel_fmax"], 207 | n_mels=config["num_mels"], 208 | power=1, 209 | norm="slaney", 210 | mel_scale="slaney", 211 | ) 212 | 213 | def remove_weight_norm(self): 214 | """Removes weight normalization from all layers.""" 215 | def _remove_weight_norm(m): 216 | try: 217 | torch.nn.utils.remove_weight_norm(m) 218 | except ValueError: 219 | pass # The module didn't have weight norm 220 | 221 | self.apply(_remove_weight_norm) 222 | 223 | 224 | # Test the model 225 | if __name__ == "__main__": 226 | 227 | config = load_config("pretrained_models/SparkTTS-0.5B/BiCodec/config.yaml") 228 | model = BiCodec.load_from_checkpoint( 229 | model_dir="pretrained_models/SparkTTS-0.5B/BiCodec", 230 | ) 231 | 232 | # Generate random inputs for testing 233 | duration = 0.96 234 | x = torch.randn(20, 1, int(duration * 16000)) 235 | feat = torch.randn(20, int(duration * 50), 1024) 236 | inputs = {"feat": feat, "wav": x, "ref_wav": x} 237 | 238 | # Forward pass 239 | outputs = model(inputs) 240 | semantic_tokens, global_tokens = model.tokenize(inputs) 241 | wav_recon = model.detokenize(semantic_tokens, global_tokens) 242 | 243 | # Verify if the reconstruction matches 244 | if torch.allclose(outputs["recons"].detach(), wav_recon): 245 | print("Test successful") 246 | else: 247 | print("Test failed") 248 | -------------------------------------------------------------------------------- /sparktts/modules/blocks/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 SparkAudio 2 | # 2025 Xinsheng Wang (w.xinshawn@gmail.com) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0 17 | 18 | 19 | import torch 20 | import torch.nn as nn 21 | from torch.nn.utils import weight_norm 22 | 23 | 24 | def WNConv1d(*args, **kwargs): 25 | return weight_norm(nn.Conv1d(*args, **kwargs)) 26 | 27 | 28 | def WNConvTranspose1d(*args, **kwargs): 29 | return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) 30 | 31 | 32 | # Scripting this brings model speed up 1.4x 33 | @torch.jit.script 34 | def snake(x, alpha): 35 | shape = x.shape 36 | x = x.reshape(shape[0], shape[1], -1) 37 | x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) 38 | x = x.reshape(shape) 39 | return x 40 | 41 | 42 | class Snake1d(nn.Module): 43 | def __init__(self, channels): 44 | super().__init__() 45 | self.alpha = nn.Parameter(torch.ones(1, channels, 1)) 46 | 47 | def forward(self, x): 48 | return snake(x, self.alpha) 49 | 50 | 51 | class ResidualUnit(nn.Module): 52 | def __init__(self, dim: int = 16, dilation: int = 1): 53 | super().__init__() 54 | pad = ((7 - 1) * dilation) // 2 55 | self.block = nn.Sequential( 56 | Snake1d(dim), 57 | WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), 58 | Snake1d(dim), 59 | WNConv1d(dim, dim, kernel_size=1), 60 | ) 61 | 62 | def forward(self, x): 63 | y = self.block(x) 64 | pad = (x.shape[-1] - y.shape[-1]) // 2 65 | if pad > 0: 66 | x = x[..., pad:-pad] 67 | return x + y 68 | 69 | 70 | def init_weights(m): 71 | if isinstance(m, nn.Conv1d): 72 | nn.init.trunc_normal_(m.weight, std=0.02) 73 | nn.init.constant_(m.bias, 0) 74 | -------------------------------------------------------------------------------- /sparktts/modules/blocks/samper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 SparkAudio 2 | # 2025 Xinsheng Wang (w.xinshawn@gmail.com) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | 22 | class SamplingBlock(nn.Module): 23 | """Sampling block for upsampling or downsampling""" 24 | 25 | def __init__( 26 | self, 27 | dim: int, 28 | groups: int = 1, 29 | upsample_scale: int = 1, 30 | downsample_scale: int = 1, 31 | ) -> None: 32 | """ 33 | Args: 34 | dim: input dimension 35 | groups: number of groups 36 | upsample_scale: upsampling scale 37 | downsample_scale: downsampling scale 38 | """ 39 | super(SamplingBlock, self).__init__() 40 | 41 | self.upsample_scale = upsample_scale 42 | self.downsample_scale = downsample_scale 43 | 44 | if self.upsample_scale > 1: 45 | self.de_conv_upsampler = nn.Sequential( 46 | nn.LeakyReLU(0.2), 47 | nn.ConvTranspose1d( 48 | dim, 49 | dim, 50 | kernel_size=upsample_scale * 2, 51 | stride=upsample_scale, 52 | padding=upsample_scale // 2 + upsample_scale % 2, 53 | output_padding=upsample_scale % 2, 54 | groups=groups, 55 | ), 56 | ) 57 | 58 | if self.downsample_scale > 1: 59 | self.conv_downsampler = nn.Sequential( 60 | nn.LeakyReLU(0.2), 61 | nn.Conv1d( 62 | dim, 63 | dim, 64 | kernel_size=2 * downsample_scale, 65 | stride=downsample_scale, 66 | padding=downsample_scale // 2 + downsample_scale % 2, 67 | groups=groups, 68 | ), 69 | ) 70 | 71 | @staticmethod 72 | def repeat_upsampler(x, upsample_scale): 73 | return x.repeat_interleave(upsample_scale, dim=2) 74 | 75 | @staticmethod 76 | def skip_downsampler(x, downsample_scale): 77 | return F.avg_pool1d(x, kernel_size=downsample_scale, stride=downsample_scale) 78 | 79 | def forward(self, x): 80 | x = x.transpose(1, 2) 81 | if self.upsample_scale > 1: 82 | repeat_res = self.repeat_upsampler(x, self.upsample_scale) 83 | deconv_res = self.de_conv_upsampler(x) 84 | upmerge_res = repeat_res + deconv_res 85 | else: 86 | upmerge_res = x 87 | repeat_res = x 88 | 89 | if self.downsample_scale > 1: 90 | conv_res = self.conv_downsampler(upmerge_res) 91 | skip2_res = self.skip_downsampler(upmerge_res, self.downsample_scale) 92 | skip1_res = self.skip_downsampler(repeat_res, self.downsample_scale) 93 | else: 94 | conv_res = upmerge_res 95 | skip2_res = upmerge_res 96 | skip1_res = repeat_res 97 | 98 | final_res = conv_res + skip1_res + skip2_res 99 | 100 | return final_res 101 | 102 | 103 | # test 104 | if __name__ == "__main__": 105 | test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50 106 | model = SamplingBlock(1024, 1024, upsample_scale=2) 107 | model_down = SamplingBlock(1024, 1024, downsample_scale=2) 108 | output = model(test_input) 109 | output_down = model_down(test_input) 110 | print("shape after upsample * 2", output.shape) # torch.Size([8, 1024, 100]) 111 | print("shape after downsample * 2", output_down.shape) # torch.Size([8, 1024, 25]) 112 | if output.shape == torch.Size([8, 1024, 100]) and output_down.shape == torch.Size( 113 | [8, 1024, 25] 114 | ): 115 | print("test successful") 116 | -------------------------------------------------------------------------------- /sparktts/modules/encoder_decoder/feat_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 SparkAudio 2 | # 2025 Xinsheng Wang (w.xinshawn@gmail.com) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | from typing import List 21 | 22 | from sparktts.modules.blocks.vocos import VocosBackbone 23 | from sparktts.modules.blocks.samper import SamplingBlock 24 | 25 | 26 | class Decoder(nn.Module): 27 | """Decoder module with convnext and upsampling blocks 28 | 29 | Args: 30 | sample_ratios (List[int]): sample ratios 31 | example: [2, 2] means downsample by 2x and then upsample by 2x 32 | """ 33 | 34 | def __init__( 35 | self, 36 | input_channels: int, 37 | vocos_dim: int, 38 | vocos_intermediate_dim: int, 39 | vocos_num_layers: int, 40 | out_channels: int, 41 | condition_dim: int = None, 42 | sample_ratios: List[int] = [1, 1], 43 | use_tanh_at_final: bool = False, 44 | ): 45 | super().__init__() 46 | 47 | self.linear_pre = nn.Linear(input_channels, vocos_dim) 48 | modules = [ 49 | nn.Sequential( 50 | SamplingBlock( 51 | dim=vocos_dim, 52 | groups=vocos_dim, 53 | upsample_scale=ratio, 54 | ), 55 | VocosBackbone( 56 | input_channels=vocos_dim, 57 | dim=vocos_dim, 58 | intermediate_dim=vocos_intermediate_dim, 59 | num_layers=2, 60 | condition_dim=None, 61 | ), 62 | ) 63 | for ratio in sample_ratios 64 | ] 65 | 66 | self.downsample = nn.Sequential(*modules) 67 | 68 | self.vocos_backbone = VocosBackbone( 69 | input_channels=vocos_dim, 70 | dim=vocos_dim, 71 | intermediate_dim=vocos_intermediate_dim, 72 | num_layers=vocos_num_layers, 73 | condition_dim=condition_dim, 74 | ) 75 | self.linear = nn.Linear(vocos_dim, out_channels) 76 | self.use_tanh_at_final = use_tanh_at_final 77 | 78 | def forward(self, x: torch.Tensor, c: torch.Tensor = None): 79 | """encoder forward. 80 | 81 | Args: 82 | x (torch.Tensor): (batch_size, input_channels, length) 83 | 84 | Returns: 85 | x (torch.Tensor): (batch_size, encode_channels, length) 86 | """ 87 | x = self.linear_pre(x.transpose(1, 2)) 88 | x = self.downsample(x).transpose(1, 2) 89 | x = self.vocos_backbone(x, condition=c) 90 | x = self.linear(x).transpose(1, 2) 91 | if self.use_tanh_at_final: 92 | x = torch.tanh(x) 93 | 94 | return x 95 | 96 | 97 | # test 98 | if __name__ == "__main__": 99 | test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50 100 | condition = torch.randn(8, 256) 101 | decoder = Decoder( 102 | input_channels=1024, 103 | vocos_dim=384, 104 | vocos_intermediate_dim=2048, 105 | vocos_num_layers=12, 106 | out_channels=256, 107 | condition_dim=256, 108 | sample_ratios=[2, 2], 109 | ) 110 | output = decoder(test_input, condition) 111 | print(output.shape) # torch.Size([8, 256, 200]) 112 | if output.shape == torch.Size([8, 256, 200]): 113 | print("Decoder test passed") 114 | else: 115 | print("Decoder test failed") 116 | -------------------------------------------------------------------------------- /sparktts/modules/encoder_decoder/feat_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 SparkAudio 2 | # 2025 Xinsheng Wang (w.xinshawn@gmail.com) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | from typing import List 21 | 22 | from sparktts.modules.blocks.vocos import VocosBackbone 23 | from sparktts.modules.blocks.samper import SamplingBlock 24 | 25 | 26 | class Encoder(nn.Module): 27 | """Encoder module with convnext and downsampling blocks""" 28 | 29 | def __init__( 30 | self, 31 | input_channels: int, 32 | vocos_dim: int, 33 | vocos_intermediate_dim: int, 34 | vocos_num_layers: int, 35 | out_channels: int, 36 | sample_ratios: List[int] = [1, 1], 37 | ): 38 | super().__init__() 39 | """ 40 | Encoder module with VocosBackbone and sampling blocks. 41 | 42 | Args: 43 | sample_ratios (List[int]): sample ratios 44 | example: [2, 2] means downsample by 2x and then upsample by 2x 45 | """ 46 | self.encoder = VocosBackbone( 47 | input_channels=input_channels, 48 | dim=vocos_dim, 49 | intermediate_dim=vocos_intermediate_dim, 50 | num_layers=vocos_num_layers, 51 | condition_dim=None, 52 | ) 53 | 54 | modules = [ 55 | nn.Sequential( 56 | SamplingBlock( 57 | dim=vocos_dim, 58 | groups=vocos_dim, 59 | downsample_scale=ratio, 60 | ), 61 | VocosBackbone( 62 | input_channels=vocos_dim, 63 | dim=vocos_dim, 64 | intermediate_dim=vocos_intermediate_dim, 65 | num_layers=2, 66 | condition_dim=None, 67 | ), 68 | ) 69 | for ratio in sample_ratios 70 | ] 71 | 72 | self.downsample = nn.Sequential(*modules) 73 | 74 | self.project = nn.Linear(vocos_dim, out_channels) 75 | 76 | def forward(self, x: torch.Tensor, *args): 77 | """ 78 | Args: 79 | x (torch.Tensor): (batch_size, input_channels, length) 80 | 81 | Returns: 82 | x (torch.Tensor): (batch_size, encode_channels, length) 83 | """ 84 | x = self.encoder(x) 85 | x = self.downsample(x) 86 | x = self.project(x) 87 | return x.transpose(1, 2) 88 | 89 | 90 | # test 91 | if __name__ == "__main__": 92 | test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50 93 | encoder = Encoder( 94 | input_channels=1024, 95 | vocos_dim=384, 96 | vocos_intermediate_dim=2048, 97 | vocos_num_layers=12, 98 | out_channels=256, 99 | sample_ratios=[2, 2], 100 | ) 101 | 102 | output = encoder(test_input) 103 | print(output.shape) # torch.Size([8, 256, 12]) 104 | if output.shape == torch.Size([8, 256, 12]): 105 | print("test successful") 106 | -------------------------------------------------------------------------------- /sparktts/modules/encoder_decoder/wave_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Xinsheng Wang (w.xinshawn@gmail.com) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0 16 | 17 | 18 | import torch.nn as nn 19 | 20 | from sparktts.modules.blocks.layers import ( 21 | Snake1d, 22 | WNConv1d, 23 | ResidualUnit, 24 | WNConvTranspose1d, 25 | init_weights, 26 | ) 27 | 28 | 29 | class DecoderBlock(nn.Module): 30 | def __init__( 31 | self, 32 | input_dim: int = 16, 33 | output_dim: int = 8, 34 | kernel_size: int = 2, 35 | stride: int = 1, 36 | ): 37 | super().__init__() 38 | self.block = nn.Sequential( 39 | Snake1d(input_dim), 40 | WNConvTranspose1d( 41 | input_dim, 42 | output_dim, 43 | kernel_size=kernel_size, 44 | stride=stride, 45 | padding=(kernel_size - stride) // 2, 46 | ), 47 | ResidualUnit(output_dim, dilation=1), 48 | ResidualUnit(output_dim, dilation=3), 49 | ResidualUnit(output_dim, dilation=9), 50 | ) 51 | 52 | def forward(self, x): 53 | return self.block(x) 54 | 55 | 56 | class WaveGenerator(nn.Module): 57 | def __init__( 58 | self, 59 | input_channel, 60 | channels, 61 | rates, 62 | kernel_sizes, 63 | d_out: int = 1, 64 | ): 65 | super().__init__() 66 | 67 | # Add first conv layer 68 | layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] 69 | 70 | # Add upsampling + MRF blocks 71 | for i, (kernel_size, stride) in enumerate(zip(kernel_sizes, rates)): 72 | input_dim = channels // 2**i 73 | output_dim = channels // 2 ** (i + 1) 74 | layers += [DecoderBlock(input_dim, output_dim, kernel_size, stride)] 75 | 76 | # Add final conv layer 77 | layers += [ 78 | Snake1d(output_dim), 79 | WNConv1d(output_dim, d_out, kernel_size=7, padding=3), 80 | nn.Tanh(), 81 | ] 82 | 83 | self.model = nn.Sequential(*layers) 84 | 85 | self.apply(init_weights) 86 | 87 | def forward(self, x): 88 | return self.model(x) 89 | -------------------------------------------------------------------------------- /sparktts/modules/fsq/finite_scalar_quantization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505 3 | Code adapted from Jax version in Appendix A.1 4 | """ 5 | 6 | from __future__ import annotations 7 | from functools import wraps, partial 8 | from contextlib import nullcontext 9 | from typing import List, Tuple 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.nn import Module 14 | from torch import Tensor, int32 15 | from torch.amp import autocast 16 | 17 | from einops import rearrange, pack, unpack 18 | 19 | # helper functions 20 | 21 | 22 | def exists(v): 23 | return v is not None 24 | 25 | 26 | def default(*args): 27 | for arg in args: 28 | if exists(arg): 29 | return arg 30 | return None 31 | 32 | 33 | def maybe(fn): 34 | @wraps(fn) 35 | def inner(x, *args, **kwargs): 36 | if not exists(x): 37 | return x 38 | return fn(x, *args, **kwargs) 39 | 40 | return inner 41 | 42 | 43 | def pack_one(t, pattern): 44 | return pack([t], pattern) 45 | 46 | 47 | def unpack_one(t, ps, pattern): 48 | return unpack(t, ps, pattern)[0] 49 | 50 | 51 | # tensor helpers 52 | 53 | 54 | def round_ste(z: Tensor) -> Tensor: 55 | """Round with straight through gradients.""" 56 | zhat = z.round() 57 | return z + (zhat - z).detach() 58 | 59 | 60 | # main class 61 | 62 | 63 | class FSQ(Module): 64 | def __init__( 65 | self, 66 | levels: List[int], 67 | dim: int | None = None, 68 | num_codebooks=1, 69 | keep_num_codebooks_dim: bool | None = None, 70 | scale: float | None = None, 71 | allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64), 72 | channel_first: bool = False, 73 | projection_has_bias: bool = True, 74 | return_indices=True, 75 | force_quantization_f32=True, 76 | ): 77 | super().__init__() 78 | _levels = torch.tensor(levels, dtype=int32) 79 | self.register_buffer("_levels", _levels, persistent=False) 80 | 81 | _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32) 82 | self.register_buffer("_basis", _basis, persistent=False) 83 | 84 | self.scale = scale 85 | 86 | codebook_dim = len(levels) 87 | self.codebook_dim = codebook_dim 88 | 89 | effective_codebook_dim = codebook_dim * num_codebooks 90 | self.num_codebooks = num_codebooks 91 | self.effective_codebook_dim = effective_codebook_dim 92 | 93 | keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) 94 | assert not (num_codebooks > 1 and not keep_num_codebooks_dim) 95 | self.keep_num_codebooks_dim = keep_num_codebooks_dim 96 | 97 | self.dim = default(dim, len(_levels) * num_codebooks) 98 | 99 | self.channel_first = channel_first 100 | 101 | has_projections = self.dim != effective_codebook_dim 102 | self.project_in = ( 103 | nn.Linear(self.dim, effective_codebook_dim, bias=projection_has_bias) 104 | if has_projections 105 | else nn.Identity() 106 | ) 107 | self.project_out = ( 108 | nn.Linear(effective_codebook_dim, self.dim, bias=projection_has_bias) 109 | if has_projections 110 | else nn.Identity() 111 | ) 112 | 113 | self.has_projections = has_projections 114 | 115 | self.return_indices = return_indices 116 | if return_indices: 117 | self.codebook_size = self._levels.prod().item() 118 | implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size)) 119 | self.register_buffer( 120 | "implicit_codebook", implicit_codebook, persistent=False 121 | ) 122 | 123 | self.allowed_dtypes = allowed_dtypes 124 | self.force_quantization_f32 = force_quantization_f32 125 | 126 | def bound(self, z, eps: float = 1e-3): 127 | """Bound `z`, an array of shape (..., d).""" 128 | half_l = (self._levels - 1) * (1 + eps) / 2 129 | offset = torch.where(self._levels % 2 == 0, 0.5, 0.0) 130 | shift = (offset / half_l).atanh() 131 | return (z + shift).tanh() * half_l - offset 132 | 133 | def quantize(self, z): 134 | """Quantizes z, returns quantized zhat, same shape as z.""" 135 | quantized = round_ste(self.bound(z)) 136 | half_width = self._levels // 2 # Renormalize to [-1, 1]. 137 | return quantized / half_width 138 | 139 | def _scale_and_shift(self, zhat_normalized): 140 | half_width = self._levels // 2 141 | return (zhat_normalized * half_width) + half_width 142 | 143 | def _scale_and_shift_inverse(self, zhat): 144 | half_width = self._levels // 2 145 | return (zhat - half_width) / half_width 146 | 147 | def _indices_to_codes(self, indices): 148 | level_indices = self.indices_to_level_indices(indices) 149 | codes = self._scale_and_shift_inverse(level_indices) 150 | return codes 151 | 152 | def codes_to_indices(self, zhat): 153 | """Converts a `code` to an index in the codebook.""" 154 | assert zhat.shape[-1] == self.codebook_dim 155 | zhat = self._scale_and_shift(zhat) 156 | return (zhat * self._basis).sum(dim=-1).to(int32) 157 | 158 | def indices_to_level_indices(self, indices): 159 | """Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings""" 160 | indices = rearrange(indices, "... -> ... 1") 161 | codes_non_centered = (indices // self._basis) % self._levels 162 | return codes_non_centered 163 | 164 | def indices_to_codes(self, indices): 165 | """Inverse of `codes_to_indices`.""" 166 | assert exists(indices) 167 | 168 | is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) 169 | 170 | codes = self._indices_to_codes(indices) 171 | 172 | if self.keep_num_codebooks_dim: 173 | codes = rearrange(codes, "... c d -> ... (c d)") 174 | 175 | codes = self.project_out(codes) 176 | 177 | if is_img_or_video or self.channel_first: 178 | codes = rearrange(codes, "b ... d -> b d ...") 179 | 180 | return codes 181 | 182 | def forward(self, z): 183 | """ 184 | einstein notation 185 | b - batch 186 | n - sequence (or flattened spatial dimensions) 187 | d - feature dimension 188 | c - number of codebook dim 189 | """ 190 | 191 | is_img_or_video = z.ndim >= 4 192 | need_move_channel_last = is_img_or_video or self.channel_first 193 | 194 | # standardize image or video into (batch, seq, dimension) 195 | 196 | if need_move_channel_last: 197 | z = rearrange(z, "b d ... -> b ... d") 198 | z, ps = pack_one(z, "b * d") 199 | 200 | assert ( 201 | z.shape[-1] == self.dim 202 | ), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}" 203 | 204 | z = self.project_in(z) 205 | 206 | z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks) 207 | 208 | # whether to force quantization step to be full precision or not 209 | 210 | force_f32 = self.force_quantization_f32 211 | quantization_context = ( 212 | partial(autocast, "cuda", enabled=False) if force_f32 else nullcontext 213 | ) 214 | 215 | with quantization_context(): 216 | orig_dtype = z.dtype 217 | 218 | if force_f32 and orig_dtype not in self.allowed_dtypes: 219 | z = z.float() 220 | 221 | codes = self.quantize(z) 222 | 223 | # returning indices could be optional 224 | 225 | indices = None 226 | 227 | if self.return_indices: 228 | indices = self.codes_to_indices(codes) 229 | 230 | codes = rearrange(codes, "b n c d -> b n (c d)") 231 | 232 | codes = codes.type(orig_dtype) 233 | 234 | # project out 235 | 236 | out = self.project_out(codes) 237 | 238 | # reconstitute image or video dimensions 239 | 240 | if need_move_channel_last: 241 | out = unpack_one(out, ps, "b * d") 242 | out = rearrange(out, "b ... d -> b d ...") 243 | 244 | indices = maybe(unpack_one)(indices, ps, "b * c") 245 | 246 | if not self.keep_num_codebooks_dim and self.return_indices: 247 | indices = maybe(rearrange)(indices, "... 1 -> ...") 248 | 249 | # return quantized output and indices 250 | 251 | return out, indices 252 | -------------------------------------------------------------------------------- /sparktts/modules/fsq/residual_fsq.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.distributed as dist 5 | 6 | from typing import List 7 | from torch import nn 8 | from torch.nn import Module 9 | from torch.amp import autocast 10 | from einx import get_at 11 | from einops import rearrange, reduce, pack, unpack 12 | 13 | from sparktts.modules.fsq.finite_scalar_quantization import FSQ 14 | 15 | 16 | def exists(val): 17 | return val is not None 18 | 19 | 20 | def first(l): 21 | return l[0] 22 | 23 | 24 | def default(val, d): 25 | return val if exists(val) else d 26 | 27 | 28 | def round_up_multiple(num, mult): 29 | return ceil(num / mult) * mult 30 | 31 | 32 | # distributed helpers 33 | 34 | 35 | def is_distributed(): 36 | return dist.is_initialized() and dist.get_world_size() > 1 37 | 38 | 39 | def get_maybe_sync_seed(device, max_size=10_000): 40 | rand_int = torch.randint(0, max_size, (), device=device) 41 | 42 | if is_distributed(): 43 | dist.all_reduce(rand_int) 44 | 45 | return rand_int.item() 46 | 47 | 48 | class ResidualFSQ(Module): 49 | """Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf""" 50 | 51 | def __init__( 52 | self, 53 | *, 54 | levels: List[int], 55 | num_quantizers, 56 | dim=None, 57 | is_channel_first=False, 58 | quantize_dropout=False, 59 | quantize_dropout_cutoff_index=0, 60 | quantize_dropout_multiple_of=1, 61 | **kwargs, 62 | ): 63 | super().__init__() 64 | codebook_dim = len(levels) 65 | dim = default(dim, codebook_dim) 66 | 67 | requires_projection = codebook_dim != dim 68 | self.project_in = ( 69 | nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity() 70 | ) 71 | self.project_out = ( 72 | nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity() 73 | ) 74 | self.has_projections = requires_projection 75 | 76 | self.is_channel_first = is_channel_first 77 | self.num_quantizers = num_quantizers 78 | 79 | self.levels = levels 80 | self.layers = nn.ModuleList([]) 81 | 82 | levels_tensor = torch.Tensor(levels) 83 | 84 | scales = [] 85 | 86 | for ind in range(num_quantizers): 87 | scales.append((levels_tensor - 1) ** -ind) 88 | 89 | fsq = FSQ(levels=levels, dim=codebook_dim, **kwargs) 90 | 91 | self.layers.append(fsq) 92 | 93 | assert all([not fsq.has_projections for fsq in self.layers]) 94 | 95 | self.codebook_size = self.layers[0].codebook_size 96 | 97 | self.register_buffer("scales", torch.stack(scales), persistent=False) 98 | 99 | self.quantize_dropout = quantize_dropout and num_quantizers > 1 100 | 101 | assert quantize_dropout_cutoff_index >= 0 102 | 103 | self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index 104 | self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4 105 | 106 | @property 107 | def codebooks(self): 108 | codebooks = [layer.implicit_codebook for layer in self.layers] 109 | codebooks = torch.stack(codebooks, dim=0) 110 | return codebooks 111 | 112 | def get_codes_from_indices(self, indices): 113 | 114 | batch, quantize_dim = indices.shape[0], indices.shape[-1] 115 | 116 | # may also receive indices in the shape of 'b h w q' (accept_image_fmap) 117 | 118 | indices, ps = pack([indices], "b * q") 119 | 120 | # because of quantize dropout, one can pass in indices that are coarse 121 | # and the network should be able to reconstruct 122 | 123 | if quantize_dim < self.num_quantizers: 124 | assert ( 125 | self.quantize_dropout > 0.0 126 | ), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations" 127 | indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1) 128 | 129 | # take care of quantizer dropout 130 | 131 | mask = indices == -1 132 | indices = indices.masked_fill( 133 | mask, 0 134 | ) # have it fetch a dummy code to be masked out later 135 | 136 | all_codes = get_at("q [c] d, b n q -> q b n d", self.codebooks, indices) 137 | 138 | # mask out any codes that were dropout-ed 139 | 140 | all_codes = all_codes.masked_fill(rearrange(mask, "b n q -> q b n 1"), 0.0) 141 | 142 | # scale the codes 143 | 144 | scales = rearrange(self.scales, "q d -> q 1 1 d") 145 | all_codes = all_codes * scales 146 | 147 | # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension) 148 | 149 | (all_codes,) = unpack(all_codes, ps, "q b * d") 150 | 151 | return all_codes 152 | 153 | def get_output_from_indices(self, indices): 154 | codes = self.get_codes_from_indices(indices) 155 | codes_summed = reduce(codes, "q ... -> ...", "sum") 156 | return self.project_out(codes_summed) 157 | 158 | def forward(self, x, return_all_codes=False, rand_quantize_dropout_fixed_seed=None): 159 | num_quant, quant_dropout_multiple_of, device = ( 160 | self.num_quantizers, 161 | self.quantize_dropout_multiple_of, 162 | x.device, 163 | ) 164 | 165 | # handle channel first 166 | 167 | if self.is_channel_first: 168 | x = rearrange(x, "b d ... -> b ... d") 169 | x, ps = pack([x], "b * d") 170 | 171 | # maybe project in 172 | 173 | x = self.project_in(x) 174 | 175 | quantized_out = 0.0 176 | residual = x 177 | 178 | all_indices = [] 179 | 180 | should_quantize_dropout = self.training and self.quantize_dropout 181 | 182 | # sample a layer index at which to dropout further residual quantization 183 | # also prepare null indices 184 | 185 | if should_quantize_dropout: 186 | 187 | # check if seed is manually passed in 188 | 189 | if not exists(rand_quantize_dropout_fixed_seed): 190 | rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device) 191 | 192 | rand = random.Random(rand_quantize_dropout_fixed_seed) 193 | 194 | rand_quantize_dropout_index = rand.randrange( 195 | self.quantize_dropout_cutoff_index, num_quant 196 | ) 197 | 198 | if quant_dropout_multiple_of != 1: 199 | rand_quantize_dropout_index = ( 200 | round_up_multiple( 201 | rand_quantize_dropout_index + 1, quant_dropout_multiple_of 202 | ) 203 | - 1 204 | ) 205 | 206 | null_indices = torch.full( 207 | x.shape[:2], -1.0, device=device, dtype=torch.long 208 | ) 209 | 210 | # go through the layers 211 | 212 | with autocast("cuda", enabled=False): 213 | for quantizer_index, (layer, scale) in enumerate( 214 | zip(self.layers, self.scales) 215 | ): 216 | 217 | if ( 218 | should_quantize_dropout 219 | and quantizer_index > rand_quantize_dropout_index 220 | ): 221 | all_indices.append(null_indices) 222 | continue 223 | 224 | quantized, indices = layer(residual / scale) 225 | 226 | quantized = quantized * scale 227 | 228 | residual = residual - quantized.detach() 229 | quantized_out = quantized_out + quantized 230 | 231 | all_indices.append(indices) 232 | 233 | # project out, if needed 234 | 235 | quantized_out = self.project_out(quantized_out) 236 | 237 | # stack all indices 238 | 239 | all_indices = torch.stack(all_indices, dim=-1) 240 | 241 | # channel first out 242 | 243 | if self.is_channel_first: 244 | (quantized_out,) = unpack(quantized_out, ps, "b * d") 245 | (all_indices,) = unpack(all_indices, ps, "b * d") 246 | 247 | quantized_out = rearrange(quantized_out, "b ... d -> b d ...") 248 | all_indices = rearrange(all_indices, "b ... d -> b d ...") 249 | 250 | # return 251 | 252 | ret = (quantized_out, all_indices) 253 | 254 | if not return_all_codes: 255 | return ret 256 | 257 | # whether to return all codes from all codebooks across layers 258 | 259 | all_codes = self.get_codes_from_indices(all_indices) 260 | 261 | # will return all codes in shape (quantizer, batch, sequence length, codebook dimension) 262 | 263 | return (*ret, all_codes) 264 | 265 | 266 | # grouped residual fsq 267 | 268 | 269 | class GroupedResidualFSQ(Module): 270 | def __init__(self, *, dim, groups=1, accept_image_fmap=False, **kwargs): 271 | super().__init__() 272 | self.dim = dim 273 | self.groups = groups 274 | assert (dim % groups) == 0 275 | dim_per_group = dim // groups 276 | 277 | self.accept_image_fmap = accept_image_fmap 278 | 279 | self.rvqs = nn.ModuleList([]) 280 | 281 | for _ in range(groups): 282 | self.rvqs.append(ResidualFSQ(dim=dim_per_group, **kwargs)) 283 | 284 | self.codebook_size = self.rvqs[0].codebook_size 285 | 286 | @property 287 | def codebooks(self): 288 | return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs)) 289 | 290 | @property 291 | def split_dim(self): 292 | return 1 if self.accept_image_fmap else -1 293 | 294 | def get_codes_from_indices(self, indices): 295 | codes = tuple( 296 | rvq.get_codes_from_indices(chunk_indices) 297 | for rvq, chunk_indices in zip(self.rvqs, indices) 298 | ) 299 | return torch.stack(codes) 300 | 301 | def get_output_from_indices(self, indices): 302 | outputs = tuple( 303 | rvq.get_output_from_indices(chunk_indices) 304 | for rvq, chunk_indices in zip(self.rvqs, indices) 305 | ) 306 | return torch.cat(outputs, dim=self.split_dim) 307 | 308 | def forward(self, x, return_all_codes=False): 309 | shape, split_dim, device = x.shape, self.split_dim, x.device 310 | assert shape[split_dim] == self.dim 311 | 312 | # split the feature dimension into groups 313 | 314 | x = x.chunk(self.groups, dim=split_dim) 315 | 316 | forward_kwargs = dict( 317 | return_all_codes=return_all_codes, 318 | rand_quantize_dropout_fixed_seed=( 319 | get_maybe_sync_seed(device) if self.training else None 320 | ), 321 | ) 322 | 323 | # invoke residual vq on each group 324 | 325 | out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x)) 326 | out = tuple(zip(*out)) 327 | 328 | # otherwise, get all the zipped outputs and combine them 329 | 330 | quantized, all_indices, *maybe_all_codes = out 331 | 332 | quantized = torch.cat(quantized, dim=split_dim) 333 | all_indices = torch.stack(all_indices) 334 | 335 | ret = (quantized, all_indices, *maybe_all_codes) 336 | return ret 337 | 338 | 339 | if __name__ == "__main__": 340 | model = ResidualFSQ( 341 | levels=[4, 4, 4, 4, 4, 4], 342 | num_quantizers=1, 343 | dim=30, 344 | is_channel_first=True, 345 | quantize_dropout=False, 346 | ) 347 | x = torch.randn(2, 30, 10) 348 | quantize, embed_ind = model(x) 349 | 350 | emb_from_ind = model.get_output_from_indices(embed_ind.transpose(1, 2)) 351 | 352 | print(quantize == emb_from_ind.transpose(1, 2)) 353 | 354 | print("quantize shape", quantize.shape) 355 | print("embed_ind", embed_ind) 356 | -------------------------------------------------------------------------------- /sparktts/modules/speaker/ecapa_tdnn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Zhengyang Chen (chenzhengyang117@gmail.com) 2 | # 2022 Hongji Wang (jijijiang77@gmail.com) 3 | # 2023 Bing Han (hanbing97@sjtu.edu.cn) 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """ This implementation is adapted from github repo: 18 | https://github.com/lawlict/ECAPA-TDNN. 19 | """ 20 | 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | 25 | import sparktts.modules.speaker.pooling_layers as pooling_layers 26 | 27 | 28 | class Res2Conv1dReluBn(nn.Module): 29 | """ 30 | in_channels == out_channels == channels 31 | """ 32 | 33 | def __init__( 34 | self, 35 | channels, 36 | kernel_size=1, 37 | stride=1, 38 | padding=0, 39 | dilation=1, 40 | bias=True, 41 | scale=4, 42 | ): 43 | super().__init__() 44 | assert channels % scale == 0, "{} % {} != 0".format(channels, scale) 45 | self.scale = scale 46 | self.width = channels // scale 47 | self.nums = scale if scale == 1 else scale - 1 48 | 49 | self.convs = [] 50 | self.bns = [] 51 | for i in range(self.nums): 52 | self.convs.append( 53 | nn.Conv1d( 54 | self.width, 55 | self.width, 56 | kernel_size, 57 | stride, 58 | padding, 59 | dilation, 60 | bias=bias, 61 | ) 62 | ) 63 | self.bns.append(nn.BatchNorm1d(self.width)) 64 | self.convs = nn.ModuleList(self.convs) 65 | self.bns = nn.ModuleList(self.bns) 66 | 67 | def forward(self, x): 68 | out = [] 69 | spx = torch.split(x, self.width, 1) 70 | sp = spx[0] 71 | for i, (conv, bn) in enumerate(zip(self.convs, self.bns)): 72 | # Order: conv -> relu -> bn 73 | if i >= 1: 74 | sp = sp + spx[i] 75 | sp = conv(sp) 76 | sp = bn(F.relu(sp)) 77 | out.append(sp) 78 | if self.scale != 1: 79 | out.append(spx[self.nums]) 80 | out = torch.cat(out, dim=1) 81 | 82 | return out 83 | 84 | 85 | """ Conv1d + BatchNorm1d + ReLU 86 | """ 87 | 88 | 89 | class Conv1dReluBn(nn.Module): 90 | 91 | def __init__( 92 | self, 93 | in_channels, 94 | out_channels, 95 | kernel_size=1, 96 | stride=1, 97 | padding=0, 98 | dilation=1, 99 | bias=True, 100 | ): 101 | super().__init__() 102 | self.conv = nn.Conv1d( 103 | in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias 104 | ) 105 | self.bn = nn.BatchNorm1d(out_channels) 106 | 107 | def forward(self, x): 108 | return self.bn(F.relu(self.conv(x))) 109 | 110 | 111 | """ The SE connection of 1D case. 112 | """ 113 | 114 | 115 | class SE_Connect(nn.Module): 116 | 117 | def __init__(self, channels, se_bottleneck_dim=128): 118 | super().__init__() 119 | self.linear1 = nn.Linear(channels, se_bottleneck_dim) 120 | self.linear2 = nn.Linear(se_bottleneck_dim, channels) 121 | 122 | def forward(self, x): 123 | out = x.mean(dim=2) 124 | out = F.relu(self.linear1(out)) 125 | out = torch.sigmoid(self.linear2(out)) 126 | out = x * out.unsqueeze(2) 127 | 128 | return out 129 | 130 | 131 | """ SE-Res2Block of the ECAPA-TDNN architecture. 132 | """ 133 | 134 | 135 | class SE_Res2Block(nn.Module): 136 | 137 | def __init__(self, channels, kernel_size, stride, padding, dilation, scale): 138 | super().__init__() 139 | self.se_res2block = nn.Sequential( 140 | Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0), 141 | Res2Conv1dReluBn( 142 | channels, kernel_size, stride, padding, dilation, scale=scale 143 | ), 144 | Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0), 145 | SE_Connect(channels), 146 | ) 147 | 148 | def forward(self, x): 149 | return x + self.se_res2block(x) 150 | 151 | 152 | class ECAPA_TDNN(nn.Module): 153 | 154 | def __init__( 155 | self, 156 | channels=512, 157 | feat_dim=80, 158 | embed_dim=192, 159 | pooling_func="ASTP", 160 | global_context_att=False, 161 | emb_bn=False, 162 | ): 163 | super().__init__() 164 | 165 | self.layer1 = Conv1dReluBn(feat_dim, channels, kernel_size=5, padding=2) 166 | self.layer2 = SE_Res2Block( 167 | channels, kernel_size=3, stride=1, padding=2, dilation=2, scale=8 168 | ) 169 | self.layer3 = SE_Res2Block( 170 | channels, kernel_size=3, stride=1, padding=3, dilation=3, scale=8 171 | ) 172 | self.layer4 = SE_Res2Block( 173 | channels, kernel_size=3, stride=1, padding=4, dilation=4, scale=8 174 | ) 175 | 176 | cat_channels = channels * 3 177 | out_channels = 512 * 3 178 | self.conv = nn.Conv1d(cat_channels, out_channels, kernel_size=1) 179 | self.pool = getattr(pooling_layers, pooling_func)( 180 | in_dim=out_channels, global_context_att=global_context_att 181 | ) 182 | self.pool_out_dim = self.pool.get_out_dim() 183 | self.bn = nn.BatchNorm1d(self.pool_out_dim) 184 | self.linear = nn.Linear(self.pool_out_dim, embed_dim) 185 | self.emb_bn = emb_bn 186 | if emb_bn: # better in SSL for SV 187 | self.bn2 = nn.BatchNorm1d(embed_dim) 188 | else: 189 | self.bn2 = nn.Identity() 190 | 191 | def forward(self, x, return_latent=False): 192 | x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T) 193 | 194 | out1 = self.layer1(x) 195 | out2 = self.layer2(out1) 196 | out3 = self.layer3(out2) 197 | out4 = self.layer4(out3) 198 | 199 | out = torch.cat([out2, out3, out4], dim=1) 200 | latent = F.relu(self.conv(out)) 201 | out = self.bn(self.pool(latent)) 202 | out = self.linear(out) 203 | if self.emb_bn: 204 | out = self.bn2(out) 205 | 206 | if return_latent: 207 | return out, latent 208 | return out 209 | 210 | 211 | def ECAPA_TDNN_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False): 212 | return ECAPA_TDNN( 213 | channels=1024, 214 | feat_dim=feat_dim, 215 | embed_dim=embed_dim, 216 | pooling_func=pooling_func, 217 | emb_bn=emb_bn, 218 | ) 219 | 220 | 221 | def ECAPA_TDNN_GLOB_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False): 222 | return ECAPA_TDNN( 223 | channels=1024, 224 | feat_dim=feat_dim, 225 | embed_dim=embed_dim, 226 | pooling_func=pooling_func, 227 | global_context_att=True, 228 | emb_bn=emb_bn, 229 | ) 230 | 231 | 232 | def ECAPA_TDNN_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False): 233 | return ECAPA_TDNN( 234 | channels=512, 235 | feat_dim=feat_dim, 236 | embed_dim=embed_dim, 237 | pooling_func=pooling_func, 238 | emb_bn=emb_bn, 239 | ) 240 | 241 | 242 | def ECAPA_TDNN_GLOB_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False): 243 | return ECAPA_TDNN( 244 | channels=512, 245 | feat_dim=feat_dim, 246 | embed_dim=embed_dim, 247 | pooling_func=pooling_func, 248 | global_context_att=True, 249 | emb_bn=emb_bn, 250 | ) 251 | 252 | 253 | if __name__ == "__main__": 254 | x = torch.zeros(1, 200, 100) 255 | model = ECAPA_TDNN_GLOB_c512(feat_dim=100, embed_dim=256, pooling_func="ASTP") 256 | model.eval() 257 | out, latent = model(x, True) 258 | print(out.shape) 259 | print(latent.shape) 260 | 261 | num_params = sum(param.numel() for param in model.parameters()) 262 | print("{} M".format(num_params / 1e6)) 263 | 264 | # from thop import profile 265 | # x_np = torch.randn(1, 200, 80) 266 | # flops, params = profile(model, inputs=(x_np, )) 267 | # print("FLOPs: {} G, Params: {} M".format(flops / 1e9, params / 1e6)) 268 | -------------------------------------------------------------------------------- /sparktts/modules/speaker/perceiver_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 SparkAudio 2 | # 2025 Xinsheng Wang (w.xinshawn@gmail.com) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532 17 | 18 | from collections import namedtuple 19 | from functools import wraps 20 | 21 | import torch 22 | import torch.nn.functional as F 23 | from einops import rearrange, repeat 24 | from einops.layers.torch import Rearrange 25 | from packaging import version 26 | from torch import einsum, nn 27 | 28 | 29 | def exists(val): 30 | return val is not None 31 | 32 | 33 | def once(fn): 34 | called = False 35 | 36 | @wraps(fn) 37 | def inner(x): 38 | nonlocal called 39 | if called: 40 | return 41 | called = True 42 | return fn(x) 43 | 44 | return inner 45 | 46 | 47 | print_once = once(print) 48 | 49 | # main class 50 | 51 | 52 | class Attend(nn.Module): 53 | def __init__(self, dropout=0.0, causal=False, use_flash=False): 54 | super().__init__() 55 | self.dropout = dropout 56 | self.attn_dropout = nn.Dropout(dropout) 57 | 58 | self.causal = causal 59 | self.register_buffer("mask", None, persistent=False) 60 | 61 | self.use_flash = use_flash 62 | assert not ( 63 | use_flash and version.parse(torch.__version__) < version.parse("2.0.0") 64 | ), "in order to use flash attention, you must be using pytorch 2.0 or above" 65 | 66 | # determine efficient attention configs for cuda and cpu 67 | self.config = namedtuple( 68 | "EfficientAttentionConfig", 69 | ["enable_flash", "enable_math", "enable_mem_efficient"], 70 | ) 71 | self.cpu_config = self.config(True, True, True) 72 | self.cuda_config = None 73 | 74 | if not torch.cuda.is_available() or not use_flash: 75 | return 76 | 77 | device_properties = torch.cuda.get_device_properties(torch.device("cuda")) 78 | 79 | if device_properties.major == 8 and device_properties.minor == 0: 80 | print_once( 81 | "A100 GPU detected, using flash attention if input tensor is on cuda" 82 | ) 83 | self.cuda_config = self.config(True, False, False) 84 | else: 85 | print_once( 86 | "Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda" 87 | ) 88 | self.cuda_config = self.config(False, True, True) 89 | 90 | def get_mask(self, n, device): 91 | if exists(self.mask) and self.mask.shape[-1] >= n: 92 | return self.mask[:n, :n] 93 | 94 | mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) 95 | self.register_buffer("mask", mask, persistent=False) 96 | return mask 97 | 98 | def flash_attn(self, q, k, v, mask=None): 99 | _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda 100 | 101 | # Recommended for multi-query single-key-value attention by Tri Dao 102 | # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64]) 103 | 104 | if k.ndim == 3: 105 | k = rearrange(k, "b ... -> b 1 ...").expand_as(q) 106 | 107 | if v.ndim == 3: 108 | v = rearrange(v, "b ... -> b 1 ...").expand_as(q) 109 | 110 | # Check if mask exists and expand to compatible shape 111 | # The mask is B L, so it would have to be expanded to B H N L 112 | 113 | if exists(mask): 114 | mask = rearrange(mask, "b j -> b 1 1 j") 115 | mask = mask.expand(-1, heads, q_len, -1) 116 | 117 | # Check if there is a compatible device for flash attention 118 | 119 | config = self.cuda_config if is_cuda else self.cpu_config 120 | 121 | # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale 122 | 123 | with torch.backends.cuda.sdp_kernel(**config._asdict()): 124 | out = F.scaled_dot_product_attention( 125 | q, 126 | k, 127 | v, 128 | attn_mask=mask, 129 | dropout_p=self.dropout if self.training else 0.0, 130 | is_causal=self.causal, 131 | ) 132 | 133 | return out 134 | 135 | def forward(self, q, k, v, mask=None): 136 | """ 137 | einstein notation 138 | b - batch 139 | h - heads 140 | n, i, j - sequence length (base sequence length, source, target) 141 | d - feature dimension 142 | """ 143 | 144 | n, device = q.shape[-2], q.device 145 | 146 | scale = q.shape[-1] ** -0.5 147 | 148 | if self.use_flash: 149 | return self.flash_attn(q, k, v, mask=mask) 150 | 151 | kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d" 152 | 153 | # similarity 154 | 155 | sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale 156 | 157 | # key padding mask 158 | 159 | if exists(mask): 160 | mask = rearrange(mask, "b j -> b 1 1 j") 161 | sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) 162 | 163 | # causal mask 164 | 165 | if self.causal: 166 | causal_mask = self.get_mask(n, device) 167 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) 168 | 169 | # attention 170 | 171 | attn = sim.softmax(dim=-1) 172 | attn = self.attn_dropout(attn) 173 | 174 | # aggregate values 175 | 176 | out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v) 177 | 178 | return out 179 | 180 | 181 | def Sequential(*mods): 182 | return nn.Sequential(*filter(exists, mods)) 183 | 184 | 185 | def exists(x): 186 | return x is not None 187 | 188 | 189 | def default(val, d): 190 | if exists(val): 191 | return val 192 | return d() if callable(d) else d 193 | 194 | 195 | class RMSNorm(nn.Module): 196 | def __init__(self, dim, scale=True, dim_cond=None): 197 | super().__init__() 198 | self.cond = exists(dim_cond) 199 | self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None 200 | 201 | self.scale = dim**0.5 202 | self.gamma = nn.Parameter(torch.ones(dim)) if scale else None 203 | 204 | def forward(self, x, cond=None): 205 | gamma = default(self.gamma, 1) 206 | out = F.normalize(x, dim=-1) * self.scale * gamma 207 | 208 | if not self.cond: 209 | return out 210 | 211 | assert exists(cond) 212 | gamma, beta = self.to_gamma_beta(cond).chunk(2, dim=-1) 213 | gamma, beta = map(lambda t: rearrange(t, "b d -> b 1 d"), (gamma, beta)) 214 | return out * gamma + beta 215 | 216 | 217 | class CausalConv1d(nn.Conv1d): 218 | def __init__(self, *args, **kwargs): 219 | super().__init__(*args, **kwargs) 220 | (kernel_size,) = self.kernel_size 221 | (dilation,) = self.dilation 222 | (stride,) = self.stride 223 | 224 | assert stride == 1 225 | self.causal_padding = dilation * (kernel_size - 1) 226 | 227 | def forward(self, x): 228 | causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0) 229 | return super().forward(causal_padded_x) 230 | 231 | 232 | class GEGLU(nn.Module): 233 | def forward(self, x): 234 | x, gate = x.chunk(2, dim=-1) 235 | return F.gelu(gate) * x 236 | 237 | 238 | def FeedForward(dim, mult=4, causal_conv=False): 239 | dim_inner = int(dim * mult * 2 / 3) 240 | 241 | conv = None 242 | if causal_conv: 243 | conv = nn.Sequential( 244 | Rearrange("b n d -> b d n"), 245 | CausalConv1d(dim_inner, dim_inner, 3), 246 | Rearrange("b d n -> b n d"), 247 | ) 248 | 249 | return Sequential( 250 | nn.Linear(dim, dim_inner * 2), GEGLU(), conv, nn.Linear(dim_inner, dim) 251 | ) 252 | 253 | 254 | class Attention(nn.Module): 255 | def __init__( 256 | self, 257 | dim, 258 | *, 259 | dim_context=None, 260 | causal=False, 261 | dim_head=64, 262 | heads=8, 263 | dropout=0.0, 264 | use_flash=False, 265 | cross_attn_include_queries=False, 266 | ): 267 | super().__init__() 268 | self.scale = dim_head**-0.5 269 | self.heads = heads 270 | self.cross_attn_include_queries = cross_attn_include_queries 271 | 272 | dim_inner = dim_head * heads 273 | dim_context = default(dim_context, dim) 274 | 275 | self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash) 276 | self.to_q = nn.Linear(dim, dim_inner, bias=False) 277 | self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False) 278 | self.to_out = nn.Linear(dim_inner, dim, bias=False) 279 | 280 | def forward(self, x, context=None, mask=None): 281 | h, has_context = self.heads, exists(context) 282 | 283 | context = default(context, x) 284 | 285 | if has_context and self.cross_attn_include_queries: 286 | context = torch.cat((x, context), dim=-2) 287 | 288 | q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1)) 289 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) 290 | 291 | out = self.attend(q, k, v, mask=mask) 292 | 293 | out = rearrange(out, "b h n d -> b n (h d)") 294 | return self.to_out(out) 295 | 296 | 297 | class PerceiverResampler(nn.Module): 298 | def __init__( 299 | self, 300 | *, 301 | dim, 302 | depth=2, 303 | dim_context=None, 304 | num_latents=32, 305 | dim_head=64, 306 | heads=8, 307 | ff_mult=4, 308 | use_flash_attn=False, 309 | ): 310 | super().__init__() 311 | dim_context = default(dim_context, dim) 312 | 313 | self.proj_context = ( 314 | nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity() 315 | ) 316 | 317 | self.latents = nn.Parameter(torch.randn(num_latents, dim)) 318 | nn.init.normal_(self.latents, std=0.02) 319 | 320 | self.layers = nn.ModuleList([]) 321 | for _ in range(depth): 322 | self.layers.append( 323 | nn.ModuleList( 324 | [ 325 | Attention( 326 | dim=dim, 327 | dim_head=dim_head, 328 | heads=heads, 329 | use_flash=use_flash_attn, 330 | cross_attn_include_queries=True, 331 | ), 332 | FeedForward(dim=dim, mult=ff_mult), 333 | ] 334 | ) 335 | ) 336 | 337 | self.norm = RMSNorm(dim) 338 | 339 | def forward(self, x, mask=None): 340 | batch = x.shape[0] 341 | 342 | x = self.proj_context(x) 343 | 344 | latents = repeat(self.latents, "n d -> b n d", b=batch) 345 | 346 | for attn, ff in self.layers: 347 | latents = attn(latents, x, mask=mask) + latents 348 | latents = ff(latents) + latents 349 | 350 | return self.norm(latents) 351 | 352 | 353 | if __name__ == "__main__": 354 | model = PerceiverResampler(dim=256, dim_context=80) 355 | x = torch.randn(8, 200, 80) 356 | out = model(x) 357 | print(out.shape) # [8, 32, 80] 358 | 359 | num_params = sum(param.numel() for param in model.parameters()) 360 | print("{} M".format(num_params / 1e6)) 361 | -------------------------------------------------------------------------------- /sparktts/modules/speaker/pooling_layers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Shuai Wang (wsstriving@gmail.com) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Pooling functions to aggregate frame-level deep features 16 | into segment-level speaker embeddings 17 | 18 | High-order statistics are surprisingly effective, TSDP acts similarly as TSTP, 19 | even though we remove the mean statistic, on Voxceleb. 20 | """ 21 | 22 | import torch 23 | import torch.nn as nn 24 | import torch.nn.functional as F 25 | 26 | 27 | class TAP(nn.Module): 28 | """ 29 | Temporal average pooling, only first-order mean is considered 30 | """ 31 | 32 | def __init__(self, in_dim=0, **kwargs): 33 | super(TAP, self).__init__() 34 | self.in_dim = in_dim 35 | 36 | def forward(self, x): 37 | pooling_mean = x.mean(dim=-1) 38 | # To be compatable with 2D input 39 | pooling_mean = pooling_mean.flatten(start_dim=1) 40 | return pooling_mean 41 | 42 | def get_out_dim(self): 43 | self.out_dim = self.in_dim 44 | return self.out_dim 45 | 46 | 47 | class TSDP(nn.Module): 48 | """ 49 | Temporal standard deviation pooling, only second-order std is considered 50 | """ 51 | 52 | def __init__(self, in_dim=0, **kwargs): 53 | super(TSDP, self).__init__() 54 | self.in_dim = in_dim 55 | 56 | def forward(self, x): 57 | # The last dimension is the temporal axis 58 | pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7) 59 | pooling_std = pooling_std.flatten(start_dim=1) 60 | return pooling_std 61 | 62 | def get_out_dim(self): 63 | self.out_dim = self.in_dim 64 | return self.out_dim 65 | 66 | 67 | class TSTP(nn.Module): 68 | """ 69 | Temporal statistics pooling, concatenate mean and std, which is used in 70 | x-vector 71 | Comment: simple concatenation can not make full use of both statistics 72 | """ 73 | 74 | def __init__(self, in_dim=0, **kwargs): 75 | super(TSTP, self).__init__() 76 | self.in_dim = in_dim 77 | 78 | def forward(self, x): 79 | # The last dimension is the temporal axis 80 | pooling_mean = x.mean(dim=-1) 81 | pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7) 82 | pooling_mean = pooling_mean.flatten(start_dim=1) 83 | pooling_std = pooling_std.flatten(start_dim=1) 84 | stats = torch.cat((pooling_mean, pooling_std), 1) 85 | return stats 86 | 87 | def get_out_dim(self): 88 | self.out_dim = self.in_dim * 2 89 | return self.out_dim 90 | 91 | 92 | class ASTP(nn.Module): 93 | """ Attentive statistics pooling: Channel- and context-dependent 94 | statistics pooling, first used in ECAPA_TDNN. 95 | """ 96 | 97 | def __init__(self, 98 | in_dim, 99 | bottleneck_dim=128, 100 | global_context_att=False, 101 | **kwargs): 102 | super(ASTP, self).__init__() 103 | self.in_dim = in_dim 104 | self.global_context_att = global_context_att 105 | 106 | # Use Conv1d with stride == 1 rather than Linear, then we don't 107 | # need to transpose inputs. 108 | if global_context_att: 109 | self.linear1 = nn.Conv1d( 110 | in_dim * 3, bottleneck_dim, 111 | kernel_size=1) # equals W and b in the paper 112 | else: 113 | self.linear1 = nn.Conv1d( 114 | in_dim, bottleneck_dim, 115 | kernel_size=1) # equals W and b in the paper 116 | self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, 117 | kernel_size=1) # equals V and k in the paper 118 | 119 | def forward(self, x): 120 | """ 121 | x: a 3-dimensional tensor in tdnn-based architecture (B,F,T) 122 | or a 4-dimensional tensor in resnet architecture (B,C,F,T) 123 | 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension) 124 | """ 125 | if len(x.shape) == 4: 126 | x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3]) 127 | assert len(x.shape) == 3 128 | 129 | if self.global_context_att: 130 | context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x) 131 | context_std = torch.sqrt( 132 | torch.var(x, dim=-1, keepdim=True) + 1e-7).expand_as(x) 133 | x_in = torch.cat((x, context_mean, context_std), dim=1) 134 | else: 135 | x_in = x 136 | 137 | # DON'T use ReLU here! ReLU may be hard to converge. 138 | alpha = torch.tanh( 139 | self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in)) 140 | alpha = torch.softmax(self.linear2(alpha), dim=2) 141 | mean = torch.sum(alpha * x, dim=2) 142 | var = torch.sum(alpha * (x**2), dim=2) - mean**2 143 | std = torch.sqrt(var.clamp(min=1e-7)) 144 | return torch.cat([mean, std], dim=1) 145 | 146 | def get_out_dim(self): 147 | self.out_dim = 2 * self.in_dim 148 | return self.out_dim 149 | 150 | 151 | class MHASTP(torch.nn.Module): 152 | """ Multi head attentive statistics pooling 153 | Reference: 154 | Self Multi-Head Attention for Speaker Recognition 155 | https://arxiv.org/pdf/1906.09890.pdf 156 | """ 157 | 158 | def __init__(self, 159 | in_dim, 160 | layer_num=2, 161 | head_num=2, 162 | d_s=1, 163 | bottleneck_dim=64, 164 | **kwargs): 165 | super(MHASTP, self).__init__() 166 | assert (in_dim % head_num 167 | ) == 0 # make sure that head num can be divided by input_dim 168 | self.in_dim = in_dim 169 | self.head_num = head_num 170 | d_model = int(in_dim / head_num) 171 | channel_dims = [bottleneck_dim for i in range(layer_num + 1)] 172 | if d_s > 1: 173 | d_s = d_model 174 | else: 175 | d_s = 1 176 | self.d_s = d_s 177 | channel_dims[0], channel_dims[-1] = d_model, d_s 178 | heads_att_trans = [] 179 | for i in range(self.head_num): 180 | att_trans = nn.Sequential() 181 | for i in range(layer_num - 1): 182 | att_trans.add_module( 183 | 'att_' + str(i), 184 | nn.Conv1d(channel_dims[i], channel_dims[i + 1], 1, 1)) 185 | att_trans.add_module('tanh' + str(i), nn.Tanh()) 186 | att_trans.add_module( 187 | 'att_' + str(layer_num - 1), 188 | nn.Conv1d(channel_dims[layer_num - 1], channel_dims[layer_num], 189 | 1, 1)) 190 | heads_att_trans.append(att_trans) 191 | self.heads_att_trans = nn.ModuleList(heads_att_trans) 192 | 193 | def forward(self, input): 194 | """ 195 | input: a 3-dimensional tensor in xvector architecture 196 | or a 4-dimensional tensor in resnet architecture 197 | 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension) 198 | """ 199 | if len(input.shape) == 4: # B x F x T 200 | input = input.reshape(input.shape[0], 201 | input.shape[1] * input.shape[2], 202 | input.shape[3]) 203 | assert len(input.shape) == 3 204 | bs, f_dim, t_dim = input.shape 205 | chunks = torch.chunk(input, self.head_num, 1) 206 | # split 207 | chunks_out = [] 208 | # for i in range(self.head_num): 209 | # att_score = self.heads_att_trans[i](chunks[i]) 210 | for i, layer in enumerate(self.heads_att_trans): 211 | att_score = layer(chunks[i]) 212 | alpha = F.softmax(att_score, dim=-1) 213 | mean = torch.sum(alpha * chunks[i], dim=2) 214 | var = torch.sum(alpha * chunks[i]**2, dim=2) - mean**2 215 | std = torch.sqrt(var.clamp(min=1e-7)) 216 | chunks_out.append(torch.cat((mean, std), dim=1)) 217 | out = torch.cat(chunks_out, dim=1) 218 | return out 219 | 220 | def get_out_dim(self): 221 | self.out_dim = 2 * self.in_dim 222 | return self.out_dim 223 | 224 | 225 | class MQMHASTP(torch.nn.Module): 226 | """ An attentive pooling 227 | Reference: 228 | multi query multi head attentive statistics pooling 229 | https://arxiv.org/pdf/2110.05042.pdf 230 | Args: 231 | in_dim: the feature dimension of input 232 | layer_num: the number of layer in the pooling layer 233 | query_num: the number of querys 234 | head_num: the number of heads 235 | bottleneck_dim: the bottleneck dimension 236 | 237 | SA (H = 1, Q = 1, n = 2, d_s = 1) ref: 238 | https://www.danielpovey.com/files/2018_interspeech_xvector_attention.pdf 239 | MHA (H > 1, Q = 1, n = 1, d_s = 1) ref: 240 | https://arxiv.org/pdf/1906.09890.pdf 241 | AS (H = 1, Q > 1, n = 2, d_s = 1) ref: 242 | https://arxiv.org/pdf/1803.10963.pdf 243 | VSA (H = 1, Q > 1, n = 2, d_s = d_h) ref: 244 | http://www.interspeech2020.org/uploadfile/pdf/Mon-2-10-5.pdf 245 | """ 246 | 247 | def __init__(self, 248 | in_dim, 249 | layer_num=2, 250 | query_num=2, 251 | head_num=8, 252 | d_s=2, 253 | bottleneck_dim=64, 254 | **kwargs): 255 | super(MQMHASTP, self).__init__() 256 | self.n_query = nn.ModuleList([ 257 | MHASTP(in_dim, 258 | layer_num=layer_num, 259 | head_num=head_num, 260 | d_s=d_s, 261 | bottleneck_dim=bottleneck_dim) for i in range(query_num) 262 | ]) 263 | self.query_num = query_num 264 | self.in_dim = in_dim 265 | 266 | def forward(self, input): 267 | """ 268 | input: a 3-dimensional tensor in xvector architecture 269 | or a 4-dimensional tensor in resnet architecture 270 | 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension) 271 | """ 272 | if len(input.shape) == 4: # B x F x T 273 | input = input.reshape(input.shape[0], 274 | input.shape[1] * input.shape[2], 275 | input.shape[3]) 276 | assert len(input.shape) == 3 277 | res = [] 278 | for i, layer in enumerate(self.n_query): 279 | res.append(layer(input)) 280 | out = torch.cat(res, dim=-1) 281 | return out 282 | 283 | def get_out_dim(self): 284 | self.out_dim = self.in_dim * 2 * self.query_num 285 | return self.out_dim 286 | 287 | 288 | if __name__ == '__main__': 289 | data = torch.randn(16, 512, 10, 35) 290 | # model = StatisticsPooling() 291 | model = MQMHASTP(512 * 10) 292 | model = MHASTP(512 * 10) 293 | model = MQMHASTP(512 * 10, context=False) 294 | print(model) 295 | 296 | out = model(data) 297 | print(out.shape) 298 | print(model.get_out_dim()) -------------------------------------------------------------------------------- /sparktts/modules/speaker/speaker_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 SparkAudio 2 | # 2025 Xinsheng Wang (w.xinshawn@gmail.com) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | import torch.nn as nn 18 | 19 | from typing import List, Tuple 20 | from sparktts.modules.fsq.residual_fsq import ResidualFSQ 21 | from sparktts.modules.speaker.ecapa_tdnn import ECAPA_TDNN_GLOB_c512 22 | from sparktts.modules.speaker.perceiver_encoder import PerceiverResampler 23 | 24 | """ 25 | x-vector + d-vector 26 | """ 27 | 28 | 29 | class SpeakerEncoder(nn.Module): 30 | """ 31 | 32 | Args: 33 | input_dim (int): acoustic feature dimension 34 | out_dim (int): output dimension of x-vector and d-vector 35 | latent_dim (int): latent dimension before quantization 36 | token_num (int): sequence length of speaker tokens 37 | fsq_levels (List[int]): number of levels for each quantizer 38 | fsq_num_quantizers (int): number of quantizers 39 | 40 | Return: 41 | speaker_embs: (B, T2, out_dim) 42 | """ 43 | 44 | def __init__( 45 | self, 46 | input_dim: int = 100, 47 | out_dim: int = 512, 48 | latent_dim: int = 128, 49 | token_num: int = 32, 50 | fsq_levels: List[int] = [4, 4, 4, 4, 4, 4], 51 | fsq_num_quantizers: int = 1, 52 | ): 53 | super(SpeakerEncoder, self).__init__() 54 | 55 | self.speaker_encoder = ECAPA_TDNN_GLOB_c512( 56 | feat_dim=input_dim, embed_dim=out_dim 57 | ) 58 | self.perceiver_sampler = PerceiverResampler( 59 | dim=latent_dim, dim_context=512 * 3, num_latents=token_num 60 | ) 61 | self.quantizer = ResidualFSQ( 62 | levels=fsq_levels, 63 | num_quantizers=fsq_num_quantizers, 64 | dim=latent_dim, 65 | is_channel_first=True, 66 | quantize_dropout=False, 67 | ) 68 | 69 | self.project = nn.Linear(latent_dim * token_num, out_dim) 70 | 71 | def get_codes_from_indices(self, indices: torch.Tensor) -> torch.Tensor: 72 | zq = self.quantizer.get_codes_from_indices(indices.transpose(1, 2)) 73 | return zq.transpose(1, 2) 74 | 75 | def get_indices(self, mels: torch.Tensor) -> torch.Tensor: 76 | mels = mels.transpose(1, 2) 77 | x = self.perceiver_sampler(mels).transpose(1, 2) 78 | zq, indices = self.quantizer(x) 79 | return indices 80 | 81 | def forward(self, mels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 82 | """ 83 | Args: 84 | mels: (B, D_mel, T1) 85 | 86 | Return: 87 | x_vector: (B, out_dim) 88 | d_vector: (B, out_dim) 89 | """ 90 | # mels = mels.transpose(1,2) 91 | 92 | x_vector, features = self.speaker_encoder(mels, True) 93 | x = self.perceiver_sampler(features.transpose(1, 2)).transpose(1, 2) 94 | zq, indices = self.quantizer(x) # zq: (B, latent_dim, T2, latent_dim) 95 | x = zq.reshape(zq.shape[0], -1) 96 | d_vector = self.project(x) 97 | 98 | return x_vector, d_vector 99 | 100 | def tokenize(self, mels: torch.Tensor) -> torch.Tensor: 101 | """tokenize the input mel spectrogram""" 102 | _, features = self.speaker_encoder(mels, True) 103 | x = self.perceiver_sampler(features.transpose(1, 2)).transpose(1, 2) 104 | zq, indices = self.quantizer(x) 105 | return indices 106 | 107 | def detokenize(self, indices: torch.Tensor) -> torch.Tensor: 108 | """detokenize the input indices to d-vector""" 109 | zq = self.quantizer.get_output_from_indices(indices.transpose(1, 2)).transpose(1, 2) 110 | x = zq.reshape(zq.shape[0], -1) 111 | d_vector = self.project(x) 112 | return d_vector 113 | 114 | if __name__ == "__main__": 115 | model = SpeakerEncoder( 116 | input_dim=100, 117 | latent_dim=128, 118 | token_num=32, 119 | fsq_levels=[4, 4, 4, 4, 4, 4], 120 | fsq_num_quantizers=1, 121 | ) 122 | mel = torch.randn(8, 200, 100) 123 | x_vector, d_vector = model(mel) 124 | print("x-vector shape", x_vector.shape) 125 | print("d-vector shape", d_vector.shape) 126 | 127 | indices = model.tokenize(mel) 128 | print("indices shape", indices.shape) 129 | d_vector_post = model.detokenize(indices) 130 | print("d-vector shape", d_vector_post.shape) 131 | if d_vector_post.all() == d_vector.all(): 132 | print("d-vector post and d-vector are the same") 133 | else: 134 | print("d-vector post and d-vector are different") 135 | num_params = sum(param.numel() for param in model.parameters()) 136 | print("{} M".format(num_params / 1e6)) -------------------------------------------------------------------------------- /sparktts/modules/vq/factorized_vector_quantize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 SparkAudio 2 | # 2025 Xinsheng Wang (w.xinshawn@gmail.com) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Heavily based on https://github.com/lucidrains/vector-quantize-pytorch 17 | 18 | 19 | from typing import Any, Dict 20 | 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | from einops import rearrange 25 | from torch.nn.utils import weight_norm 26 | 27 | 28 | def WNConv1d(*args, **kwargs): 29 | return weight_norm(nn.Conv1d(*args, **kwargs)) 30 | 31 | 32 | def ema_inplace(moving_avg, new, decay): 33 | moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) 34 | 35 | 36 | class FactorizedVectorQuantize(nn.Module): 37 | def __init__( 38 | self, 39 | input_dim: int, 40 | codebook_size: int, 41 | codebook_dim: int, 42 | commitment: float, 43 | codebook_loss_weight: float = 1.0, 44 | decay: float = 0.99, 45 | threshold_ema_dead_code: float = 2, 46 | momentum: float = 0.99, 47 | **kwargs, 48 | ): 49 | super().__init__() 50 | self.input_dim = input_dim 51 | self.codebook_size = codebook_size 52 | self.codebook_dim = codebook_dim 53 | self.commitment = commitment 54 | self.codebook_loss_weight = codebook_loss_weight 55 | self.decay = decay 56 | self.threshold_ema_dead_code = threshold_ema_dead_code 57 | self.momentum = momentum 58 | 59 | if input_dim != self.codebook_dim: 60 | self.in_project = WNConv1d(input_dim, self.codebook_dim, kernel_size=1) 61 | self.out_project = WNConv1d(self.codebook_dim, input_dim, kernel_size=1) 62 | 63 | else: 64 | self.in_project = nn.Identity() 65 | self.out_project = nn.Identity() 66 | 67 | self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim) 68 | self.register_buffer("cluster_size", torch.zeros(self.codebook_size)) 69 | 70 | def forward(self, z: torch.Tensor) -> Dict[str, Any]: 71 | """Quantized the input tensor using a fixed codebook and returns 72 | the corresponding codebook vectors 73 | 74 | Parameters 75 | ---------- 76 | z : Tensor[B x D x T] 77 | 78 | Returns 79 | ------- 80 | Tensor[B x D x T] 81 | Quantized continuous representation of input 82 | Tensor[1] 83 | Commitment loss to train encoder to predict vectors closer to codebook 84 | entries 85 | Tensor[1] 86 | Codebook loss to update the codebook 87 | Tensor[B x T] 88 | Codebook indices (quantized discrete representation of input) 89 | Tensor[B x D x T] 90 | Projected latents (continuous representation of input before quantization) 91 | """ 92 | # transpose since we use linear 93 | 94 | # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim 95 | z_e = self.in_project(z) 96 | z_q, indices, dists = self.decode_latents(z_e) 97 | 98 | # statistic the usage of codes 99 | embed_onehot = F.one_hot(indices, self.codebook_size).type(z_e.dtype) 100 | avg_probs = torch.mean(embed_onehot.reshape(-1, self.codebook_size), dim=0) 101 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) 102 | 103 | active_num = (embed_onehot.sum(0).sum(0) > 0).sum() 104 | if self.training: 105 | # We do the expiry of code at that point as buffers are in sync 106 | # and all the workers will take the same decision. 107 | ema_inplace(self.cluster_size, embed_onehot.sum(0).sum(0), self.decay) 108 | active_num = sum(self.cluster_size > self.threshold_ema_dead_code) 109 | 110 | if self.training: 111 | commit_loss = ( 112 | F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) 113 | * self.commitment 114 | ) 115 | 116 | codebook_loss = ( 117 | F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) 118 | * self.codebook_loss_weight 119 | ) 120 | 121 | else: 122 | commit_loss = torch.zeros(0, device=z.device) 123 | codebook_loss = torch.zeros(0, device=z.device) 124 | 125 | z_q = ( 126 | z_e + (z_q - z_e).detach() 127 | ) # noop in forward pass, straight-through gradient estimator in backward pass 128 | 129 | z_q = self.out_project(z_q) 130 | 131 | vq_loss = (commit_loss + codebook_loss).mean() 132 | 133 | return { 134 | "z_q": z_q, 135 | "indices": indices, 136 | "dists": dists, 137 | "vq_loss": vq_loss, 138 | "perplexity": perplexity, 139 | "active_num": active_num.float(), 140 | } 141 | 142 | def vq2emb(self, vq, out_proj=True): 143 | emb = self.embed_code(vq) 144 | if out_proj: 145 | emb = self.out_project(emb) 146 | return emb 147 | 148 | def tokenize(self, z: torch.Tensor) -> torch.Tensor: 149 | """tokenize the input tensor""" 150 | z_e = self.in_project(z) 151 | _, indices, _ = self.decode_latents(z_e) 152 | return indices 153 | 154 | def detokenize(self, indices): 155 | """detokenize the input indices""" 156 | z_q = self.decode_code(indices) 157 | z_q = self.out_project(z_q) 158 | return z_q 159 | 160 | def get_emb(self): 161 | return self.codebook.weight 162 | 163 | def embed_code(self, embed_id): 164 | return F.embedding(embed_id, self.codebook.weight) 165 | 166 | def decode_code(self, embed_id): 167 | return self.embed_code(embed_id).transpose(1, 2) 168 | 169 | def decode_latents(self, latents): 170 | encodings = rearrange(latents, "b d t -> (b t) d") 171 | codebook = self.codebook.weight 172 | 173 | # L2 normalize encodings and codebook 174 | encodings = F.normalize(encodings) 175 | codebook = F.normalize(codebook) 176 | 177 | # Compute euclidean distance between encodings and codebook, 178 | # with L2 normalization, the distance is equal to cosine distance 179 | dist = ( 180 | encodings.pow(2).sum(1, keepdim=True) 181 | - 2 * encodings @ codebook.t() 182 | + codebook.pow(2).sum(1, keepdim=True).t() 183 | ) 184 | indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) 185 | z_q = self.decode_code(indices) 186 | 187 | return z_q, indices, dist 188 | -------------------------------------------------------------------------------- /sparktts/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/sparktts/utils/__init__.py -------------------------------------------------------------------------------- /sparktts/utils/audio.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 SparkAudio 2 | # 2025 Xinsheng Wang (w.xinshawn@gmail.com) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | Description: 17 | This script contains a collection of functions designed to handle various 18 | audio processing. 19 | """ 20 | 21 | import random 22 | import soxr 23 | import soundfile 24 | import torch 25 | import torchaudio 26 | import numpy as np 27 | 28 | from pathlib import Path 29 | from typing import Tuple 30 | from numpy.lib.stride_tricks import sliding_window_view 31 | 32 | 33 | def audio_volume_normalize(audio: np.ndarray, coeff: float = 0.2) -> np.ndarray: 34 | """ 35 | Normalize the volume of an audio signal. 36 | 37 | Parameters: 38 | audio (numpy array): Input audio signal array. 39 | coeff (float): Target coefficient for normalization, default is 0.2. 40 | 41 | Returns: 42 | numpy array: The volume-normalized audio signal. 43 | """ 44 | # Sort the absolute values of the audio signal 45 | temp = np.sort(np.abs(audio)) 46 | 47 | # If the maximum value is less than 0.1, scale the array to have a maximum of 0.1 48 | if temp[-1] < 0.1: 49 | scaling_factor = max( 50 | temp[-1], 1e-3 51 | ) # Prevent division by zero with a small constant 52 | audio = audio / scaling_factor * 0.1 53 | 54 | # Filter out values less than 0.01 from temp 55 | temp = temp[temp > 0.01] 56 | L = temp.shape[0] # Length of the filtered array 57 | 58 | # If there are fewer than or equal to 10 significant values, return the audio without further processing 59 | if L <= 10: 60 | return audio 61 | 62 | # Compute the average of the top 10% to 1% of values in temp 63 | volume = np.mean(temp[int(0.9 * L) : int(0.99 * L)]) 64 | 65 | # Normalize the audio to the target coefficient level, clamping the scale factor between 0.1 and 10 66 | audio = audio * np.clip(coeff / volume, a_min=0.1, a_max=10) 67 | 68 | # Ensure the maximum absolute value in the audio does not exceed 1 69 | max_value = np.max(np.abs(audio)) 70 | if max_value > 1: 71 | audio = audio / max_value 72 | 73 | return audio 74 | 75 | 76 | def load_audio( 77 | adfile: Path, 78 | sampling_rate: int = None, 79 | length: int = None, 80 | volume_normalize: bool = False, 81 | segment_duration: int = None, 82 | ) -> np.ndarray: 83 | r"""Load audio file with target sampling rate and lsength 84 | 85 | Args: 86 | adfile (Path): path to audio file. 87 | sampling_rate (int, optional): target sampling rate. Defaults to None. 88 | length (int, optional): target audio length. Defaults to None. 89 | volume_normalize (bool, optional): whether perform volume normalization. Defaults to False. 90 | segment_duration (int): random select a segment with duration of {segment_duration}s. 91 | Defualt to None which means the whole audio will be used. 92 | 93 | Returns: 94 | audio (np.ndarray): audio 95 | """ 96 | 97 | audio, sr = soundfile.read(adfile) 98 | if len(audio.shape) > 1: 99 | audio = audio[:, 0] 100 | 101 | if sampling_rate is not None and sr != sampling_rate: 102 | audio = soxr.resample(audio, sr, sampling_rate, quality="VHQ") 103 | sr = sampling_rate 104 | 105 | if segment_duration is not None: 106 | seg_length = int(sr * segment_duration) 107 | audio = random_select_audio_segment(audio, seg_length) 108 | 109 | # Audio volume normalize 110 | if volume_normalize: 111 | audio = audio_volume_normalize(audio) 112 | # check the audio length 113 | if length is not None: 114 | assert abs(audio.shape[0] - length) < 1000 115 | if audio.shape[0] > length: 116 | audio = audio[:length] 117 | else: 118 | audio = np.pad(audio, (0, int(length - audio.shape[0]))) 119 | return audio 120 | 121 | 122 | def random_select_audio_segment(audio: np.ndarray, length: int) -> np.ndarray: 123 | """get an audio segment given the length 124 | 125 | Args: 126 | audio (np.ndarray): 127 | length (int): audio length = sampling_rate * duration 128 | """ 129 | if audio.shape[0] < length: 130 | audio = np.pad(audio, (0, int(length - audio.shape[0]))) 131 | start_index = random.randint(0, audio.shape[0] - length) 132 | end_index = int(start_index + length) 133 | 134 | return audio[start_index:end_index] 135 | 136 | 137 | def audio_highpass_filter(audio, sample_rate, highpass_cutoff_freq): 138 | """apply highpass fileter to audio 139 | 140 | Args: 141 | audio (np.ndarray): 142 | sample_rate (ind): 143 | highpass_cutoff_freq (int): 144 | """ 145 | 146 | audio = torchaudio.functional.highpass_biquad( 147 | torch.from_numpy(audio), sample_rate, cutoff_freq=highpass_cutoff_freq 148 | ) 149 | return audio.numpy() 150 | 151 | 152 | def stft( 153 | x: torch.Tensor, 154 | fft_size: int, 155 | hop_size: int, 156 | win_length: int, 157 | window: str, 158 | use_complex: bool = False, 159 | ) -> torch.Tensor: 160 | """Perform STFT and convert to magnitude spectrogram. 161 | Args: 162 | x (Tensor): Input signal tensor (B, T). 163 | fft_size (int): FFT size. 164 | hop_size (int): Hop size. 165 | win_length (int): Window length. 166 | window (str): Window function type. 167 | Returns: 168 | Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). 169 | """ 170 | 171 | x_stft = torch.stft( 172 | x, fft_size, hop_size, win_length, window.to(x.device), return_complex=True 173 | ) 174 | 175 | # clamp is needed to avoid nan or inf 176 | if not use_complex: 177 | return torch.sqrt( 178 | torch.clamp(x_stft.real**2 + x_stft.imag**2, min=1e-7, max=1e3) 179 | ).transpose(2, 1) 180 | else: 181 | res = torch.cat([x_stft.real.unsqueeze(1), x_stft.imag.unsqueeze(1)], dim=1) 182 | res = res.transpose(2, 3) # [B, 2, T, F] 183 | return res 184 | 185 | 186 | def detect_speech_boundaries( 187 | wav: np.ndarray, 188 | sample_rate: int, 189 | window_duration: float = 0.1, 190 | energy_threshold: float = 0.01, 191 | margin_factor: int = 2 192 | ) -> Tuple[int, int]: 193 | """Detect the start and end points of speech in an audio signal using RMS energy. 194 | 195 | Args: 196 | wav: Input audio signal array with values in [-1, 1] 197 | sample_rate: Audio sample rate in Hz 198 | window_duration: Duration of detection window in seconds 199 | energy_threshold: RMS energy threshold for speech detection 200 | margin_factor: Factor to determine extra margin around detected boundaries 201 | 202 | Returns: 203 | tuple: (start_index, end_index) of speech segment 204 | 205 | Raises: 206 | ValueError: If the audio contains only silence 207 | """ 208 | window_size = int(window_duration * sample_rate) 209 | margin = margin_factor * window_size 210 | step_size = window_size // 10 211 | 212 | # Create sliding windows using stride tricks to avoid loops 213 | windows = sliding_window_view(wav, window_size)[::step_size] 214 | 215 | # Calculate RMS energy for each window 216 | energy = np.sqrt(np.mean(windows ** 2, axis=1)) 217 | speech_mask = energy >= energy_threshold 218 | 219 | if not np.any(speech_mask): 220 | raise ValueError("No speech detected in audio (only silence)") 221 | 222 | start = max(0, np.argmax(speech_mask) * step_size - margin) 223 | end = min(len(wav), (len(speech_mask) - 1 - np.argmax(speech_mask[::-1])) * step_size + margin) 224 | 225 | return start, end 226 | 227 | 228 | def remove_silence_on_both_ends( 229 | wav: np.ndarray, 230 | sample_rate: int, 231 | window_duration: float = 0.1, 232 | volume_threshold: float = 0.01 233 | ) -> np.ndarray: 234 | """Remove silence from both ends of an audio signal. 235 | 236 | Args: 237 | wav: Input audio signal array 238 | sample_rate: Audio sample rate in Hz 239 | window_duration: Duration of detection window in seconds 240 | volume_threshold: Amplitude threshold for silence detection 241 | 242 | Returns: 243 | np.ndarray: Audio signal with silence removed from both ends 244 | 245 | Raises: 246 | ValueError: If the audio contains only silence 247 | """ 248 | start, end = detect_speech_boundaries( 249 | wav, 250 | sample_rate, 251 | window_duration, 252 | volume_threshold 253 | ) 254 | return wav[start:end] 255 | 256 | 257 | 258 | def hertz_to_mel(pitch: float) -> float: 259 | """ 260 | Converts a frequency from the Hertz scale to the Mel scale. 261 | 262 | Parameters: 263 | - pitch: float or ndarray 264 | Frequency in Hertz. 265 | 266 | Returns: 267 | - mel: float or ndarray 268 | Frequency in Mel scale. 269 | """ 270 | mel = 2595 * np.log10(1 + pitch / 700) 271 | return mel -------------------------------------------------------------------------------- /sparktts/utils/file.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 SparkAudio 2 | # 2025 Xinsheng Wang (w.xinshawn@gmail.com) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | Description: 17 | This script contains a collection of functions designed to handle various 18 | file reading and writing operations. It provides utilities to read from files, 19 | write data to files, and perform file manipulation tasks. 20 | """ 21 | 22 | 23 | import os 24 | import json 25 | import json 26 | import csv 27 | 28 | from tqdm import tqdm 29 | from typing import List, Dict, Any, Set, Union 30 | from pathlib import Path 31 | from omegaconf import OmegaConf, DictConfig 32 | 33 | 34 | def resolve_symbolic_link(symbolic_link_path: Path) -> Path: 35 | """ 36 | Resolves the absolute path of a symbolic link. 37 | 38 | Args: 39 | symbolic_link_path (Path): The path to the symbolic link. 40 | 41 | Returns: 42 | Path: The absolute path that the symbolic link points to. 43 | """ 44 | 45 | link_directory = os.path.dirname(symbolic_link_path) 46 | target_path_relative = os.readlink(symbolic_link_path) 47 | return os.path.join(link_directory, target_path_relative) 48 | 49 | 50 | def write_jsonl(metadata: List[dict], file_path: Path) -> None: 51 | """Writes a list of dictionaries to a JSONL file. 52 | 53 | Args: 54 | metadata : List[dict] 55 | A list of dictionaries, each representing a piece of meta. 56 | file_path : Path 57 | The file path to save the JSONL file 58 | 59 | This function writes each dictionary in the list to a new line in the specified file. 60 | """ 61 | with open(file_path, "w", encoding="utf-8") as f: 62 | for meta in tqdm(metadata, desc="writing jsonl"): 63 | # Convert dictionary to JSON string and write it to the file with a newline 64 | json_str = json.dumps(meta, ensure_ascii=False) + "\n" 65 | f.write(json_str) 66 | print(f"jsonl saved to {file_path}") 67 | 68 | 69 | def read_jsonl(file_path: Path) -> List[dict]: 70 | """ 71 | Reads a JSONL file and returns a list of dictionaries. 72 | 73 | Args: 74 | file_path : Path 75 | The path to the JSONL file to be read. 76 | 77 | Returns: 78 | List[dict] 79 | A list of dictionaries parsed from each line of the JSONL file. 80 | """ 81 | metadata = [] 82 | # Open the file for reading 83 | with open(file_path, "r", encoding="utf-8") as f: 84 | # Split the file into lines 85 | lines = f.read().splitlines() 86 | # Process each line 87 | for line in lines: 88 | # Convert JSON string back to dictionary and append to list 89 | meta = json.loads(line) 90 | metadata.append(meta) 91 | # Return the list of metadata 92 | return metadata 93 | 94 | def read_json_as_jsonl(file_path: Path) -> List[dict]: 95 | metadata = [] 96 | with open(file_path, 'r', encoding='utf-8') as infile: 97 | data = json.load(infile) 98 | for k in sorted(data.keys()): 99 | meta = {'index': k} 100 | meta.update(data[k]) 101 | metadata.append(meta) 102 | return metadata 103 | 104 | 105 | 106 | def decode_unicode_strings(meta: Dict[str, Any]) -> Dict[str, Any]: 107 | processed_meta = {} 108 | for k, v in meta.items(): 109 | if isinstance(v, str): 110 | processed_meta[k] = v.encode("utf-8").decode("unicode_escape") 111 | else: 112 | processed_meta[k] = v 113 | return processed_meta 114 | 115 | 116 | def load_config(config_path: Path) -> DictConfig: 117 | """Loads a configuration file and optionally merges it with a base configuration. 118 | 119 | Args: 120 | config_path (Path): Path to the configuration file. 121 | """ 122 | # Load the initial configuration from the given path 123 | config = OmegaConf.load(config_path) 124 | 125 | # Check if there is a base configuration specified and merge if necessary 126 | if config.get("base_config", None) is not None: 127 | base_config = OmegaConf.load(config["base_config"]) 128 | config = OmegaConf.merge(base_config, config) 129 | 130 | return config 131 | 132 | 133 | 134 | def jsonl_to_csv(jsonl_file_path: str, csv_file_path: str) -> None: 135 | """ 136 | Converts a JSONL file to a CSV file. 137 | 138 | This function reads a JSONL file, determines all unique keys present in the file, 139 | and writes the data to a CSV file with columns for all these keys. 140 | """ 141 | 142 | all_keys = set() 143 | data_rows = [] 144 | 145 | # Read the JSONL file once to extract keys and collect data 146 | with open(jsonl_file_path, 'r') as file: 147 | for line in file: 148 | data = json.loads(line.strip()) 149 | data_rows.append(data) 150 | all_keys.update(data.keys()) 151 | 152 | # Convert the set of keys to a sorted list for consistent column order 153 | sorted_keys = sorted(all_keys) 154 | 155 | # Write the data to a CSV file 156 | with open(csv_file_path, 'w', newline='') as csvfile: 157 | writer = csv.DictWriter(csvfile, fieldnames=sorted_keys) 158 | 159 | # Write the header row 160 | writer.writeheader() 161 | 162 | # Write each row of data 163 | for data in data_rows: 164 | writer.writerow(data) 165 | 166 | print(f"CSV file has been created at {csv_file_path}") 167 | 168 | 169 | def save_metadata(data, filename, headers=None): 170 | """ 171 | Save metadata to a file. 172 | 173 | Args: 174 | data (list of dict): Metadata to be saved. 175 | filename (str): Name of the file to save the metadata. 176 | headers (list of str): The order of column names to be saved; defaults to the keys from the first dictionary in data if not provided. 177 | """ 178 | # Set headers to keys from the first dictionary in data if not explicitly provided 179 | if headers is None: 180 | headers = list(data[0].keys()) 181 | 182 | with open(filename, "w", encoding="utf-8") as file: 183 | # Write the headers to the file 184 | file.write("|".join(headers) + "\n") 185 | for entry in data: 186 | # Retrieve values in the order of headers, replacing any '|' characters with a space to prevent formatting errors 187 | formatted_values = [str(entry.get(key, "")).replace("|", " ") for key in headers] 188 | # Write the formatted values to the file 189 | file.write("|".join(formatted_values) + "\n") 190 | 191 | 192 | def read_metadata(filename, headers=None): 193 | """ 194 | Read metadata from a file. 195 | 196 | Args: 197 | filename (str): The file from which to read the metadata. 198 | 199 | Returns: 200 | list of dict: The metadata read from the file. 201 | list of str: The headers used in the file. 202 | """ 203 | with open(filename, "r", encoding="utf-8") as file: 204 | lines = file.readlines() 205 | 206 | data = [] 207 | # Set headers from the first line of the file if not provided 208 | if headers is None: 209 | headers = lines[0].strip().split("|") 210 | lines = lines[1:] 211 | 212 | for line in lines: 213 | line = line.strip() 214 | # Skip empty lines 215 | if not line: 216 | continue 217 | # Split the line by '|' and pair with headers to form a dictionary 218 | entry_data = dict(zip(headers, line.split("|"))) 219 | data.append(entry_data) 220 | 221 | return data, headers 222 | -------------------------------------------------------------------------------- /sparktts/utils/parse_options.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey); 4 | # Arnab Ghoshal, Karel Vesely 5 | 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 14 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 15 | # MERCHANTABLITY OR NON-INFRINGEMENT. 16 | # See the Apache 2 License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | 20 | # Parse command-line options. 21 | # To be sourced by another script (as in ". parse_options.sh"). 22 | # Option format is: --option-name arg 23 | # and shell variable "option_name" gets set to value "arg." 24 | # The exception is --help, which takes no arguments, but prints the 25 | # $help_message variable (if defined). 26 | 27 | 28 | ### 29 | ### The --config file options have lower priority to command line 30 | ### options, so we need to import them first... 31 | ### 32 | 33 | # Now import all the configs specified by command-line, in left-to-right order 34 | # for ((argpos=1; argpos<$#; argpos++)); do 35 | # if [ "${!argpos}" == "--config" ]; then 36 | # argpos_plus1=$((argpos+1)) 37 | # config=${!argpos_plus1} 38 | # [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 39 | # . $config # source the config file. 40 | # fi 41 | # done 42 | 43 | 44 | ### 45 | ### No we process the command line options 46 | ### 47 | while true; do 48 | [ -z "${1:-}" ] && break; # break if there are no arguments 49 | case "$1" in 50 | # If the enclosing script is called with --help option, print the help 51 | # message and exit. Scripts should put help messages in $help_message 52 | --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; 53 | else printf "$help_message\n" 1>&2 ; fi; 54 | exit 0 ;; 55 | --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" 56 | exit 1 ;; 57 | # If the first command-line argument begins with "--" (e.g. --foo-bar), 58 | # then work out the variable name as $name, which will equal "foo_bar". 59 | --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; 60 | # Next we test whether the variable in question is undefned-- if so it's 61 | # an invalid option and we die. Note: $0 evaluates to the name of the 62 | # enclosing script. 63 | # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar 64 | # is undefined. We then have to wrap this test inside "eval" because 65 | # foo_bar is itself inside a variable ($name). 66 | eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; 67 | 68 | oldval="`eval echo \\$$name`"; 69 | # Work out whether we seem to be expecting a Boolean argument. 70 | if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then 71 | was_bool=true; 72 | else 73 | was_bool=false; 74 | fi 75 | 76 | # Set the variable to the right value-- the escaped quotes make it work if 77 | # the option had spaces, like --cmd "queue.pl -sync y" 78 | eval $name=\"$2\"; 79 | 80 | # Check that Boolean-valued arguments are really Boolean. 81 | if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then 82 | echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 83 | exit 1; 84 | fi 85 | shift 2; 86 | ;; 87 | *) break; 88 | esac 89 | done 90 | 91 | 92 | # Check for an empty argument to the --cmd option, which can easily occur as a 93 | # result of scripting errors. 94 | [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; 95 | 96 | 97 | true; # so this script returns exit code 0. -------------------------------------------------------------------------------- /sparktts/utils/token_parser.py: -------------------------------------------------------------------------------- 1 | TASK_TOKEN_MAP = { 2 | "vc": "<|task_vc|>", 3 | "tts": "<|task_tts|>", 4 | "asr": "<|task_asr|>", 5 | "s2s": "<|task_s2s|>", 6 | "t2s": "<|task_t2s|>", 7 | "understand": "<|task_understand|>", 8 | "caption": "<|task_cap|>", 9 | "controllable_tts": "<|task_controllable_tts|>", 10 | "prompt_tts": "<|task_prompt_tts|>", 11 | "speech_edit": "<|task_edit|>", 12 | } 13 | 14 | LEVELS_MAP = { 15 | "very_low": 0, 16 | "low": 1, 17 | "moderate": 2, 18 | "high": 3, 19 | "very_high": 4, 20 | } 21 | 22 | LEVELS_MAP_UI = { 23 | 1: 'very_low', 24 | 2: 'low', 25 | 3: 'moderate', 26 | 4: 'high', 27 | 5: 'very_high' 28 | } 29 | 30 | GENDER_MAP = { 31 | "female": 0, 32 | "male": 1, 33 | } 34 | 35 | AGE_MAP = {"Child": 0, "Teenager": 1, "Youth-Adult": 2, "Middle-aged": 3, "Elderly": 4} 36 | 37 | EMO_MAP = { 38 | "UNKNOWN": 0, 39 | "NEUTRAL": 1, 40 | "ANGRY": 2, 41 | "HAPPY": 3, 42 | "SAD": 4, 43 | "FEARFUL": 5, 44 | "DISGUSTED": 6, 45 | "SURPRISED": 7, 46 | "SARCASTIC": 8, 47 | "EXCITED": 9, 48 | "SLEEPY": 10, 49 | "CONFUSED": 11, 50 | "EMPHASIS": 12, 51 | "LAUGHING": 13, 52 | "SINGING": 14, 53 | "WORRIED": 15, 54 | "WHISPER": 16, 55 | "ANXIOUS": 17, 56 | "NO-AGREEMENT": 18, 57 | "APOLOGETIC": 19, 58 | "CONCERNED": 20, 59 | "ENUNCIATED": 21, 60 | "ASSERTIVE": 22, 61 | "ENCOURAGING": 23, 62 | "CONTEMPT": 24, 63 | } 64 | 65 | 66 | class TokenParser: 67 | """Turn label to special token""" 68 | 69 | def __init__(self): 70 | pass 71 | 72 | """Parse the attributes of a person.""" 73 | 74 | def __init__(self): 75 | pass 76 | 77 | @staticmethod 78 | def age(age: str) -> str: 79 | """Turn age token.""" 80 | age_id = AGE_MAP[age] 81 | return f"<|age_{age_id}|>" 82 | 83 | @staticmethod 84 | def gender(gender: str) -> str: 85 | """Turn gender token.""" 86 | gender_id = GENDER_MAP[gender] 87 | return f"<|gender_{gender_id}|>" 88 | 89 | @staticmethod 90 | def mel_value(mel: int): 91 | """Turn special token of mel scale pitch.""" 92 | mel = max(0, int(mel)) 93 | mel = min(1000, int(mel)) 94 | return f"<|pitch_value_{mel}|>" 95 | 96 | @staticmethod 97 | def mel_level(level: str): 98 | """Turn special token of mel level.""" 99 | level_tag = LEVELS_MAP[level] 100 | return f"<|pitch_label_{level_tag}|>" 101 | 102 | @staticmethod 103 | def pitch_var_value(pitch_std: int): 104 | """Turn special token of pitch_std value.""" 105 | assert isinstance(pitch_std, int) 106 | pitch_std = max(0, int(pitch_std)) 107 | pitch_std = min(10, int(pitch_std)) 108 | return f"<|pitch_var_value_{pitch_std}|>" 109 | 110 | @staticmethod 111 | def pitch_var_level(level: str): 112 | """Turn special token of pitch std level.""" 113 | level_tag = LEVELS_MAP[level] 114 | return f"<|pitch_var_label_{level_tag}|>" 115 | 116 | @staticmethod 117 | def loudness_value(loudness: int): 118 | """Turn special toak of loudness value [0, 30]""" 119 | assert loudness >= 0 120 | loudness = max(0, int(loudness)) 121 | loudness = min(30, int(loudness)) 122 | return f"<|loudness_value_{loudness}|>" 123 | 124 | @staticmethod 125 | def loudness_level(level: str): 126 | """Turn special token of loudness level.""" 127 | level_tag = LEVELS_MAP[level] 128 | return f"<|loudness_label_{level_tag}|>" 129 | 130 | @staticmethod 131 | def speed_value(speed: int): 132 | """Turn special token of speed value.""" 133 | speed = max(0, int(speed)) 134 | speed = min(10, int(speed)) 135 | return f"<|speed_value_{speed}|>" 136 | 137 | @staticmethod 138 | def speed_level(level: str): 139 | """Turn special token of speed level.""" 140 | level_tag = LEVELS_MAP[level] 141 | return f"<|speed_label_{level_tag}|>" 142 | 143 | @staticmethod 144 | def task(task: str) -> str: 145 | """Turn special token of task.""" 146 | assert task in TASK_TOKEN_MAP.keys() 147 | 148 | return TASK_TOKEN_MAP[task] 149 | 150 | @staticmethod 151 | def emotion(emotion: str): 152 | emo_id = EMO_MAP[emotion] 153 | 154 | return f"<|emotion_{emo_id}|>" 155 | 156 | 157 | # test 158 | if __name__ == "__main__": 159 | from transformers import AutoTokenizer 160 | 161 | tokenizer = AutoTokenizer.from_pretrained( 162 | "/aifs4su/xinshengwang/code/StyleCraft/tokenizer/stylecraft-bicodec-pitch-loudness-speed-emotion-tokenizer" 163 | ) 164 | 165 | tasks = ["tts", "tts", "understand", "controllable_tts", "prompt_tts"] 166 | ages = ["Child", "Teenager", "Youth-Adult", "Middle-aged", "Elderly"] 167 | genders = ["female", "female", "female", "male", "male"] 168 | mels = [100, 200, 300, 400, 500] 169 | mel_levels = ["very_low", "low", "moderate", "high", "very_high"] 170 | loudnesses = [1, 10, 23, 19, 30] 171 | loudness_levels = ["very_low", "low", "moderate", "high", "very_high"] 172 | emotions = ["UNKNOWN", "NEUTRAL", "ANGRY", "HAPPY", "SAD"] 173 | 174 | for i in range(5): 175 | task = TokenParser.task(tasks[i]) 176 | age = TokenParser.age(ages[i]) 177 | gender = TokenParser.gender(genders[i]) 178 | mel = TokenParser.mel_value(mels[i]) 179 | mel_level = TokenParser.mel_level(mel_levels[i]) 180 | loudness = TokenParser.loudness_value(loudnesses[i]) 181 | loudness_level = TokenParser.loudness_level(loudness_levels[i]) 182 | emotion = TokenParser.emotion(emotions[i]) 183 | inputs = [task, age, gender, mel, mel_level, loudness, loudness_level, emotion] 184 | inputs = "".join(inputs) 185 | ids = tokenizer.encode(inputs, add_special_tokens=False) 186 | print(ids) 187 | print("decode", tokenizer.decode(ids)) 188 | -------------------------------------------------------------------------------- /src/demos/trump/trump_en.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/src/demos/trump/trump_en.wav -------------------------------------------------------------------------------- /src/demos/zhongli/zhongli_en.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/src/demos/zhongli/zhongli_en.wav -------------------------------------------------------------------------------- /src/demos/余承东/yuchengdong_zh.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/src/demos/余承东/yuchengdong_zh.wav -------------------------------------------------------------------------------- /src/demos/刘德华/dehua_zh.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/src/demos/刘德华/dehua_zh.wav -------------------------------------------------------------------------------- /src/demos/哪吒/nezha_zh.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/src/demos/哪吒/nezha_zh.wav -------------------------------------------------------------------------------- /src/demos/徐志胜/zhisheng_zh.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/src/demos/徐志胜/zhisheng_zh.wav -------------------------------------------------------------------------------- /src/demos/李靖/lijing_zh.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/src/demos/李靖/lijing_zh.wav -------------------------------------------------------------------------------- /src/demos/杨澜/yanglan_zh.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/src/demos/杨澜/yanglan_zh.wav -------------------------------------------------------------------------------- /src/demos/马云/mayun_zh.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/src/demos/马云/mayun_zh.wav -------------------------------------------------------------------------------- /src/demos/鲁豫/luyu_zh.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/src/demos/鲁豫/luyu_zh.wav -------------------------------------------------------------------------------- /src/figures/gradio_TTS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/src/figures/gradio_TTS.png -------------------------------------------------------------------------------- /src/figures/gradio_control.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/src/figures/gradio_control.png -------------------------------------------------------------------------------- /src/figures/infer_control.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/src/figures/infer_control.png -------------------------------------------------------------------------------- /src/figures/infer_voice_cloning.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/src/figures/infer_voice_cloning.png -------------------------------------------------------------------------------- /src/logo/HKUST.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/src/logo/HKUST.jpg -------------------------------------------------------------------------------- /src/logo/NPU.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/src/logo/NPU.jpg -------------------------------------------------------------------------------- /src/logo/NTU.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/src/logo/NTU.jpg -------------------------------------------------------------------------------- /src/logo/SJU.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/src/logo/SJU.jpg -------------------------------------------------------------------------------- /src/logo/SparkAudio.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/src/logo/SparkAudio.jpg -------------------------------------------------------------------------------- /src/logo/SparkAudio2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/src/logo/SparkAudio2.jpg -------------------------------------------------------------------------------- /src/logo/SparkTTS.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/src/logo/SparkTTS.jpg -------------------------------------------------------------------------------- /src/logo/SparkTTS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/src/logo/SparkTTS.png -------------------------------------------------------------------------------- /src/logo/mobvoi.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/src/logo/mobvoi.jpg -------------------------------------------------------------------------------- /src/logo/mobvoi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkAudio/Spark-TTS/2f1ea9082400547242641f5271b6f941c9f439d1/src/logo/mobvoi.png -------------------------------------------------------------------------------- /webui.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 SparkAudio 2 | # 2025 Xinsheng Wang (w.xinshawn@gmail.com) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import torch 18 | import soundfile as sf 19 | import logging 20 | import argparse 21 | import gradio as gr 22 | import platform 23 | 24 | from datetime import datetime 25 | from cli.SparkTTS import SparkTTS 26 | from sparktts.utils.token_parser import LEVELS_MAP_UI 27 | 28 | 29 | def initialize_model(model_dir="pretrained_models/Spark-TTS-0.5B", device=0): 30 | """Load the model once at the beginning.""" 31 | logging.info(f"Loading model from: {model_dir}") 32 | 33 | # Determine appropriate device based on platform and availability 34 | if platform.system() == "Darwin": 35 | # macOS with MPS support (Apple Silicon) 36 | device = torch.device(f"mps:{device}") 37 | logging.info(f"Using MPS device: {device}") 38 | elif torch.cuda.is_available(): 39 | # System with CUDA support 40 | device = torch.device(f"cuda:{device}") 41 | logging.info(f"Using CUDA device: {device}") 42 | else: 43 | # Fall back to CPU 44 | device = torch.device("cpu") 45 | logging.info("GPU acceleration not available, using CPU") 46 | 47 | model = SparkTTS(model_dir, device) 48 | return model 49 | 50 | 51 | def run_tts( 52 | text, 53 | model, 54 | prompt_text=None, 55 | prompt_speech=None, 56 | gender=None, 57 | pitch=None, 58 | speed=None, 59 | save_dir="example/results", 60 | ): 61 | """Perform TTS inference and save the generated audio.""" 62 | logging.info(f"Saving audio to: {save_dir}") 63 | 64 | if prompt_text is not None: 65 | prompt_text = None if len(prompt_text) <= 1 else prompt_text 66 | 67 | # Ensure the save directory exists 68 | os.makedirs(save_dir, exist_ok=True) 69 | 70 | # Generate unique filename using timestamp 71 | timestamp = datetime.now().strftime("%Y%m%d%H%M%S") 72 | save_path = os.path.join(save_dir, f"{timestamp}.wav") 73 | 74 | logging.info("Starting inference...") 75 | 76 | # Perform inference and save the output audio 77 | with torch.no_grad(): 78 | wav = model.inference( 79 | text, 80 | prompt_speech, 81 | prompt_text, 82 | gender, 83 | pitch, 84 | speed, 85 | ) 86 | 87 | sf.write(save_path, wav, samplerate=16000) 88 | 89 | logging.info(f"Audio saved at: {save_path}") 90 | 91 | return save_path 92 | 93 | 94 | def build_ui(model_dir, device=0): 95 | 96 | # Initialize model 97 | model = initialize_model(model_dir, device=device) 98 | 99 | # Define callback function for voice cloning 100 | def voice_clone(text, prompt_text, prompt_wav_upload, prompt_wav_record): 101 | """ 102 | Gradio callback to clone voice using text and optional prompt speech. 103 | - text: The input text to be synthesised. 104 | - prompt_text: Additional textual info for the prompt (optional). 105 | - prompt_wav_upload/prompt_wav_record: Audio files used as reference. 106 | """ 107 | prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record 108 | prompt_text_clean = None if len(prompt_text) < 2 else prompt_text 109 | 110 | audio_output_path = run_tts( 111 | text, 112 | model, 113 | prompt_text=prompt_text_clean, 114 | prompt_speech=prompt_speech 115 | ) 116 | return audio_output_path 117 | 118 | # Define callback function for creating new voices 119 | def voice_creation(text, gender, pitch, speed): 120 | """ 121 | Gradio callback to create a synthetic voice with adjustable parameters. 122 | - text: The input text for synthesis. 123 | - gender: 'male' or 'female'. 124 | - pitch/speed: Ranges mapped by LEVELS_MAP_UI. 125 | """ 126 | pitch_val = LEVELS_MAP_UI[int(pitch)] 127 | speed_val = LEVELS_MAP_UI[int(speed)] 128 | audio_output_path = run_tts( 129 | text, 130 | model, 131 | gender=gender, 132 | pitch=pitch_val, 133 | speed=speed_val 134 | ) 135 | return audio_output_path 136 | 137 | with gr.Blocks() as demo: 138 | # Use HTML for centered title 139 | gr.HTML('

Spark-TTS by SparkAudio

') 140 | with gr.Tabs(): 141 | # Voice Clone Tab 142 | with gr.TabItem("Voice Clone"): 143 | gr.Markdown( 144 | "### Upload reference audio or recording (上传参考音频或者录音)" 145 | ) 146 | 147 | with gr.Row(): 148 | prompt_wav_upload = gr.Audio( 149 | sources="upload", 150 | type="filepath", 151 | label="Choose the prompt audio file, ensuring the sampling rate is no lower than 16kHz.", 152 | ) 153 | prompt_wav_record = gr.Audio( 154 | sources="microphone", 155 | type="filepath", 156 | label="Record the prompt audio file.", 157 | ) 158 | 159 | with gr.Row(): 160 | text_input = gr.Textbox( 161 | label="Text", lines=3, placeholder="Enter text here" 162 | ) 163 | prompt_text_input = gr.Textbox( 164 | label="Text of prompt speech (Optional; recommended for cloning in the same language.)", 165 | lines=3, 166 | placeholder="Enter text of the prompt speech.", 167 | ) 168 | 169 | audio_output = gr.Audio( 170 | label="Generated Audio", autoplay=True, streaming=True 171 | ) 172 | 173 | generate_buttom_clone = gr.Button("Generate") 174 | 175 | generate_buttom_clone.click( 176 | voice_clone, 177 | inputs=[ 178 | text_input, 179 | prompt_text_input, 180 | prompt_wav_upload, 181 | prompt_wav_record, 182 | ], 183 | outputs=[audio_output], 184 | ) 185 | 186 | # Voice Creation Tab 187 | with gr.TabItem("Voice Creation"): 188 | gr.Markdown( 189 | "### Create your own voice based on the following parameters" 190 | ) 191 | 192 | with gr.Row(): 193 | with gr.Column(): 194 | gender = gr.Radio( 195 | choices=["male", "female"], value="male", label="Gender" 196 | ) 197 | pitch = gr.Slider( 198 | minimum=1, maximum=5, step=1, value=3, label="Pitch" 199 | ) 200 | speed = gr.Slider( 201 | minimum=1, maximum=5, step=1, value=3, label="Speed" 202 | ) 203 | with gr.Column(): 204 | text_input_creation = gr.Textbox( 205 | label="Input Text", 206 | lines=3, 207 | placeholder="Enter text here", 208 | value="You can generate a customized voice by adjusting parameters such as pitch and speed.", 209 | ) 210 | create_button = gr.Button("Create Voice") 211 | 212 | audio_output = gr.Audio( 213 | label="Generated Audio", autoplay=True, streaming=True 214 | ) 215 | create_button.click( 216 | voice_creation, 217 | inputs=[text_input_creation, gender, pitch, speed], 218 | outputs=[audio_output], 219 | ) 220 | 221 | return demo 222 | 223 | 224 | def parse_arguments(): 225 | """ 226 | Parse command-line arguments such as model directory and device ID. 227 | """ 228 | parser = argparse.ArgumentParser(description="Spark TTS Gradio server.") 229 | parser.add_argument( 230 | "--model_dir", 231 | type=str, 232 | default="pretrained_models/Spark-TTS-0.5B", 233 | help="Path to the model directory." 234 | ) 235 | parser.add_argument( 236 | "--device", 237 | type=int, 238 | default=0, 239 | help="ID of the GPU device to use (e.g., 0 for cuda:0)." 240 | ) 241 | parser.add_argument( 242 | "--server_name", 243 | type=str, 244 | default="0.0.0.0", 245 | help="Server host/IP for Gradio app." 246 | ) 247 | parser.add_argument( 248 | "--server_port", 249 | type=int, 250 | default=7860, 251 | help="Server port for Gradio app." 252 | ) 253 | return parser.parse_args() 254 | 255 | if __name__ == "__main__": 256 | # Parse command-line arguments 257 | args = parse_arguments() 258 | 259 | # Build the Gradio demo by specifying the model directory and GPU device 260 | demo = build_ui( 261 | model_dir=args.model_dir, 262 | device=args.device 263 | ) 264 | 265 | # Launch Gradio with the specified server name and port 266 | demo.launch( 267 | server_name=args.server_name, 268 | server_port=args.server_port 269 | ) --------------------------------------------------------------------------------