├── .gitattributes ├── .github └── workflows │ ├── docker-image.yml │ ├── python-package.yml │ └── python-publish.yml ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── README_VN.md ├── assets ├── viet-tts-large.png ├── viet-tts-medium.png └── viet-tts.png ├── docker-compose.yaml ├── pyproject.toml ├── samples ├── cdteam.wav ├── cross_lingual_prompt.wav ├── diep-chi.wav ├── doremon.mp3 ├── jack-sparrow.mp3 ├── nguyen-ngoc-ngan.wav ├── nsnd-le-chuc.mp3 ├── nu-nhe-nhang.wav ├── quynh.wav ├── son-tung-mtp.wav ├── speechify_1.wav ├── speechify_10.wav ├── speechify_11.wav ├── speechify_12.wav ├── speechify_2.wav ├── speechify_3.wav ├── speechify_4.wav ├── speechify_5.wav ├── speechify_6.wav ├── speechify_7.wav ├── speechify_8.wav ├── speechify_9.wav └── zero_shot_prompt.wav ├── viettts ├── __init__.py ├── cli.py ├── flow │ ├── decoder.py │ ├── flow.py │ ├── flow_matching.py │ └── length_regulator.py ├── frontend.py ├── hifigan │ ├── f0_predictor.py │ └── generator.py ├── llm │ └── llm.py ├── model.py ├── server.py ├── tokenizer │ ├── multilingual.tiktoken │ └── tokenizer.py ├── transformer │ ├── __init__.py │ ├── activation.py │ ├── attention.py │ ├── convolution.py │ ├── decoder.py │ ├── decoder_layer.py │ ├── embedding.py │ ├── encoder.py │ ├── encoder_layer.py │ ├── label_smoothing_loss.py │ ├── positionwise_feed_forward.py │ ├── subsampling.py │ └── transformer.py ├── tts.py └── utils │ ├── __init__.py │ ├── class_utils.py │ ├── common.py │ ├── file_utils.py │ ├── frontend_utils.py │ ├── mask.py │ └── vad.py └── web └── .gitkeep /.gitattributes: -------------------------------------------------------------------------------- 1 | *.7z filter=lfs diff=lfs merge=lfs -text 2 | *.arrow filter=lfs diff=lfs merge=lfs -text 3 | *.bin filter=lfs diff=lfs merge=lfs -text 4 | *.bz2 filter=lfs diff=lfs merge=lfs -text 5 | *.ckpt filter=lfs diff=lfs merge=lfs -text 6 | *.ftz filter=lfs diff=lfs merge=lfs -text 7 | *.gz filter=lfs diff=lfs merge=lfs -text 8 | *.h5 filter=lfs diff=lfs merge=lfs -text 9 | *.joblib filter=lfs diff=lfs merge=lfs -text 10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text 11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text 12 | *.model filter=lfs diff=lfs merge=lfs -text 13 | *.msgpack filter=lfs diff=lfs merge=lfs -text 14 | *.npy filter=lfs diff=lfs merge=lfs -text 15 | *.npz filter=lfs diff=lfs merge=lfs -text 16 | *.onnx filter=lfs diff=lfs merge=lfs -text 17 | *.ot filter=lfs diff=lfs merge=lfs -text 18 | *.parquet filter=lfs diff=lfs merge=lfs -text 19 | *.pb filter=lfs diff=lfs merge=lfs -text 20 | *.pickle filter=lfs diff=lfs merge=lfs -text 21 | *.pkl filter=lfs diff=lfs merge=lfs -text 22 | *.pt filter=lfs diff=lfs merge=lfs -text 23 | *.pth filter=lfs diff=lfs merge=lfs -text 24 | *.rar filter=lfs diff=lfs merge=lfs -text 25 | *.safetensors filter=lfs diff=lfs merge=lfs -text 26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text 27 | *.tar.* filter=lfs diff=lfs merge=lfs -text 28 | *.tar filter=lfs diff=lfs merge=lfs -text 29 | *.tflite filter=lfs diff=lfs merge=lfs -text 30 | *.tgz filter=lfs diff=lfs merge=lfs -text 31 | *.wasm filter=lfs diff=lfs merge=lfs -text 32 | *.xz filter=lfs diff=lfs merge=lfs -text 33 | *.zip filter=lfs diff=lfs merge=lfs -text 34 | *.zst filter=lfs diff=lfs merge=lfs -text 35 | *tfevents* filter=lfs diff=lfs merge=lfs -text 36 | -------------------------------------------------------------------------------- /.github/workflows/docker-image.yml: -------------------------------------------------------------------------------- 1 | name: Docker Image CI 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | jobs: 10 | 11 | build: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: Build the Docker image 18 | run: docker build . --file Dockerfile --tag my-image-name:$(date +%s) 19 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | python-version: ["3.10", "3.11"] 20 | 21 | steps: 22 | - uses: actions/checkout@v4 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v3 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install Python environment with conda 28 | run: | 29 | pip install -e . && pip cache purge 30 | - name: Test viettts CLI 31 | run: | 32 | viettts show-voices 33 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package to PyPI when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | release-build: 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v4 24 | 25 | - uses: actions/setup-python@v5 26 | with: 27 | python-version: "3.x" 28 | 29 | - name: Build release distributions 30 | run: | 31 | # NOTE: put your own distribution build steps here. 32 | python -m pip install build 33 | python -m build 34 | 35 | - name: Upload distributions 36 | uses: actions/upload-artifact@v4 37 | with: 38 | name: release-dists 39 | path: dist/ 40 | 41 | pypi-publish: 42 | runs-on: ubuntu-latest 43 | needs: 44 | - release-build 45 | permissions: 46 | # IMPORTANT: this permission is mandatory for trusted publishing 47 | id-token: write 48 | 49 | # Dedicated environments with protections for publishing are strongly recommended. 50 | # For more information, see: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment#deployment-protection-rules 51 | environment: 52 | name: pypi 53 | # OPTIONAL: uncomment and update to include your PyPI project URL in the deployment status: 54 | # url: https://pypi.org/p/YOURPROJECT 55 | # 56 | # ALTERNATIVE: if your GitHub Release name is the PyPI project version string 57 | # ALTERNATIVE: exactly, uncomment the following line instead: 58 | # url: https://pypi.org/project/YOURPROJECT/${{ github.event.release.name }} 59 | 60 | steps: 61 | - name: Retrieve release distributions 62 | uses: actions/download-artifact@v4 63 | with: 64 | name: release-dists 65 | path: dist/ 66 | 67 | - name: Publish release distributions to PyPI 68 | uses: pypa/gh-action-pypi-publish@release/v1 69 | with: 70 | packages-dir: dist/ 71 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Visual Studio Code files 7 | .vscode 8 | .vs 9 | 10 | # PyCharm files 11 | .idea 12 | 13 | # Eclipse Project settings 14 | *.*project 15 | .settings 16 | 17 | # Sublime Text settings 18 | *.sublime-workspace 19 | *.sublime-project 20 | 21 | # Editor temporaries 22 | *.swn 23 | *.swo 24 | *.swp 25 | *.swm 26 | *~ 27 | 28 | # IPython notebook checkpoints 29 | .ipynb_checkpoints 30 | 31 | # macOS dir files 32 | .DS_Store 33 | 34 | exp 35 | data* 36 | raw_wav 37 | tensorboard 38 | **/*build* 39 | 40 | # Clangd files 41 | .cache 42 | .github 43 | compile_commands.json 44 | node_modules 45 | 46 | # train/inference files 47 | *.m4a 48 | *.aac 49 | *.pt 50 | *.egg-info 51 | *dist 52 | *parcel-cache 53 | pretrained-models/* 54 | *_pb2_grpc.py 55 | *_pb2.py 56 | poetry.lock -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime 2 | 3 | ENV TZ=UTC DEBIAN_FRONTEND=noninteractive 4 | 5 | WORKDIR /app 6 | 7 | ENV POETRY_VERSION=1.8.3 8 | 9 | RUN pip install --no-cache-dir poetry==${POETRY_VERSION} 10 | 11 | ENV POETRY_CACHE_DIR=/tmp/poetry_cache 12 | ENV POETRY_NO_INTERACTION=1 13 | ENV POETRY_VIRTUALENVS_IN_PROJECT=true 14 | ENV POETRY_VIRTUALENVS_CREATE=true 15 | ENV POETRY_REQUESTS_TIMEOUT=15 16 | 17 | RUN apt-get update && apt-get install -y --no-install-recommends \ 18 | build-essential \ 19 | gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev \ 20 | ffmpeg \ 21 | sox 22 | 23 | RUN apt clean && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* 24 | 25 | COPY ./viettts /app/viettts 26 | COPY ./samples /app/samples 27 | COPY ./web /app/web 28 | COPY ./pyproject.toml /app/ 29 | COPY ./README.md /app/ 30 | 31 | RUN pip install -e . && pip cache purge -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |

3 | 4 |

VietTTS: An Open-Source Vietnamese Text to Speech

5 |

6 |

7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 |

18 | 19 | **VietTTS** is an open-source toolkit providing the community with a powerful Vietnamese TTS model, capable of natural voice synthesis and robust voice cloning. Designed for effective experimentation, **VietTTS** supports research and application in Vietnamese voice technologies. 20 | 21 | ## ⭐ Key Features 22 | - **TTS**: Text-to-Speech generation with any voice via prompt audio 23 | - **OpenAI-API-compatible**: Compatible with OpenAI's Text-to-Speech API format 24 | 25 | ## 🛠️ Installation 26 | 27 | VietTTS can be installed via a Python installer (Linux only, with Windows and macOS support coming soon) or Docker. 28 | 29 | ### Python Installer (Python>=3.10) 30 | ```bash 31 | git clone https://github.com/dangvansam/viet-tts.git 32 | cd viet-tts 33 | 34 | # (Optional) Install Python environment with conda, you could also use virtualenv 35 | conda create --name viettts python=3.10 36 | conda activate viettts 37 | 38 | # Install 39 | pip install -e . && pip cache purge 40 | ``` 41 | 42 | ### Docker 43 | 44 | 1. Install [Docker](https://docs.docker.com/get-docker/), [NVIDIA Driver](https://www.nvidia.com/download/index.aspx), [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html), and [CUDA](https://developer.nvidia.com/cuda-downloads). 45 | 46 | 2. Run the following commands: 47 | ```bash 48 | git clone https://github.com/dangvansam/viet-tts.git 49 | cd viet-tts 50 | 51 | # Build docker images 52 | docker compose build 53 | 54 | # Run with docker-compose - will create server at: http://localhost:8298 55 | docker compose up -d 56 | 57 | # Or run with docker run - will create server at: http://localhost:8298 58 | docker run -itd --gpu=alls -p 8298:8298 -v ./pretrained-models:/app/pretrained-models -n viet-tts-service viet-tts:latest viettts server --host 0.0.0.0 --port 8298 59 | ``` 60 | 61 | ## 🚀 Usage 62 | 63 | ### Built-in Voices 🤠 64 | You can use available voices bellow to synthesize speech. 65 |
66 | Expand 67 | 68 | | ID | Voice | Gender | Play Audio | 69 | |-----|-----------------------|--------|--------------------------------------------------| 70 | | 1 | nsnd-le-chuc | 👨 | | 71 | | 2 | speechify_10 | 👩 | | 72 | | 3 | atuan | 👨 | | 73 | | 4 | speechify_11 | 👩 | | 74 | | 5 | cdteam | 👨 | | 75 | | 6 | speechify_12 | 👩 | | 76 | | 7 | cross_lingual_prompt | 👩 | | 77 | | 8 | speechify_2 | 👩 | | 78 | | 9 | diep-chi | 👨 | | 79 | | 10 | speechify_3 | 👩 | | 80 | | 11 | doremon | 👨 | | 81 | | 12 | speechify_4 | 👩 | | 82 | | 13 | jack-sparrow | 👨 | | 83 | | 14 | speechify_5 | 👩 | | 84 | | 15 | nguyen-ngoc-ngan | 👩 | | 85 | | 16 | speechify_6 | 👩 | | 86 | | 17 | nu-nhe-nhang | 👩 | | 87 | | 18 | speechify_7 | 👩 | | 88 | | 19 | quynh | 👩 | | 89 | | 20 | speechify_8 | 👩 | | 90 | | 21 | speechify_9 | 👩 | | 91 | | 22 | son-tung-mtp | 👨 | | 92 | | 23 | zero_shot_prompt | 👩 | | 93 | | 24 | speechify_1 | 👩 | | 94 | 95 |
96 |
97 |
98 | 99 | ### Command Line Interface (CLI) 100 | The VietTTS Command Line Interface (CLI) allows you to quickly generate speech directly from the terminal. Here's how to use it: 101 | ```bash 102 | # Usage 103 | viettts --help 104 | 105 | # Start API Server 106 | viettts server --host 0.0.0.0 --port 8298 107 | 108 | # List all built-in voices 109 | viettts show-voices 110 | 111 | # Synthesize speech from text with built-in voices 112 | viettts synthesis --text "Xin chào" --voice 0 --output test.wav 113 | 114 | # Clone voice from a local audio file 115 | viettts synthesis --text "Xin chào" --voice Download/voice.wav --output cloned.wav 116 | ``` 117 | 118 | ### API Client 119 | #### Python (OpenAI Client) 120 | You need to set environment variables for the OpenAI Client: 121 | ```bash 122 | # Set base_url and API key as environment variables 123 | export OPENAI_BASE_URL=http://localhost:8298 124 | export OPENAI_API_KEY=viet-tts # not use in current version 125 | ``` 126 | To create speech from input text: 127 | ```python 128 | from pathlib import Path 129 | from openai import OpenAI 130 | 131 | client = OpenAI() 132 | 133 | output_file_path = Path(__file__).parent / "speech.wav" 134 | 135 | with client.audio.speech.with_streaming_response.create( 136 | model='tts-1', 137 | voice='cdteam', 138 | input='Xin chào Việt Nam.', 139 | speed=1.0, 140 | response_format='wav' 141 | ) as response: 142 | response.stream_to_file('a.wav') 143 | ``` 144 | 145 | #### CURL 146 | ```bash 147 | # Get all built-in voices 148 | curl --location http://0.0.0.0:8298/v1/voices 149 | 150 | # OpenAI format (bult-in voices) 151 | curl http://localhost:8298/v1/audio/speech \ 152 |   -H "Authorization: Bearer viet-tts" \ 153 |   -H "Content-Type: application/json" \ 154 |   -d '{ 155 |     "model": "tts-1", 156 |     "input": "Xin chào Việt Nam.", 157 |     "voice": "son-tung-mtp" 158 |   }' \ 159 |   --output speech.wav 160 | 161 | # API with voice from local file 162 | curl --location http://0.0.0.0:8298/v1/tts \ 163 | --form 'text="xin chào"' \ 164 | --form 'audio_file=@"/home/viettts/Downloads/voice.mp4"' \ 165 | --output speech.wav 166 | ``` 167 | 168 | #### Node 169 | ```js 170 | import fs from "fs"; 171 | import path from "path"; 172 | import OpenAI from "openai"; 173 | 174 | const openai = new OpenAI(); 175 | 176 | const speechFile = path.resolve("./speech.wav"); 177 | 178 | async function main() { 179 | const mp3 = await openai.audio.speech.create({ 180 | model: "tts-1", 181 | voice: "1", 182 | input: "Xin chào Việt Nam.", 183 | }); 184 | console.log(speechFile); 185 | const buffer = Buffer.from(await mp3.arrayBuffer()); 186 | await fs.promises.writeFile(speechFile, buffer); 187 | } 188 | main(); 189 | ``` 190 | 191 | ## 🙏 Acknowledgement 192 | - 💡 Borrowed code from [Cosyvoice](https://github.com/FunAudioLLM/CosyVoice) 193 | - 🎙️ VAD model from [silero-vad](https://github.com/snakers4/silero-vad) 194 | - 📝 Text normalization with [Vinorm](https://github.com/v-nhandt21/Vinorm) 195 | 196 | ## 📜 License 197 | The **VietTTS** source code is released under the **Apache 2.0 License**. Pre-trained models and audio samples are licensed under the **CC BY-NC License**, based on an in-the-wild dataset. We apologize for any inconvenience this may cause. 198 | 199 | ## ⚠️ Disclaimer 200 | The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal. 201 | 202 | ## 💬 Contact 203 | - Facebook: https://fb.com/sam.rngd 204 | - GitHub: https://github.com/dangvansam 205 | - Email: dangvansam98@gmail.com -------------------------------------------------------------------------------- /README_VN.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |  

VietTTS: Công cụ chuyển văn bản thành giọng nói tiếng Việt mã nguồn mở

4 |

5 |

6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 | 14 | 15 |

16 | 17 | **VietTTS** là một bộ công cụ mã nguồn mở cung cấp mô hình TTS tiếng Việt mạnh mẽ, cho phép tổng hợp giọng nói tự nhiên và tạo giọng nói mới. **VietTTS** hỗ trợ nghiên cứu và ứng dụng trong công nghệ giọng nói tiếng Việt. 18 | 19 | ## ⭐ Tính năng nổi bật 20 | - **TTS**: Tổng hợp giọng nói từ văn bản với bất kỳ giọng nào qua audio mẫu 21 | - **OpenAI-API-compatible**: Tương thích với API Text to Speech OpenAI 22 | 23 | ## 🛠️ Cài đặt 24 | VietTTS có thể được cài đặt qua trình cài đặt Python (chỉ hỗ trợ Linux, Windows và macOS sẽ có trong tương lai) hoặc Docker. 25 | 26 | ### Trình cài đặt Python (Python>=3.10) 27 | 28 | ```bash 29 | git clone https://github.com/dangvansam/viet-tts.git 30 | cd viet-tts 31 | 32 | # (Tùy chọn) Tạo môi trường Python với conda hoặc dùng virtualenv 33 | conda create --name viettts python=3.10 34 | conda activate viettts 35 | 36 | # Cài đặt 37 | pip install -e . && pip cache purge 38 | ``` 39 | 40 | ### Docker 41 | 1. Cài đặt [Docker](https://docs.docker.com/get-docker/), [NVIDIA Driver](https://www.nvidia.com/download/index.aspx), [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html), và [CUDA](https://developer.nvidia.com/cuda-downloads). 42 | 43 | 2. Chạy các lệnh sau: 44 | ```bash 45 | git clone https://github.com/dangvansam/viet-tts.git 46 | cd viet-tts 47 | 48 | # Xây dựng hình ảnh docker 49 | docker compose build 50 | 51 | # Chạy bằng docker-compose - tạo server tại: http://localhost:8298 52 | docker compose up -d 53 | 54 | # Chạy bằng docker run - tạo server tại: http://localhost:8298 55 | docker run -itd --gpu=alls -p 8298:8298 -v ./pretrained-models:/app/pretrained-models -n viet-tts-service viet-tts:latest viettts server --host 0.0.0.0 --port 8298 56 | ``` 57 | 58 | ## 🚀 Sử dụng 59 | 60 | ### Giọng nói tích hợp 🤠 61 | Bạn có thể sử dụng các giọng nói có sẵn dưới đây để tổng hợp giọng nói. 62 |
63 |   Mở rộng 64 | 65 | | ID  | Giọng                   | Giới tính | Phát âm thanh                                   | 66 | |-----|--------------------------|-----------|-------------------------------------------------| 67 | | 1   | nsnd-le-chuc             | 👨        | | 68 | | 2   | speechify_10             | 👩        | | 69 | | 3   | atuan                    | 👨        |        | 70 | | 4   | speechify_11             | 👩        | | 71 | | 5   | cdteam                   | 👨        |       | 72 | | 6   | speechify_12             | 👩        | | 73 | | 7   | cross_lingual_prompt     | 👩        | | 74 | | 8   | speechify_2              | 👩        |   | 75 | | 9   | diep-chi                 | 👨        |      | 76 | | 10  | speechify_3              | 👩        |   | 77 | | 11  | doremon                  | 👨        |       | 78 | | 12  | speechify_4              | 👩        |   | 79 | | 13  | jack-sparrow             | 👨        | | 80 | | 14  | speechify_5              | 👩        |   | 81 | | 15  | nguyen-ngoc-ngan         | 👩        | | 82 | | 16  | speechify_6              | 👩        |   | 83 | | 17  | nu-nhe-nhang             | 👩        | | 84 | | 18  | speechify_7              | 👩        |   | 85 | | 19  | quynh                    | 👩        |         | 86 | | 20  | speechify_8              | 👩        |   | 87 | | 21  | speechify_9              | 👩        |   | 88 | | 22  | son-tung-mtp             | 👨        | | 89 | | 23  | zero_shot_prompt         | 👩        | | 90 | | 24  | speechify_1              | 👩        |   | 91 | 92 |  
93 | 94 |  
95 | 96 |
97 | 98 | ### Thực thi với lệnh (CLI) 99 | 100 | Giao diện dòng lệnh VietTTS cho phép bạn tạo giọng nói từ terminal. Cách sử dụng: 101 | 102 | ```bash 103 | # Hướng dẫn sử dụng 104 | viettts --help 105 | 106 | # Khởi động API Server 107 | viettts server --host 0.0.0.0 --port 8298 108 | 109 | # Xem tất cả các giọng nói có sẵn 110 | viettts show-voices 111 | 112 | # Tổng hợp giọng nói từ văn bản với giọng có sẵn 113 | viettts synthesis --text "Xin chào" --voice 0 --output test.wav 114 | 115 | # Sao chép giọng từ audio file bất kì 116 | viettts synthesis --text "Xin chào" --voice Download/voice.wav --output cloned.wav 117 | ``` 118 | 119 | ### API Client 120 | #### Python (OpenAI Client) 121 | Thiết lập biến môi trường cho OpenAI Client: 122 | 123 | ```bash 124 | # Thiết lập base_url và API key như biến môi trường 125 | export OPENAI_BASE_URL=http://localhost:8298 126 | export OPENAI_API_KEY=viet-tts # không dùng trong phiên bản hiện tại 127 | ``` 128 | 129 | Để tạo giọng nói từ văn bản đầu vào: 130 | 131 | ```python 132 | from pathlib import Path 133 | from openai import OpenAI 134 | 135 | 136 | 137 | client = OpenAI() 138 | output_file_path = Path(__file__).parent / "speech.wav" 139 | 140 | with client.audio.speech.with_streaming_response.create( 141 | model='tts-1', 142 | voice='cdteam', 143 | input='Xin chào Việt Nam.', 144 | speed=1.0, 145 | response_format='wav' 146 | ) as response: 147 | response.stream_to_file('a.wav') 148 | ``` 149 | 150 | #### CURL 151 | ```bash 152 | # Lấy danh sách giọng có sẵn 153 | curl --location http://0.0.0.0:8298/v1/voices 154 | 155 | # OpenAI API format 156 | curl http://localhost:8298/v1/audio/speech \ 157 |   -H "Authorization: Bearer viet-tts" \ 158 |   -H "Content-Type: application/json" \ 159 |   -d '{ 160 |     "model": "tts-1", 161 |     "input": "Xin chào Việt Nam.", 162 |     "voice": "son-tung-mtp" 163 |   }' \ 164 |   --output speech.wav 165 | 166 | # API với giọng từ file local 167 | curl --location http://0.0.0.0:8298/v1/tts \ 168 | --form 'text="xin chào"' \ 169 | --form 'audio_file=@"/home/viettts/Downloads/voice.mp4"' \ 170 | --output speech.wav 171 | ``` 172 | 173 | #### Node 174 | ```js 175 | import fs from "fs"; 176 | import path from "path"; 177 | import OpenAI from "openai"; 178 | 179 | const openai = new OpenAI(); 180 | const speechFile = path.resolve("./speech.wav"); 181 | 182 | async function main() { 183 |   const mp3 = await openai.audio.speech.create({ 184 |     model: "tts-1", 185 |     voice: "1", 186 |     input: "Xin chào Việt Nam.", 187 |   }); 188 |   console.log(speechFile); 189 |   const buffer = Buffer.from(await mp3.arrayBuffer()); 190 |   await fs.promises.writeFile(speechFile, buffer); 191 | } 192 | main(); 193 | ``` 194 | 195 | ## 🙏 Mã liên quan 196 | - 💡 Sử dụng mã từ [Cosyvoice](https://github.com/FunAudioLLM/CosyVoice) 197 | - 🎙️ Mô hình VAD từ [silero-vad](https://github.com/snakers4/silero-vad) 198 | - 📝 Chuẩn hóa văn bản với [Vinorm](https://github.com/v-nhandt21/Vinorm) 199 | 200 | ## 📜 Giấy phép 201 | Mã nguồn của **VietTTS** được cấp phép theo **Apache 2.0 License**. Mô hình và mẫu âm thanh huấn luyện được cấp phép theo **CC BY-NC License**, dựa trên tập dữ liệu từ internet. Xin lỗi nếu điều này gây bất tiện. 202 | 203 | ## ⚠️ Tuyên bố miễn trừ trách nhiệm 204 | Nội dung trên chỉ phục vụ mục đích học thuật và nhằm trình bày khả năng kỹ thuật. Một số ví dụ lấy từ internet. Nếu nội dung vi phạm quyền của bạn, vui lòng liên hệ để được gỡ bỏ. 205 | 206 | ## 💬 Liên hệ 207 | - Facebook: https://fb.com/sam.rngd 208 | - GitHub: https://github.com/dangvansam 209 | - Email: dangvansam98@gmail.com -------------------------------------------------------------------------------- /assets/viet-tts-large.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/assets/viet-tts-large.png -------------------------------------------------------------------------------- /assets/viet-tts-medium.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/assets/viet-tts-medium.png -------------------------------------------------------------------------------- /assets/viet-tts.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/assets/viet-tts.png -------------------------------------------------------------------------------- /docker-compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | api: 3 | image: viet-tts:latest 4 | build: . 5 | restart: always 6 | container_name: viet-tts-service 7 | ports: 8 | - 8298:8298 9 | deploy: 10 | resources: 11 | reservations: 12 | devices: 13 | - driver: nvidia 14 | count: 1 15 | capabilities: [gpu] 16 | volumes: 17 | - ./pretrained-models:/app/pretrained-models 18 | command: viettts server --host 0.0.0.0 --port 8298 -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "viettts" 3 | version = "0.1.0" 4 | description = "VietTTS: An Open-Source Vietnamese Text to Speech" 5 | authors = ["dangvansam "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.10" 10 | conformer = "0.3.2" 11 | diffusers = "0.27.2" 12 | gradio = "4.32.2" 13 | hydra-core = "1.3.2" 14 | hyperpyyaml = "1.2.2" 15 | librosa = "0.10.2" 16 | omegaconf = "2.3.0" 17 | onnx = "1.16.0" 18 | onnxruntime-gpu = "1.16.0" 19 | protobuf = "4.25" 20 | pydantic = "2.7.0" 21 | soundfile = "0.12.1" 22 | torch = "2.0.1" 23 | torchaudio = "2.0.2" 24 | uvicorn = "0.30.0" 25 | wget = "3.2" 26 | fastapi = "0.111.0" 27 | fastapi-cli = "0.0.4" 28 | loguru = "0.7.2" 29 | vinorm = "^2.0.7" 30 | huggingface-hub = "0.24.7" 31 | click = "^8.1.7" 32 | gunicorn = "^23.0.0" 33 | silero-vad = "^5.1.2" 34 | tiktoken = "^0.8.0" 35 | openai-whisper = "^20240930" 36 | 37 | [tool.poetry.scripts] 38 | viettts = "viettts.cli:cli" 39 | 40 | [build-system] 41 | requires = ["poetry-core"] 42 | build-backend = "poetry.core.masonry.api" 43 | 44 | [tool.setuptools] 45 | packages = ["viettts"] -------------------------------------------------------------------------------- /samples/cdteam.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/cdteam.wav -------------------------------------------------------------------------------- /samples/cross_lingual_prompt.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/cross_lingual_prompt.wav -------------------------------------------------------------------------------- /samples/diep-chi.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/diep-chi.wav -------------------------------------------------------------------------------- /samples/doremon.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/doremon.mp3 -------------------------------------------------------------------------------- /samples/jack-sparrow.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/jack-sparrow.mp3 -------------------------------------------------------------------------------- /samples/nguyen-ngoc-ngan.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/nguyen-ngoc-ngan.wav -------------------------------------------------------------------------------- /samples/nsnd-le-chuc.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/nsnd-le-chuc.mp3 -------------------------------------------------------------------------------- /samples/nu-nhe-nhang.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/nu-nhe-nhang.wav -------------------------------------------------------------------------------- /samples/quynh.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/quynh.wav -------------------------------------------------------------------------------- /samples/son-tung-mtp.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/son-tung-mtp.wav -------------------------------------------------------------------------------- /samples/speechify_1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/speechify_1.wav -------------------------------------------------------------------------------- /samples/speechify_10.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/speechify_10.wav -------------------------------------------------------------------------------- /samples/speechify_11.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/speechify_11.wav -------------------------------------------------------------------------------- /samples/speechify_12.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/speechify_12.wav -------------------------------------------------------------------------------- /samples/speechify_2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/speechify_2.wav -------------------------------------------------------------------------------- /samples/speechify_3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/speechify_3.wav -------------------------------------------------------------------------------- /samples/speechify_4.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/speechify_4.wav -------------------------------------------------------------------------------- /samples/speechify_5.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/speechify_5.wav -------------------------------------------------------------------------------- /samples/speechify_6.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/speechify_6.wav -------------------------------------------------------------------------------- /samples/speechify_7.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/speechify_7.wav -------------------------------------------------------------------------------- /samples/speechify_8.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/speechify_8.wav -------------------------------------------------------------------------------- /samples/speechify_9.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/speechify_9.wav -------------------------------------------------------------------------------- /samples/zero_shot_prompt.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/zero_shot_prompt.wav -------------------------------------------------------------------------------- /viettts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/viettts/__init__.py -------------------------------------------------------------------------------- /viettts/cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import click 5 | import subprocess 6 | from loguru import logger 7 | from rich.table import Table 8 | from rich.console import Console 9 | from viettts.tts import TTS 10 | from viettts.utils.file_utils import load_prompt_speech_from_file, load_voices 11 | 12 | 13 | AUDIO_DIR = 'samples' 14 | MODEL_DIR = 'pretrained-models' 15 | 16 | @click.command('server') 17 | @click.option('-h', '--host', type=str, default='0.0.0.0', help="The host address to bind the server to. Default is '0.0.0.0'.") 18 | @click.option('-p', '--port', type=int, default=8298, help="The port number to bind the server to. Default is 8298.") 19 | @click.option('-w', '--workers', type=int, default=1, help="The number of worker processes to handle requests. Default is 1.") 20 | def start_server(host: str, port: int, workers: int): 21 | """Start API server (OpenAI TTS API compatible). 22 | 23 | Usage: viettts server --host 0.0.0.0 --port 8298 -w 4 24 | """ 25 | logger.info("Starting server") 26 | cmd = f'gunicorn viettts.server:app \ 27 | -k uvicorn.workers.UvicornWorker \ 28 | --bind {host}:{port} \ 29 | --workers {workers} \ 30 | --max-requests 1000 \ 31 | --max-requests-jitter 50 \ 32 | --timeout 300 \ 33 | --keep-alive 75 \ 34 | --graceful-timeout 60' 35 | 36 | subprocess.call(cmd, shell=True, stdout=sys.stdout) 37 | 38 | 39 | @click.command('synthesis') 40 | @click.option('-t', "--text", type=str, required=True, help="The input text to synthesize into speech.") 41 | @click.option('-v', "--voice", type=str, default='1', help="The voice ID or file path to clone the voice from. Default is '1'.") 42 | @click.option('-s', "--speed", type=float, default=1, help="The speed multiplier for the speech. Default is 1 (normal speed).") 43 | @click.option('-o', "--output", type=str, default='output.wav', help="The file path to save the synthesized audio. Default is 'output.wav'.") 44 | def synthesis(text: str, voice: str, speed: float, output: str): 45 | """Synthesis audio from text and save to file. 46 | 47 | Usage: viettts synthesis --text 'Xin chào VietTTS' --voice nu-nhe-nhang --voice 8 --speed 1.2 --output test_nu-nhe-nhang.wav 48 | """ 49 | logger.info("Starting synthesis") 50 | st = time.perf_counter() 51 | if not text: 52 | logger.error('text must not empty') 53 | return 54 | 55 | if speed > 2 or speed < 0.5: 56 | logger.error(f'speed must in range 0.5-2.0') 57 | return 58 | 59 | if not os.path.exists(voice): 60 | voice_map = load_voices(AUDIO_DIR) 61 | if voice.isdigit(): 62 | voice = list(voice_map.values())[int(voice)] 63 | else: 64 | voice = voice_map.get(voice) 65 | 66 | if not os.path.exists(voice): 67 | logger.error(f'voice is not available. Use --voice or run `viettts show-voices` to get available voices.') 68 | return 69 | 70 | logger.info('Loading model') 71 | tts = TTS(model_dir=MODEL_DIR) 72 | 73 | logger.info('Loading voice') 74 | voice = load_prompt_speech_from_file(voice) 75 | 76 | logger.info('Processing') 77 | tts.tts_to_file(text, voice, speed, output) 78 | 79 | et = time.perf_counter() 80 | logger.success(f"Saved to: {output} [time cost={et-st:.2f}s]") 81 | 82 | 83 | @click.command('show-voices') 84 | def show_voice(): 85 | """Print all available voices. 86 | 87 | Usage: viettts show-voices 88 | """ 89 | voice_map = load_voices(AUDIO_DIR) 90 | console = Console() 91 | table = Table(show_header=True, header_style="green", show_lines=False) 92 | table.add_column("Voice ID", width=10) 93 | table.add_column("Voice Name", width=30) 94 | table.add_column("File", justify="left") 95 | 96 | for i, (voice_name, voice_path) in enumerate(voice_map.items()): 97 | table.add_row(str(i+1), voice_name, voice_path) 98 | 99 | console.print(table) 100 | 101 | 102 | @click.group() 103 | def cli(): 104 | """ 105 | VietTTS CLI v0.1.0 106 | 107 | Vietnamese Text To Speech and Voice Clone 108 | License: Apache 2.0 - Author: 109 | """ 110 | pass 111 | 112 | cli.add_command(start_server) 113 | cli.add_command(synthesis) 114 | cli.add_command(show_voice) -------------------------------------------------------------------------------- /viettts/flow/flow.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | from typing import Dict, Optional 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | from omegaconf import DictConfig 8 | from viettts.utils.mask import make_pad_mask 9 | 10 | 11 | class MaskedDiffWithXvec(torch.nn.Module): 12 | def __init__(self, 13 | input_size: int = 512, 14 | output_size: int = 80, 15 | spk_embed_dim: int = 192, 16 | output_type: str = "mel", 17 | vocab_size: int = 4096, 18 | input_frame_rate: int = 50, 19 | only_mask_loss: bool = True, 20 | encoder: torch.nn.Module = None, 21 | length_regulator: torch.nn.Module = None, 22 | decoder: torch.nn.Module = None, 23 | decoder_conf: Dict = { 24 | 'in_channels': 240, 25 | 'out_channel': 80, 26 | 'spk_emb_dim': 80, 27 | 'n_spks': 1, 28 | 'cfm_params': DictConfig({ 29 | 'sigma_min': 1e-06, 30 | 'solver': 'euler', 31 | 't_scheduler': 'cosine', 32 | 'training_cfg_rate': 0.2, 33 | 'inference_cfg_rate': 0.7, 34 | 'reg_loss_type': 'l1' 35 | }), 36 | 'decoder_params': { 37 | 'channels': [256, 256], 38 | 'dropout': 0.0, 39 | 'attention_head_dim': 64, 40 | 'n_blocks': 4, 41 | 'num_mid_blocks': 12, 42 | 'num_heads': 8, 43 | 'act_fn': 'gelu' 44 | } 45 | }, 46 | mel_feat_conf: Dict = { 47 | 'n_fft': 1024, 48 | 'num_mels': 80, 49 | 'sampling_rate': 22050, 50 | 'hop_size': 256, 51 | 'win_size': 1024, 52 | 'fmin': 0, 53 | 'fmax': 8000 54 | } 55 | ): 56 | super().__init__() 57 | self.input_size = input_size 58 | self.output_size = output_size 59 | self.decoder_conf = decoder_conf 60 | self.mel_feat_conf = mel_feat_conf 61 | self.vocab_size = vocab_size 62 | self.output_type = output_type 63 | self.input_frame_rate = input_frame_rate 64 | logging.info(f"input frame rate={self.input_frame_rate}") 65 | self.input_embedding = nn.Embedding(vocab_size, input_size) 66 | self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size) 67 | self.encoder = encoder 68 | self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size) 69 | self.decoder = decoder 70 | self.length_regulator = length_regulator 71 | self.only_mask_loss = only_mask_loss 72 | 73 | def forward( 74 | self, 75 | batch: dict, 76 | device: torch.device, 77 | ) -> Dict[str, Optional[torch.Tensor]]: 78 | token = batch['speech_token'].to(device) 79 | token_len = batch['speech_token_len'].to(device) 80 | feat = batch['speech_feat'].to(device) 81 | feat_len = batch['speech_feat_len'].to(device) 82 | embedding = batch['embedding'].to(device) 83 | 84 | # xvec projection 85 | embedding = F.normalize(embedding, dim=1) 86 | embedding = self.spk_embed_affine_layer(embedding) 87 | 88 | # concat text and prompt_text 89 | mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device) 90 | token = self.input_embedding(torch.clamp(token, min=0)) * mask 91 | 92 | # text encode 93 | h, h_lengths = self.encoder(token, token_len) 94 | h = self.encoder_proj(h) 95 | h, h_lengths = self.length_regulator(h, feat_len) 96 | 97 | # get conditions 98 | conds = torch.zeros(feat.shape, device=token.device) 99 | for i, j in enumerate(feat_len): 100 | if random.random() < 0.5: 101 | continue 102 | index = random.randint(0, int(0.3 * j)) 103 | conds[i, :index] = feat[i, :index] 104 | conds = conds.transpose(1, 2) 105 | 106 | mask = (~make_pad_mask(feat_len)).to(h) 107 | feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1) 108 | loss, _ = self.decoder.compute_loss( 109 | feat.transpose(1, 2).contiguous(), 110 | mask.unsqueeze(1), 111 | h.transpose(1, 2).contiguous(), 112 | embedding, 113 | cond=conds 114 | ) 115 | return {'loss': loss} 116 | 117 | @torch.inference_mode() 118 | def inference(self, 119 | token, 120 | token_len, 121 | prompt_token, 122 | prompt_token_len, 123 | prompt_feat, 124 | prompt_feat_len, 125 | embedding): 126 | assert token.shape[0] == 1 127 | # xvec projection 128 | embedding = F.normalize(embedding, dim=1) 129 | embedding = self.spk_embed_affine_layer(embedding) 130 | 131 | # concat text and prompt_text 132 | token_len1, token_len2 = prompt_token.shape[1], token.shape[1] 133 | token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len 134 | mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding) 135 | token = self.input_embedding(torch.clamp(token, min=0)) * mask 136 | 137 | # text encode 138 | h, h_lengths = self.encoder(token, token_len) 139 | h = self.encoder_proj(h) 140 | mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256) 141 | h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate) 142 | 143 | # get conditions 144 | conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device) 145 | conds[:, :mel_len1] = prompt_feat 146 | conds = conds.transpose(1, 2) 147 | 148 | mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h) 149 | feat = self.decoder( 150 | mu=h.transpose(1, 2).contiguous(), 151 | mask=mask.unsqueeze(1), 152 | spks=embedding, 153 | cond=conds, 154 | n_timesteps=10 155 | ) 156 | feat = feat[:, :, mel_len1:] 157 | assert feat.shape[2] == mel_len2 158 | return feat 159 | -------------------------------------------------------------------------------- /viettts/flow/flow_matching.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from viettts.flow.decoder import Decoder 7 | 8 | 9 | class BASECFM(torch.nn.Module, ABC): 10 | def __init__( 11 | self, 12 | n_feats, 13 | cfm_params, 14 | n_spks=1, 15 | spk_emb_dim=128, 16 | ): 17 | super().__init__() 18 | self.n_feats = n_feats 19 | self.n_spks = n_spks 20 | self.spk_emb_dim = spk_emb_dim 21 | self.solver = cfm_params.solver 22 | if hasattr(cfm_params, "sigma_min"): 23 | self.sigma_min = cfm_params.sigma_min 24 | else: 25 | self.sigma_min = 1e-4 26 | 27 | self.estimator = None 28 | 29 | @torch.inference_mode() 30 | def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): 31 | """Forward diffusion 32 | 33 | Args: 34 | mu (torch.Tensor): output of encoder 35 | shape: (batch_size, n_feats, mel_timesteps) 36 | mask (torch.Tensor): output_mask 37 | shape: (batch_size, 1, mel_timesteps) 38 | n_timesteps (int): number of diffusion steps 39 | temperature (float, optional): temperature for scaling noise. Defaults to 1.0. 40 | spks (torch.Tensor, optional): speaker ids. Defaults to None. 41 | shape: (batch_size, spk_emb_dim) 42 | cond: Not used but kept for future purposes 43 | 44 | Returns: 45 | sample: generated mel-spectrogram 46 | shape: (batch_size, n_feats, mel_timesteps) 47 | """ 48 | z = torch.randn_like(mu) * temperature 49 | t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) 50 | return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) 51 | 52 | def solve_euler(self, x, t_span, mu, mask, spks, cond): 53 | """ 54 | Fixed euler solver for ODEs. 55 | Args: 56 | x (torch.Tensor): random noise 57 | t_span (torch.Tensor): n_timesteps interpolated 58 | shape: (n_timesteps + 1,) 59 | mu (torch.Tensor): output of encoder 60 | shape: (batch_size, n_feats, mel_timesteps) 61 | mask (torch.Tensor): output_mask 62 | shape: (batch_size, 1, mel_timesteps) 63 | spks (torch.Tensor, optional): speaker ids. Defaults to None. 64 | shape: (batch_size, spk_emb_dim) 65 | cond: Not used but kept for future purposes 66 | """ 67 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 68 | 69 | # I am storing this because I can later plot it by putting a debugger here and saving it to a file 70 | # Or in future might add like a return_all_steps flag 71 | sol = [] 72 | 73 | for step in range(1, len(t_span)): 74 | dphi_dt = self.estimator(x, mask, mu, t, spks, cond) 75 | 76 | x = x + dt * dphi_dt 77 | t = t + dt 78 | sol.append(x) 79 | if step < len(t_span) - 1: 80 | dt = t_span[step + 1] - t 81 | 82 | return sol[-1] 83 | 84 | def compute_loss(self, x1, mask, mu, spks=None, cond=None): 85 | """Computes diffusion loss 86 | 87 | Args: 88 | x1 (torch.Tensor): Target 89 | shape: (batch_size, n_feats, mel_timesteps) 90 | mask (torch.Tensor): target mask 91 | shape: (batch_size, 1, mel_timesteps) 92 | mu (torch.Tensor): output of encoder 93 | shape: (batch_size, n_feats, mel_timesteps) 94 | spks (torch.Tensor, optional): speaker embedding. Defaults to None. 95 | shape: (batch_size, spk_emb_dim) 96 | 97 | Returns: 98 | loss: conditional flow matching loss 99 | y: conditional flow 100 | shape: (batch_size, n_feats, mel_timesteps) 101 | """ 102 | b, _, t = mu.shape 103 | 104 | # random timestep 105 | t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) 106 | # sample noise p(x_0) 107 | z = torch.randn_like(x1) 108 | 109 | y = (1 - (1 - self.sigma_min) * t) * z + t * x1 110 | u = x1 - (1 - self.sigma_min) * z 111 | 112 | loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / ( 113 | torch.sum(mask) * u.shape[1] 114 | ) 115 | return loss, y 116 | 117 | 118 | class CFM(BASECFM): 119 | def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64): 120 | super().__init__( 121 | n_feats=in_channels, 122 | cfm_params=cfm_params, 123 | n_spks=n_spks, 124 | spk_emb_dim=spk_emb_dim, 125 | ) 126 | 127 | in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0) 128 | # Just change the architecture of the estimator here 129 | self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params) 130 | 131 | 132 | class ConditionalCFM(BASECFM): 133 | def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None): 134 | super().__init__( 135 | n_feats=in_channels, 136 | cfm_params=cfm_params, 137 | n_spks=n_spks, 138 | spk_emb_dim=spk_emb_dim, 139 | ) 140 | self.t_scheduler = cfm_params.t_scheduler 141 | self.training_cfg_rate = cfm_params.training_cfg_rate 142 | self.inference_cfg_rate = cfm_params.inference_cfg_rate 143 | in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0) 144 | # Just change the architecture of the estimator here 145 | self.estimator = estimator 146 | 147 | @torch.inference_mode() 148 | def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): 149 | """Forward diffusion 150 | 151 | Args: 152 | mu (torch.Tensor): output of encoder 153 | shape: (batch_size, n_feats, mel_timesteps) 154 | mask (torch.Tensor): output_mask 155 | shape: (batch_size, 1, mel_timesteps) 156 | n_timesteps (int): number of diffusion steps 157 | temperature (float, optional): temperature for scaling noise. Defaults to 1.0. 158 | spks (torch.Tensor, optional): speaker ids. Defaults to None. 159 | shape: (batch_size, spk_emb_dim) 160 | cond: Not used but kept for future purposes 161 | 162 | Returns: 163 | sample: generated mel-spectrogram 164 | shape: (batch_size, n_feats, mel_timesteps) 165 | """ 166 | z = torch.randn_like(mu) * temperature 167 | t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) 168 | if self.t_scheduler == 'cosine': 169 | t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) 170 | return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) 171 | 172 | def solve_euler(self, x, t_span, mu, mask, spks, cond): 173 | """ 174 | Fixed euler solver for ODEs. 175 | Args: 176 | x (torch.Tensor): random noise 177 | t_span (torch.Tensor): n_timesteps interpolated 178 | shape: (n_timesteps + 1,) 179 | mu (torch.Tensor): output of encoder 180 | shape: (batch_size, n_feats, mel_timesteps) 181 | mask (torch.Tensor): output_mask 182 | shape: (batch_size, 1, mel_timesteps) 183 | spks (torch.Tensor, optional): speaker ids. Defaults to None. 184 | shape: (batch_size, spk_emb_dim) 185 | cond: Not used but kept for future purposes 186 | """ 187 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 188 | t = t.unsqueeze(dim=0) 189 | 190 | # I am storing this because I can later plot it by putting a debugger here and saving it to a file 191 | # Or in future might add like a return_all_steps flag 192 | sol = [] 193 | 194 | for step in range(1, len(t_span)): 195 | dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond) 196 | # Classifier-Free Guidance inference introduced in VoiceBox 197 | if self.inference_cfg_rate > 0: 198 | cfg_dphi_dt = self.forward_estimator( 199 | x, mask, 200 | torch.zeros_like(mu), t, 201 | torch.zeros_like(spks) if spks is not None else None, 202 | torch.zeros_like(cond) 203 | ) 204 | dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - 205 | self.inference_cfg_rate * cfg_dphi_dt) 206 | x = x + dt * dphi_dt 207 | t = t + dt 208 | sol.append(x) 209 | if step < len(t_span) - 1: 210 | dt = t_span[step + 1] - t 211 | 212 | return sol[-1] 213 | 214 | def forward_estimator(self, x, mask, mu, t, spks, cond): 215 | if isinstance(self.estimator, torch.nn.Module): 216 | return self.estimator.forward(x, mask, mu, t, spks, cond) 217 | else: 218 | ort_inputs = { 219 | 'x': x.cpu().numpy(), 220 | 'mask': mask.cpu().numpy(), 221 | 'mu': mu.cpu().numpy(), 222 | 't': t.cpu().numpy(), 223 | 'spks': spks.cpu().numpy(), 224 | 'cond': cond.cpu().numpy() 225 | } 226 | output = self.estimator.run(None, ort_inputs)[0] 227 | return torch.tensor(output, dtype=x.dtype, device=x.device) 228 | 229 | def compute_loss(self, x1, mask, mu, spks=None, cond=None): 230 | """Computes diffusion loss 231 | 232 | Args: 233 | x1 (torch.Tensor): Target 234 | shape: (batch_size, n_feats, mel_timesteps) 235 | mask (torch.Tensor): target mask 236 | shape: (batch_size, 1, mel_timesteps) 237 | mu (torch.Tensor): output of encoder 238 | shape: (batch_size, n_feats, mel_timesteps) 239 | spks (torch.Tensor, optional): speaker embedding. Defaults to None. 240 | shape: (batch_size, spk_emb_dim) 241 | 242 | Returns: 243 | loss: conditional flow matching loss 244 | y: conditional flow 245 | shape: (batch_size, n_feats, mel_timesteps) 246 | """ 247 | b, _, t = mu.shape 248 | 249 | # random timestep 250 | t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) 251 | if self.t_scheduler == 'cosine': 252 | t = 1 - torch.cos(t * 0.5 * torch.pi) 253 | # sample noise p(x_0) 254 | z = torch.randn_like(x1) 255 | 256 | y = (1 - (1 - self.sigma_min) * t) * z + t * x1 257 | u = x1 - (1 - self.sigma_min) * z 258 | 259 | # during training, we randomly drop condition to trade off mode coverage and sample fidelity 260 | if self.training_cfg_rate > 0: 261 | cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate 262 | mu = mu * cfg_mask.view(-1, 1, 1) 263 | spks = spks * cfg_mask.view(-1, 1) 264 | cond = cond * cfg_mask.view(-1, 1, 1) 265 | 266 | pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond) 267 | loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1]) 268 | return loss, y 269 | -------------------------------------------------------------------------------- /viettts/flow/length_regulator.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import torch.nn as nn 3 | import torch 4 | from torch.nn import functional as F 5 | from viettts.utils.mask import make_pad_mask 6 | 7 | 8 | class InterpolateRegulator(nn.Module): 9 | def __init__( 10 | self, 11 | channels: int, 12 | sampling_ratios: Tuple, 13 | out_channels: int = None, 14 | groups: int = 1, 15 | ): 16 | super().__init__() 17 | self.sampling_ratios = sampling_ratios 18 | out_channels = out_channels or channels 19 | model = nn.ModuleList([]) 20 | if len(sampling_ratios) > 0: 21 | for _ in sampling_ratios: 22 | module = nn.Conv1d(channels, channels, 3, 1, 1) 23 | norm = nn.GroupNorm(groups, channels) 24 | act = nn.Mish() 25 | model.extend([module, norm, act]) 26 | model.append( 27 | nn.Conv1d(channels, out_channels, 1, 1) 28 | ) 29 | self.model = nn.Sequential(*model) 30 | 31 | def forward(self, x, ylens=None): 32 | # x in (B, T, D) 33 | mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1) 34 | x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear') 35 | out = self.model(x).transpose(1, 2).contiguous() 36 | olens = ylens 37 | return out * mask, olens 38 | 39 | def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50): 40 | # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel 41 | # x in (B, T, D) 42 | if x2.shape[1] > 40: 43 | x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear') 44 | x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2, 45 | mode='linear') 46 | x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear') 47 | x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2) 48 | else: 49 | x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear') 50 | if x1.shape[1] != 0: 51 | x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear') 52 | x = torch.concat([x1, x2], dim=2) 53 | else: 54 | x = x2 55 | out = self.model(x).transpose(1, 2).contiguous() 56 | return out, mel_len1 + mel_len2 57 | -------------------------------------------------------------------------------- /viettts/frontend.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchaudio 4 | import whisper 5 | import onnxruntime 6 | import numpy as np 7 | import torchaudio.compliance.kaldi as kaldi 8 | from typing import Callable, List, Union 9 | from functools import partial 10 | from loguru import logger 11 | 12 | from viettts.utils.frontend_utils import split_text, normalize_text, mel_spectrogram 13 | from viettts.tokenizer.tokenizer import get_tokenizer 14 | 15 | class TTSFrontEnd: 16 | def __init__( 17 | self, 18 | speech_embedding_model: str, 19 | speech_tokenizer_model: str, 20 | ): 21 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 22 | self.tokenizer = get_tokenizer() 23 | option = onnxruntime.SessionOptions() 24 | option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL 25 | option.intra_op_num_threads = 1 26 | self.speech_embedding_session = onnxruntime.InferenceSession( 27 | speech_embedding_model, 28 | sess_options=option, 29 | providers=["CPUExecutionProvider"] 30 | ) 31 | self.speech_tokenizer_session = onnxruntime.InferenceSession( 32 | speech_tokenizer_model, 33 | sess_options=option, 34 | providers=["CUDAExecutionProvider" if torch.cuda.is_available() else "CPUExecutionProvider"] 35 | ) 36 | self.spk2info = {} 37 | 38 | def _extract_text_token(self, text: str): 39 | text_token = self.tokenizer.encode(text, allowed_special='all') 40 | text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device) 41 | text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device) 42 | return text_token, text_token_len 43 | 44 | def _extract_speech_token(self, speech: torch.Tensor): 45 | if speech.shape[1] / 16000 > 30: 46 | speech = speech[:, :int(16000 * 30)] 47 | feat = whisper.log_mel_spectrogram(speech, n_mels=128) 48 | speech_token = self.speech_tokenizer_session.run( 49 | None, 50 | {self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(), 51 | self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)} 52 | )[0].flatten().tolist() 53 | speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device) 54 | speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device) 55 | return speech_token, speech_token_len 56 | 57 | def _extract_spk_embedding(self, speech: torch.Tensor): 58 | feat = kaldi.fbank( 59 | waveform=speech, 60 | num_mel_bins=80, 61 | dither=0, 62 | sample_frequency=16000 63 | ) 64 | feat = feat - feat.mean(dim=0, keepdim=True) 65 | embedding = self.speech_embedding_session.run( 66 | None, 67 | {self.speech_embedding_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()} 68 | )[0].flatten().tolist() 69 | embedding = torch.tensor([embedding]).to(self.device) 70 | return embedding 71 | 72 | def _extract_speech_feat(self, speech: torch.Tensor): 73 | speech_feat = mel_spectrogram( 74 | y=speech, 75 | n_fft=1024, 76 | num_mels=80, 77 | sampling_rate=22050, 78 | hop_size=256, 79 | win_size=1024, 80 | fmin=0, 81 | fmax=8000, 82 | center=False 83 | ).squeeze(dim=0).transpose(0, 1).to(self.device) 84 | speech_feat = speech_feat.unsqueeze(dim=0) 85 | speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device) 86 | return speech_feat, speech_feat_len 87 | 88 | def preprocess_text(self, text, split=True) -> Union[str, List[str]]: 89 | text = normalize_text(text) 90 | if split: 91 | text = list(split_text( 92 | text=text, 93 | tokenize=partial(self.tokenizer.encode, allowed_special='all'), 94 | token_max_n=30, 95 | token_min_n=10, 96 | merge_len=5, 97 | comma_split=False 98 | )) 99 | return text 100 | 101 | def frontend_tts( 102 | self, 103 | text: str, 104 | prompt_speech_16k: Union[np.ndarray, torch.Tensor] 105 | ) -> dict: 106 | if isinstance(prompt_speech_16k, np.ndarray): 107 | prompt_speech_16k = torch.from_numpy(prompt_speech_16k) 108 | 109 | text_token, text_token_len = self._extract_text_token(text) 110 | speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k) 111 | prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k) 112 | speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22050) 113 | embedding = self._extract_spk_embedding(prompt_speech_16k) 114 | 115 | model_input = { 116 | 'text': text_token, 117 | 'text_len': text_token_len, 118 | 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len, 119 | 'prompt_speech_feat': speech_feat, 120 | 'prompt_speech_feat_len': speech_feat_len, 121 | 'llm_embedding': embedding, 122 | 'flow_embedding': embedding 123 | } 124 | return model_input 125 | 126 | 127 | def frontend_vc( 128 | self, 129 | source_speech_16k: Union[np.ndarray, torch.Tensor], 130 | prompt_speech_16k: Union[np.ndarray, torch.Tensor] 131 | ) -> dict: 132 | if isinstance(source_speech_16k, np.ndarray): 133 | source_speech_16k = torch.from_numpy(source_speech_16k) 134 | if isinstance(prompt_speech_16k, np.ndarray): 135 | prompt_speech_16k = torch.from_numpy(prompt_speech_16k) 136 | 137 | prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k) 138 | prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k) 139 | prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_22050) 140 | embedding = self._extract_spk_embedding(prompt_speech_16k) 141 | source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k) 142 | model_input = { 143 | 'source_speech_token': source_speech_token, 144 | 'source_speech_token_len': source_speech_token_len, 145 | 'flow_prompt_speech_token': prompt_speech_token, 146 | 'flow_prompt_speech_token_len': prompt_speech_token_len, 147 | 'prompt_speech_feat': prompt_speech_feat, 148 | 'prompt_speech_feat_len': prompt_speech_feat_len, 149 | 'flow_embedding': embedding 150 | } 151 | return model_input 152 | -------------------------------------------------------------------------------- /viettts/hifigan/f0_predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils import weight_norm 4 | 5 | 6 | class ConvRNNF0Predictor(nn.Module): 7 | def __init__(self, 8 | num_class: int = 1, 9 | in_channels: int = 80, 10 | cond_channels: int = 512 11 | ): 12 | super().__init__() 13 | 14 | self.num_class = num_class 15 | self.condnet = nn.Sequential( 16 | weight_norm( 17 | nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1) 18 | ), 19 | nn.ELU(), 20 | weight_norm( 21 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 22 | ), 23 | nn.ELU(), 24 | weight_norm( 25 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 26 | ), 27 | nn.ELU(), 28 | weight_norm( 29 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 30 | ), 31 | nn.ELU(), 32 | weight_norm( 33 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 34 | ), 35 | nn.ELU(), 36 | ) 37 | self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class) 38 | 39 | def forward(self, x: torch.Tensor) -> torch.Tensor: 40 | x = self.condnet(x) 41 | x = x.transpose(1, 2) 42 | return torch.abs(self.classifier(x).squeeze(-1)) 43 | -------------------------------------------------------------------------------- /viettts/llm/llm.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Callable, List, Generator 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from torch.nn.utils.rnn import pad_sequence, unpad_sequence 6 | from viettts.utils.common import IGNORE_ID 7 | from viettts.transformer.label_smoothing_loss import LabelSmoothingLoss 8 | from viettts.utils.common import th_accuracy 9 | 10 | 11 | class TransformerLM(torch.nn.Module): 12 | def __init__( 13 | self, 14 | text_encoder_input_size: int, 15 | llm_input_size: int, 16 | llm_output_size: int, 17 | text_token_size: int, 18 | speech_token_size: int, 19 | text_encoder: torch.nn.Module, 20 | llm: torch.nn.Module, 21 | sampling: Callable, 22 | length_normalized_loss: bool = True, 23 | lsm_weight: float = 0.0, 24 | spk_embed_dim: int = 192, 25 | ): 26 | super().__init__() 27 | self.llm_input_size = llm_input_size 28 | self.speech_token_size = speech_token_size 29 | # 1. build text token inputs related modules 30 | self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size) 31 | self.text_encoder = text_encoder 32 | self.text_encoder_affine_layer = nn.Linear( 33 | self.text_encoder.output_size(), 34 | llm_input_size 35 | ) 36 | 37 | # 2. build speech token language model related modules 38 | self.sos_eos = 0 39 | self.task_id = 1 40 | self.llm_embedding = torch.nn.Embedding(2, llm_input_size) 41 | self.llm = llm 42 | self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1) 43 | self.criterion_ce = LabelSmoothingLoss( 44 | size=speech_token_size + 1, 45 | padding_idx=IGNORE_ID, 46 | smoothing=lsm_weight, 47 | normalize_length=length_normalized_loss, 48 | ) 49 | 50 | # 3. [Optional] build speech token related modules 51 | self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size) 52 | self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size) 53 | 54 | # 4. sampling method 55 | self.sampling = sampling 56 | 57 | def encode( 58 | self, 59 | text: torch.Tensor, 60 | text_lengths: torch.Tensor, 61 | ): 62 | encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1) 63 | encoder_out_lens = encoder_mask.squeeze(1).sum(1) 64 | encoder_out = self.text_encoder_affine_layer(encoder_out) 65 | return encoder_out, encoder_out_lens 66 | 67 | def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len): 68 | text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True) 69 | speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True) 70 | lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0) 71 | for i in range(len(text_token))] 72 | lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32) 73 | lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID) 74 | return lm_input, lm_input_len 75 | 76 | def forward( 77 | self, 78 | batch: dict, 79 | device: torch.device, 80 | ) -> Dict[str, Optional[torch.Tensor]]: 81 | """ 82 | Args: 83 | text: (B, L, D) 84 | text_lengths: (B,) 85 | audio: (B, T, N) or (B, T) 86 | audio_lengths: (B,) 87 | """ 88 | text_token = batch['text_token'].to(device) 89 | text_token_len = batch['text_token_len'].to(device) 90 | speech_token = batch['speech_token'].to(device) 91 | speech_token_len = batch['speech_token_len'].to(device) 92 | embedding = batch['embedding'].to(device) 93 | 94 | # 1. prepare llm_target 95 | lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + 96 | [self.speech_token_size]) for i in range(text_token.size(0))] 97 | lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device) 98 | 99 | # 1. encode text_token 100 | text_token = self.text_embedding(text_token) 101 | text_token, text_token_len = self.encode(text_token, text_token_len) 102 | 103 | # 2. embedding projection 104 | embedding = F.normalize(embedding, dim=1) 105 | embedding = self.spk_embed_affine_layer(embedding) 106 | embedding = embedding.unsqueeze(1) 107 | 108 | # 3. eos and task_id 109 | sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) 110 | task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) 111 | 112 | # 4. encode speech_token 113 | speech_token = self.speech_embedding(speech_token) 114 | 115 | # 5. unpad and pad 116 | lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len, 117 | task_id_emb, speech_token, speech_token_len) 118 | 119 | # 6. run lm forward 120 | lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device)) 121 | logits = self.llm_decoder(lm_output) 122 | loss = self.criterion_ce(logits, lm_target) 123 | acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID) 124 | return {'loss': loss, 'acc': acc} 125 | 126 | def sampling_ids( 127 | self, 128 | weighted_scores: torch.Tensor, 129 | decoded_tokens: List, 130 | sampling: int, 131 | ignore_eos: bool = True, 132 | ): 133 | while True: 134 | top_ids = self.sampling(weighted_scores, decoded_tokens, sampling) 135 | if (not ignore_eos) or (self.speech_token_size not in top_ids): 136 | break 137 | return top_ids 138 | 139 | @torch.inference_mode() 140 | def inference( 141 | self, 142 | text: torch.Tensor, 143 | text_len: torch.Tensor, 144 | prompt_text: torch.Tensor, 145 | prompt_text_len: torch.Tensor, 146 | prompt_speech_token: torch.Tensor, 147 | prompt_speech_token_len: torch.Tensor, 148 | embedding: torch.Tensor, 149 | sampling: int = 25, 150 | max_token_text_ratio: float = 20, 151 | min_token_text_ratio: float = 2, 152 | ) -> Generator[torch.Tensor, None, None]: 153 | device = text.device 154 | text = torch.concat([prompt_text, text], dim=1) 155 | text_len += prompt_text_len 156 | text = self.text_embedding(text) 157 | 158 | # 1. encode text 159 | text, text_len = self.encode(text, text_len) 160 | 161 | # 2. encode embedding 162 | if embedding.shape[0] != 0: 163 | embedding = F.normalize(embedding, dim=1) 164 | embedding = self.spk_embed_affine_layer(embedding) 165 | embedding = embedding.unsqueeze(dim=1) 166 | else: 167 | embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) 168 | 169 | # 3. concat llm_input 170 | sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) 171 | task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) 172 | if prompt_speech_token_len != 0: 173 | prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) 174 | else: 175 | prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) 176 | lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1) 177 | 178 | # 4. cal min/max_length 179 | min_len = int((text_len - prompt_text_len) * min_token_text_ratio) 180 | max_len = int((text_len - prompt_text_len) * max_token_text_ratio) 181 | 182 | # 5. step by step decode 183 | out_tokens = [] 184 | offset = 0 185 | att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device) 186 | for i in range(max_len): 187 | y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1, 188 | att_cache=att_cache, cnn_cache=cnn_cache, 189 | att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), 190 | device=lm_input.device)).to(torch.bool)) 191 | logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) 192 | top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() 193 | if top_ids == self.speech_token_size: 194 | break 195 | # in stream mode, yield token one by one 196 | yield top_ids 197 | out_tokens.append(top_ids) 198 | offset += lm_input.size(1) 199 | lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) 200 | -------------------------------------------------------------------------------- /viettts/model.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | import torch 3 | import numpy as np 4 | import threading 5 | import time 6 | from torch.nn import functional as F 7 | from contextlib import nullcontext 8 | import uuid 9 | from viettts.utils.common import fade_in_out_audio 10 | 11 | class TTSModel: 12 | def __init__( 13 | self, 14 | llm: torch.nn.Module, 15 | flow: torch.nn.Module, 16 | hift: torch.nn.Module 17 | ): 18 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 19 | self.llm = llm 20 | self.flow = flow 21 | self.hift = hift 22 | self.token_min_hop_len = 2 * self.flow.input_frame_rate 23 | self.token_max_hop_len = 4 * self.flow.input_frame_rate 24 | self.token_overlap_len = 20 25 | # mel fade in out 26 | self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256) 27 | self.mel_window = np.hamming(2 * self.mel_overlap_len) 28 | # hift cache 29 | self.mel_cache_len = 20 30 | self.source_cache_len = int(self.mel_cache_len * 256) 31 | # speech fade in out 32 | self.speech_window = np.hamming(2 * self.source_cache_len) 33 | # rtf and decoding related 34 | self.stream_scale_factor = 1 35 | assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf' 36 | self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() 37 | self.lock = threading.Lock() 38 | # dict used to store session related variable 39 | self.tts_speech_token_dict = {} 40 | self.llm_end_dict = {} 41 | self.mel_overlap_dict = {} 42 | self.hift_cache_dict = {} 43 | 44 | def load(self, llm_model, flow_model, hift_model): 45 | self.llm.load_state_dict(torch.load(llm_model, map_location=self.device)) 46 | self.llm.to(self.device).eval() 47 | self.llm.half() 48 | self.flow.load_state_dict(torch.load(flow_model, map_location=self.device)) 49 | self.flow.to(self.device).eval() 50 | self.hift.load_state_dict(torch.load(hift_model, map_location=self.device)) 51 | self.hift.to(self.device).eval() 52 | 53 | def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model): 54 | llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device) 55 | self.llm.text_encoder = llm_text_encoder 56 | llm_llm = torch.jit.load(llm_llm_model, map_location=self.device) 57 | self.llm.llm = llm_llm 58 | flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) 59 | self.flow.encoder = flow_encoder 60 | 61 | def load_onnx(self, flow_decoder_estimator_model): 62 | import onnxruntime 63 | option = onnxruntime.SessionOptions() 64 | option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL 65 | option.intra_op_num_threads = 1 66 | providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] 67 | del self.flow.decoder.estimator 68 | self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers) 69 | 70 | def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): 71 | with self.llm_context: 72 | for i in self.llm.inference( 73 | text=text.to(self.device), 74 | text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device), 75 | prompt_text=prompt_text.to(self.device), 76 | prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), 77 | prompt_speech_token=llm_prompt_speech_token.to(self.device), 78 | prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), 79 | embedding=llm_embedding.to(self.device).half() 80 | ): 81 | self.tts_speech_token_dict[uuid].append(i) 82 | self.llm_end_dict[uuid] = True 83 | 84 | def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0): 85 | tts_mel = self.flow.inference( 86 | token=token.to(self.device), 87 | token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), 88 | prompt_token=prompt_token.to(self.device), 89 | prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), 90 | prompt_feat=prompt_feat.to(self.device), 91 | prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), 92 | embedding=embedding.to(self.device) 93 | ) 94 | 95 | if self.hift_cache_dict[uuid] is not None: 96 | hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source'] 97 | tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2) 98 | else: 99 | hift_cache_source = torch.zeros(1, 1, 0) 100 | 101 | if finalize is False: 102 | self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:] 103 | tts_mel = tts_mel[:, :, :-self.mel_overlap_len] 104 | tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source) 105 | self.hift_cache_dict[uuid] = { 106 | 'mel': tts_mel[:, :, -self.mel_cache_len:], 107 | 'source': tts_source[:, :, -self.source_cache_len:], 108 | 'speech': tts_speech[:, -self.source_cache_len:] 109 | } 110 | tts_speech = tts_speech[:, :-self.source_cache_len] 111 | else: 112 | if speed != 1.0: 113 | assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode' 114 | tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear') 115 | tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source) 116 | 117 | tts_speech = fade_in_out_audio(tts_speech) 118 | return tts_speech 119 | 120 | def tts( 121 | self, 122 | text: str, 123 | flow_embedding: torch.Tensor, 124 | llm_embedding: torch.Tensor=torch.zeros(0, 192), 125 | prompt_text: torch.Tensor=torch.zeros(1, 0, dtype=torch.int32), 126 | llm_prompt_speech_token: torch.Tensor=torch.zeros(1, 0, dtype=torch.int32), 127 | flow_prompt_speech_token: torch.Tensor=torch.zeros(1, 0, dtype=torch.int32), 128 | prompt_speech_feat: torch.Tensor=torch.zeros(1, 0, 80), 129 | stream: bool=False, 130 | speed: float=1.0, 131 | **kwargs 132 | ): 133 | # this_uuid is used to track variables related to this inference thread 134 | this_uuid = str(uuid.uuid1()) 135 | with self.lock: 136 | self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False 137 | self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None 138 | 139 | p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) 140 | p.start() 141 | 142 | if stream: 143 | token_hop_len = self.token_min_hop_len 144 | while True: 145 | time.sleep(0.01) 146 | if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len: 147 | this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]).unsqueeze(dim=0) 148 | this_tts_speech = self.token2wav( 149 | token=this_tts_speech_token, 150 | prompt_token=flow_prompt_speech_token, 151 | prompt_feat=prompt_speech_feat, 152 | embedding=flow_embedding, 153 | uuid=this_uuid, 154 | finalize=False 155 | ) 156 | yield {'tts_speech': this_tts_speech.cpu()} 157 | with self.lock: 158 | self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:] 159 | # increase token_hop_len for better speech quality 160 | token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor)) 161 | if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len: 162 | break 163 | p.join() 164 | this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) 165 | this_tts_speech = self.token2wav( 166 | token=this_tts_speech_token, 167 | prompt_token=flow_prompt_speech_token, 168 | prompt_feat=prompt_speech_feat, 169 | embedding=flow_embedding, 170 | uuid=this_uuid, 171 | finalize=True 172 | ) 173 | yield {'tts_speech': this_tts_speech.cpu()} 174 | else: 175 | p.join() 176 | this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) 177 | this_tts_speech = self.token2wav( 178 | token=this_tts_speech_token, 179 | prompt_token=flow_prompt_speech_token, 180 | prompt_feat=prompt_speech_feat, 181 | embedding=flow_embedding, 182 | uuid=this_uuid, 183 | finalize=True, 184 | speed=speed 185 | ) 186 | yield {'tts_speech': this_tts_speech.cpu()} 187 | 188 | with self.lock: 189 | self.tts_speech_token_dict.pop(this_uuid) 190 | self.llm_end_dict.pop(this_uuid) 191 | self.mel_overlap_dict.pop(this_uuid) 192 | self.hift_cache_dict.pop(this_uuid) 193 | 194 | def vc( 195 | self, 196 | source_speech_token: torch.Tensor, 197 | flow_prompt_speech_token: torch.Tensor, 198 | prompt_speech_feat: torch.Tensor, 199 | flow_embedding: torch.Tensor, 200 | stream: bool=False, 201 | speed: float=1.0, 202 | **kwargs 203 | ): 204 | this_uuid = str(uuid.uuid1()) 205 | with self.lock: 206 | self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True 207 | self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None 208 | 209 | if stream: 210 | token_hop_len = self.token_min_hop_len 211 | while True: 212 | if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len: 213 | this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \ 214 | .unsqueeze(dim=0) 215 | this_tts_speech = self.token2wav( 216 | token=this_tts_speech_token, 217 | prompt_token=flow_prompt_speech_token, 218 | prompt_feat=prompt_speech_feat, 219 | embedding=flow_embedding, 220 | uuid=this_uuid, 221 | finalize=False 222 | ) 223 | yield {'tts_speech': this_tts_speech.cpu()} 224 | with self.lock: 225 | self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:] 226 | # increase token_hop_len for better speech quality 227 | token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor)) 228 | if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len: 229 | break 230 | 231 | # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None 232 | this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid], dim=1).unsqueeze(dim=0) 233 | this_tts_speech = self.token2wav( 234 | token=this_tts_speech_token, 235 | prompt_token=flow_prompt_speech_token, 236 | prompt_feat=prompt_speech_feat, 237 | embedding=flow_embedding, 238 | uuid=this_uuid, 239 | finalize=True 240 | ) 241 | yield {'tts_speech': this_tts_speech.cpu()} 242 | else: 243 | # deal with all tokens 244 | this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) 245 | this_tts_speech = self.token2wav( 246 | token=this_tts_speech_token, 247 | prompt_token=flow_prompt_speech_token, 248 | prompt_feat=prompt_speech_feat, 249 | embedding=flow_embedding, 250 | uuid=this_uuid, 251 | finalize=True, 252 | speed=speed 253 | ) 254 | yield {'tts_speech': this_tts_speech.cpu()} 255 | 256 | with self.lock: 257 | self.tts_speech_token_dict.pop(this_uuid) 258 | self.llm_end_dict.pop(this_uuid) 259 | self.mel_overlap_dict.pop(this_uuid) 260 | self.hift_cache_dict.pop(this_uuid) 261 | -------------------------------------------------------------------------------- /viettts/server.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import queue 4 | import random 5 | import subprocess 6 | import threading 7 | import wave 8 | 9 | import tempfile 10 | import shutil 11 | import requests 12 | import numpy as np 13 | from loguru import logger 14 | from datetime import datetime 15 | from typing import Any, List, Optional 16 | from pydantic import BaseModel 17 | from anyio import CapacityLimiter 18 | from anyio.lowlevel import RunVar 19 | from fastapi import FastAPI, UploadFile, Form, File, HTTPException 20 | from fastapi.responses import StreamingResponse, JSONResponse, PlainTextResponse, FileResponse 21 | from fastapi.middleware.cors import CORSMiddleware 22 | 23 | from viettts.tts import TTS 24 | from viettts.utils.file_utils import load_prompt_speech_from_file, load_voices 25 | 26 | 27 | VOICE_DIR = 'samples' 28 | VOICE_MAP = load_voices(VOICE_DIR) 29 | 30 | global tts_obj 31 | tts_obj = None 32 | 33 | 34 | app = FastAPI( 35 | title="VietTTS API", 36 | description=""" 37 | VietTTS API (https://github.com/dangvansam/viet-tts) 38 | Vietnamese Text To Speech and Voice Clone 39 | License: Apache 2.0 - Author: 40 | """ 41 | ) 42 | app.add_middleware( 43 | CORSMiddleware, 44 | allow_origins=["*"], 45 | allow_credentials=True, 46 | allow_methods=["*"], 47 | allow_headers=["*"]) 48 | 49 | 50 | def generate_data(model_output): 51 | audio = wav_chunk_header() 52 | for i in model_output: 53 | tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16) 54 | tts_audio = tts_audio.tobytes() 55 | audio += tts_audio 56 | yield audio 57 | 58 | 59 | class OpenAITTSRequest(BaseModel): 60 | input: str 61 | model: str = "tts-1" 62 | voice: str = random.choice(list(VOICE_MAP)) 63 | response_format: str = "wav" 64 | speed: float = 1.0 65 | 66 | class TTSRequest(BaseModel): 67 | text: str 68 | voice: str = random.choice(list(VOICE_MAP)) 69 | speed: float = 1.0 70 | 71 | def wav_chunk_header(sample_rate=22050, bit_depth=16, channels=1): 72 | buffer = io.BytesIO() 73 | with wave.open(buffer, "wb") as wav_file: 74 | wav_file.setnchannels(channels) 75 | wav_file.setsampwidth(bit_depth // 8) 76 | wav_file.setframerate(sample_rate) 77 | 78 | wav_header_bytes = buffer.getvalue() 79 | buffer.close() 80 | return wav_header_bytes 81 | 82 | 83 | @app.get("/", response_class=PlainTextResponse) 84 | async def root(): 85 | return 'VietTTS API' 86 | 87 | @app.get("/health", response_class=PlainTextResponse) 88 | async def health(): 89 | return 'VietTTS API is running...' 90 | 91 | @app.get("/voices") 92 | @app.get("/v1/voices") 93 | async def show_voices(): 94 | return list(VOICE_MAP.keys()) 95 | 96 | @app.post("/audio/speech") 97 | @app.post("/v1/audio/speech") 98 | async def openai_api_tts(tts_request: OpenAITTSRequest): 99 | logger.info(f"Received TTS request: {tts_request.dict()}") 100 | 101 | if tts_request.voice.isdigit(): 102 | voice_file = list(VOICE_MAP.values())[int(tts_request.voice)] 103 | else: 104 | voice_file = VOICE_MAP.get(tts_request.voice) 105 | 106 | if not voice_file: 107 | logger.error(f"Voice {tts_request.voice} not found") 108 | return PlainTextResponse(content="Voice not found", status_code=404) 109 | 110 | prompt_speech_16k = load_prompt_speech_from_file( 111 | filepath=voice_file, 112 | min_duration=3, 113 | max_duration=5 114 | ) 115 | # prompt_speech_16k = fade_in_out_audio(prompt_speech_16k) 116 | 117 | def build_ffmpeg_args(response_format, input_format, sample_rate=24000): 118 | if input_format == 'WAV': 119 | ffmpeg_args = ["ffmpeg", "-loglevel", "error", "-f", "WAV", "-i", "-"] 120 | else: 121 | ffmpeg_args = ["ffmpeg", "-loglevel", "error", "-f", input_format, "-ar", sample_rate, "-ac", "1", "-i", "-"] 122 | if response_format == "mp3": 123 | ffmpeg_args.extend(["-f", "mp3", "-c:a", "libmp3lame", "-ab", "64k"]) 124 | elif response_format == "opus": 125 | ffmpeg_args.extend(["-f", "ogg", "-c:a", "libopus"]) 126 | elif response_format == "aac": 127 | ffmpeg_args.extend(["-f", "adts", "-c:a", "aac", "-ab", "64k"]) 128 | elif response_format == "flac": 129 | ffmpeg_args.extend(["-f", "flac", "-c:a", "flac"]) 130 | elif response_format == "wav": 131 | ffmpeg_args.extend(["-f", "wav", "-c:a", "pcm_s16le"]) 132 | elif response_format == "pcm": # even though pcm is technically 'raw', we still use ffmpeg to adjust the speed 133 | ffmpeg_args.extend(["-f", "s16le", "-c:a", "pcm_s16le"]) 134 | return ffmpeg_args 135 | 136 | def exception_check(exq: queue.Queue): 137 | try: 138 | e = exq.get_nowait() 139 | except queue.Empty: 140 | return 141 | raise e 142 | 143 | if tts_request.response_format == "mp3": 144 | media_type = "audio/mpeg" 145 | elif tts_request.response_format == "opus": 146 | media_type = "audio/ogg;codec=opus" # codecs? 147 | elif tts_request.response_format == "aac": 148 | media_type = "audio/aac" 149 | elif tts_request.response_format == "flac": 150 | media_type = "audio/x-flac" 151 | elif tts_request.response_format == "wav": 152 | media_type = "audio/wav" 153 | elif tts_request.response_format == "pcm": 154 | media_type = "audio/pcm;rate=24000" 155 | else: 156 | raise ValueError(f"Invalid response_format: '{tts_request.response_format}'", param='response_format') 157 | 158 | ffmpeg_args = None 159 | ffmpeg_args = build_ffmpeg_args(tts_request.response_format, input_format="f32le", sample_rate="24000") 160 | ffmpeg_args.extend(["-"]) 161 | ffmpeg_proc = subprocess.Popen(ffmpeg_args, stdin=subprocess.PIPE, stdout=subprocess.PIPE) 162 | 163 | in_q = queue.Queue() 164 | ex_q = queue.Queue() 165 | 166 | def generator(): 167 | # text -> in_q 168 | try: 169 | model_output = tts_obj.inference_tts( 170 | tts_text=tts_request.input, 171 | prompt_speech_16k=prompt_speech_16k, 172 | speed=tts_request.speed, 173 | stream=False 174 | ) 175 | for chunk in model_output: 176 | exception_check(ex_q) 177 | chunk = chunk['tts_speech'].numpy().tobytes() 178 | in_q.put(chunk) 179 | except BrokenPipeError as e: 180 | logger.info("Client disconnected - 'Broken pipe'") 181 | except Exception as e: 182 | logger.error(f"Exception: {repr(e)}") 183 | raise e 184 | finally: 185 | in_q.put(None) # sentinel 186 | 187 | def out_writer(): 188 | try: 189 | while True: 190 | chunk = in_q.get() 191 | if chunk is None: # sentinel 192 | break 193 | ffmpeg_proc.stdin.write(chunk) # BrokenPipeError from here on client disconnect 194 | except Exception as e: # BrokenPipeError 195 | ex_q.put(e) # we need to get this exception into the generation loop 196 | ffmpeg_proc.kill() 197 | return 198 | finally: 199 | ffmpeg_proc.stdin.close() 200 | 201 | generator_worker = threading.Thread(target=generator, daemon=True) 202 | generator_worker.start() 203 | out_writer_worker = threading.Thread(target=out_writer, daemon=True) 204 | out_writer_worker.start() 205 | 206 | async def cleanup(): 207 | try: 208 | ffmpeg_proc.kill() 209 | # del generator_worker 210 | # del out_writer_worker 211 | except Exception as e: 212 | logger.error(f"Exception: {repr(e)}") 213 | 214 | return StreamingResponse( 215 | content=ffmpeg_proc.stdout, 216 | media_type=media_type, 217 | background=cleanup 218 | ) 219 | 220 | @app.post("/tts") 221 | @app.post("/v1/tts") 222 | async def tts( 223 | text: str = Form(...), 224 | voice: str = Form("0"), 225 | speed: float = Form(1.0), 226 | audio_url: str = Form(None), 227 | audio_file: UploadFile = File(None) 228 | ): 229 | logger.info(f"Received TTS request: text={text}, voice={voice}, speed={speed}, audio_url={audio_url}") 230 | voice_file = None 231 | 232 | # Case 1: Uploaded audio file 233 | if audio_file: 234 | temp_audio_file = tempfile.NamedTemporaryFile( 235 | delete=False, 236 | suffix=f'.{audio_file.filename.split(".")[-1]}' 237 | ) 238 | try: 239 | with open(temp_audio_file.name, "wb") as temp_file: 240 | shutil.copyfileobj(audio_file.file, temp_file) 241 | voice_file = temp_audio_file.name 242 | logger.info(f"Using uploaded audio file as voice: {voice_file}") 243 | finally: 244 | audio_file.file.close() 245 | 246 | # Case 2: Audio URL 247 | elif audio_url: 248 | temp_audio_file = tempfile.NamedTemporaryFile( 249 | delete=False, 250 | suffix=f'.{audio_url.lower().split(".")[-1]}' 251 | ) 252 | try: 253 | response = requests.get(audio_url, stream=True) 254 | if response.status_code != 200: 255 | raise HTTPException(status_code=400, detail="Failed to fetch audio from URL") 256 | with open(temp_audio_file.name, "wb") as temp_file: 257 | shutil.copyfileobj(response.raw, temp_file) 258 | voice_file = temp_audio_file.name 259 | logger.info(f"Using audio URL as voice: {voice_file}") 260 | finally: 261 | response.close() 262 | 263 | # Case 3: Predefined voice 264 | elif voice: 265 | if voice.isdigit(): 266 | voice_file = list(VOICE_MAP.values())[int(voice)] 267 | else: 268 | voice_file = VOICE_MAP.get(voice) 269 | 270 | if not voice_file: 271 | logger.error(f"Voice {voice} not found") 272 | raise HTTPException(status_code=404, detail="Voice not found") 273 | 274 | else: 275 | voice_file = random.choice(list(VOICE_MAP.values())) 276 | 277 | # Error if no voice file is available 278 | if not voice_file or not os.path.exists(voice_file): 279 | raise HTTPException(status_code=400, detail="No valid voice file provided") 280 | 281 | prompt_speech_16k = load_prompt_speech_from_file( 282 | filepath=voice_file, 283 | min_duration=3, 284 | max_duration=5 285 | ) 286 | 287 | temp_output_file = tempfile.NamedTemporaryFile( 288 | delete=False, 289 | suffix=f"_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp3" 290 | ) 291 | 292 | try: 293 | model_output = tts_obj.inference_tts( 294 | tts_text=text, 295 | prompt_speech_16k=prompt_speech_16k, 296 | speed=speed, 297 | stream=False 298 | ) 299 | 300 | raw_audio = b''.join(chunk['tts_speech'].numpy().tobytes() for chunk in model_output) 301 | ffmpeg_args = [ 302 | "ffmpeg", "-loglevel", "error", "-y", "-f", "f32le", "-ar", "24000", "-ac", "1", 303 | "-i", "-", "-f", "mp3", "-c:a", "libmp3lame", "-ab", "64k", temp_output_file.name 304 | ] 305 | ffmpeg_proc = subprocess.run( 306 | ffmpeg_args, 307 | input=raw_audio, 308 | stdout=subprocess.PIPE, 309 | stderr=subprocess.PIPE 310 | ) 311 | 312 | if ffmpeg_proc.returncode != 0: 313 | logger.error(f"FFmpeg error: {ffmpeg_proc.stderr.decode()}") 314 | raise HTTPException(status_code=500, detail="Error during audio processing") 315 | 316 | if not os.path.exists(temp_output_file.name): 317 | logger.error(f"FFmpeg did not create the output file: {temp_output_file.name}") 318 | raise HTTPException(status_code=500, detail="FFmpeg failed to produce the output file") 319 | 320 | return FileResponse( 321 | path=temp_output_file.name, 322 | media_type="audio/mpeg", 323 | filename=temp_output_file.name.split("/")[-1] 324 | ) 325 | 326 | finally: 327 | if audio_file or audio_url: 328 | if os.path.exists(temp_audio_file.name): 329 | os.unlink(temp_audio_file.name) 330 | 331 | 332 | @app.on_event("startup") 333 | async def startup(): 334 | global tts_obj 335 | RunVar("_default_thread_limiter").set(CapacityLimiter(os.cpu_count())) 336 | tts_obj = TTS('./pretrained-models') -------------------------------------------------------------------------------- /viettts/tokenizer/tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import base64 3 | import tiktoken 4 | from functools import lru_cache 5 | from whisper.tokenizer import Tokenizer 6 | 7 | LANGUAGES = { 8 | "en": "english", 9 | "zh": "chinese", 10 | "de": "german", 11 | "es": "spanish", 12 | "ru": "russian", 13 | "ko": "korean", 14 | "fr": "french", 15 | "ja": "japanese", 16 | "pt": "portuguese", 17 | "tr": "turkish", 18 | "pl": "polish", 19 | "ca": "catalan", 20 | "nl": "dutch", 21 | "ar": "arabic", 22 | "sv": "swedish", 23 | "it": "italian", 24 | "id": "indonesian", 25 | "hi": "hindi", 26 | "fi": "finnish", 27 | "vi": "vietnamese", 28 | "he": "hebrew", 29 | "uk": "ukrainian", 30 | "el": "greek", 31 | "ms": "malay", 32 | "cs": "czech", 33 | "ro": "romanian", 34 | "da": "danish", 35 | "hu": "hungarian", 36 | "ta": "tamil", 37 | "no": "norwegian", 38 | "th": "thai", 39 | "ur": "urdu", 40 | "hr": "croatian", 41 | "bg": "bulgarian", 42 | "lt": "lithuanian", 43 | "la": "latin", 44 | "mi": "maori", 45 | "ml": "malayalam", 46 | "cy": "welsh", 47 | "sk": "slovak", 48 | "te": "telugu", 49 | "fa": "persian", 50 | "lv": "latvian", 51 | "bn": "bengali", 52 | "sr": "serbian", 53 | "az": "azerbaijani", 54 | "sl": "slovenian", 55 | "kn": "kannada", 56 | "et": "estonian", 57 | "mk": "macedonian", 58 | "br": "breton", 59 | "eu": "basque", 60 | "is": "icelandic", 61 | "hy": "armenian", 62 | "ne": "nepali", 63 | "mn": "mongolian", 64 | "bs": "bosnian", 65 | "kk": "kazakh", 66 | "sq": "albanian", 67 | "sw": "swahili", 68 | "gl": "galician", 69 | "mr": "marathi", 70 | "pa": "punjabi", 71 | "si": "sinhala", 72 | "km": "khmer", 73 | "sn": "shona", 74 | "yo": "yoruba", 75 | "so": "somali", 76 | "af": "afrikaans", 77 | "oc": "occitan", 78 | "ka": "georgian", 79 | "be": "belarusian", 80 | "tg": "tajik", 81 | "sd": "sindhi", 82 | "gu": "gujarati", 83 | "am": "amharic", 84 | "yi": "yiddish", 85 | "lo": "lao", 86 | "uz": "uzbek", 87 | "fo": "faroese", 88 | "ht": "haitian creole", 89 | "ps": "pashto", 90 | "tk": "turkmen", 91 | "nn": "nynorsk", 92 | "mt": "maltese", 93 | "sa": "sanskrit", 94 | "lb": "luxembourgish", 95 | "my": "myanmar", 96 | "bo": "tibetan", 97 | "tl": "tagalog", 98 | "mg": "malagasy", 99 | "as": "assamese", 100 | "tt": "tatar", 101 | "haw": "hawaiian", 102 | "ln": "lingala", 103 | "ha": "hausa", 104 | "ba": "bashkir", 105 | "jw": "javanese", 106 | "su": "sundanese", 107 | "yue": "cantonese", 108 | "minnan": "minnan", 109 | "wuyu": "wuyu", 110 | "dialect": "dialect", 111 | "zh/en": "zh/en", 112 | "en/zh": "en/zh", 113 | } 114 | 115 | TO_LANGUAGE_CODE = { 116 | **{language: code for code, language in LANGUAGES.items()}, 117 | "burmese": "my", 118 | "valencian": "ca", 119 | "flemish": "nl", 120 | "haitian": "ht", 121 | "letzeburgesch": "lb", 122 | "pushto": "ps", 123 | "panjabi": "pa", 124 | "moldavian": "ro", 125 | "moldovan": "ro", 126 | "sinhalese": "si", 127 | "castilian": "es", 128 | "mandarin": "zh", 129 | } 130 | 131 | AUDIO_EVENT = { 132 | "ASR": "ASR", 133 | "AED": "AED", 134 | "SER": "SER", 135 | "Speech": "Speech", 136 | "/Speech": "/Speech", 137 | "BGM": "BGM", 138 | "/BGM": "/BGM", 139 | "Laughter": "Laughter", 140 | "/Laughter": "/Laughter", 141 | "Applause": "Applause", 142 | "/Applause": "/Applause", 143 | } 144 | 145 | EMOTION = { 146 | "HAPPY": "HAPPY", 147 | "SAD": "SAD", 148 | "ANGRY": "ANGRY", 149 | "NEUTRAL": "NEUTRAL", 150 | } 151 | 152 | TTS_Vocal_Token = { 153 | "TTS/B": "TTS/B", 154 | "TTS/O": "TTS/O", 155 | "TTS/Q": "TTS/Q", 156 | "TTS/A": "TTS/A", 157 | "TTS/CO": "TTS/CO", 158 | "TTS/CL": "TTS/CL", 159 | "TTS/H": "TTS/H", 160 | **{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)} 161 | } 162 | 163 | 164 | @lru_cache(maxsize=None) 165 | def get_encoding(name: str="gpt2", num_languages: int=99): 166 | vocab_path = os.path.join(os.path.dirname(__file__), f"{name}.tiktoken") 167 | lines = open(vocab_path).readlines() 168 | ranks = { 169 | base64.b64decode(token): int(rank) 170 | for token, rank in (line.split() for line in lines if line) 171 | } 172 | n_vocab = len(ranks) 173 | special_tokens = {} 174 | 175 | specials = [ 176 | "<|endoftext|>", 177 | "<|startoftranscript|>", 178 | *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], 179 | *[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())], 180 | *[f"<|{emotion}|>" for emotion in list(EMOTION.keys())], 181 | "<|translate|>", 182 | "<|transcribe|>", 183 | "<|startoflm|>", 184 | "<|startofprev|>", 185 | "<|nospeech|>", 186 | "<|notimestamps|>", 187 | *[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR 188 | *[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS 189 | *[f"<|{i * 0.02:.2f}|>" for i in range(1501)], 190 | ] 191 | 192 | for token in specials: 193 | special_tokens[token] = n_vocab 194 | n_vocab += 1 195 | 196 | return tiktoken.Encoding( 197 | name=os.path.basename(vocab_path), 198 | explicit_n_vocab=n_vocab, 199 | pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", 200 | mergeable_ranks=ranks, 201 | special_tokens=special_tokens, 202 | ) 203 | 204 | 205 | @lru_cache(maxsize=None) 206 | def get_tokenizer() -> Tokenizer: 207 | encoding_name = "multilingual" 208 | num_languages = 100 209 | encoding = get_encoding(name=encoding_name, num_languages=num_languages) 210 | return Tokenizer( 211 | encoding=encoding, 212 | num_languages=num_languages, 213 | language='en', 214 | task='transcribe' 215 | ) 216 | 217 | if __name__ == "__main__": 218 | tokenizer = get_tokenizer() 219 | print(tokenizer.decode(tokenizer.encode("xin chào Việt Nam, tôi là nam, 1234 1 2?"))) -------------------------------------------------------------------------------- /viettts/transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/viettts/transformer/__init__.py -------------------------------------------------------------------------------- /viettts/transformer/activation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, sin, pow 3 | from torch.nn import Parameter 4 | 5 | 6 | class Swish(torch.nn.Module): 7 | """Construct an Swish object.""" 8 | 9 | def forward(self, x: torch.Tensor) -> torch.Tensor: 10 | """Return Swish activation function.""" 11 | return x * torch.sigmoid(x) 12 | 13 | 14 | # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. 15 | # LICENSE is in incl_licenses directory. 16 | class Snake(nn.Module): 17 | ''' 18 | Implementation of a sine-based periodic activation function 19 | Shape: 20 | - Input: (B, C, T) 21 | - Output: (B, C, T), same shape as the input 22 | Parameters: 23 | - alpha - trainable parameter 24 | References: 25 | - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 26 | https://arxiv.org/abs/2006.08195 27 | Examples: 28 | >>> a1 = snake(256) 29 | >>> x = torch.randn(256) 30 | >>> x = a1(x) 31 | ''' 32 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): 33 | ''' 34 | Initialization. 35 | INPUT: 36 | - in_features: shape of the input 37 | - alpha: trainable parameter 38 | alpha is initialized to 1 by default, higher values = higher-frequency. 39 | alpha will be trained along with the rest of your model. 40 | ''' 41 | super(Snake, self).__init__() 42 | self.in_features = in_features 43 | 44 | # initialize alpha 45 | self.alpha_logscale = alpha_logscale 46 | if self.alpha_logscale: # log scale alphas initialized to zeros 47 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 48 | else: # linear scale alphas initialized to ones 49 | self.alpha = Parameter(torch.ones(in_features) * alpha) 50 | 51 | self.alpha.requires_grad = alpha_trainable 52 | 53 | self.no_div_by_zero = 0.000000001 54 | 55 | def forward(self, x): 56 | ''' 57 | Forward pass of the function. 58 | Applies the function to the input elementwise. 59 | Snake ∶= x + 1/a * sin^2 (xa) 60 | ''' 61 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] 62 | if self.alpha_logscale: 63 | alpha = torch.exp(alpha) 64 | x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 65 | 66 | return x 67 | -------------------------------------------------------------------------------- /viettts/transformer/convolution.py: -------------------------------------------------------------------------------- 1 | # Modified from ESPnet(https://github.com/espnet/espnet) 2 | """ConvolutionModule definition.""" 3 | 4 | from typing import Tuple 5 | 6 | import torch 7 | from torch import nn 8 | 9 | 10 | class ConvolutionModule(nn.Module): 11 | """ConvolutionModule in Conformer model.""" 12 | 13 | def __init__(self, 14 | channels: int, 15 | kernel_size: int = 15, 16 | activation: nn.Module = nn.ReLU(), 17 | norm: str = "batch_norm", 18 | causal: bool = False, 19 | bias: bool = True): 20 | """Construct an ConvolutionModule object. 21 | Args: 22 | channels (int): The number of channels of conv layers. 23 | kernel_size (int): Kernel size of conv layers. 24 | causal (int): Whether use causal convolution or not 25 | """ 26 | super().__init__() 27 | 28 | self.pointwise_conv1 = nn.Conv1d( 29 | channels, 30 | 2 * channels, 31 | kernel_size=1, 32 | stride=1, 33 | padding=0, 34 | bias=bias, 35 | ) 36 | # self.lorder is used to distinguish if it's a causal convolution, 37 | # if self.lorder > 0: it's a causal convolution, the input will be 38 | # padded with self.lorder frames on the left in forward. 39 | # else: it's a symmetrical convolution 40 | if causal: 41 | padding = 0 42 | self.lorder = kernel_size - 1 43 | else: 44 | # kernel_size should be an odd number for none causal convolution 45 | assert (kernel_size - 1) % 2 == 0 46 | padding = (kernel_size - 1) // 2 47 | self.lorder = 0 48 | self.depthwise_conv = nn.Conv1d( 49 | channels, 50 | channels, 51 | kernel_size, 52 | stride=1, 53 | padding=padding, 54 | groups=channels, 55 | bias=bias, 56 | ) 57 | 58 | assert norm in ['batch_norm', 'layer_norm'] 59 | if norm == "batch_norm": 60 | self.use_layer_norm = False 61 | self.norm = nn.BatchNorm1d(channels) 62 | else: 63 | self.use_layer_norm = True 64 | self.norm = nn.LayerNorm(channels) 65 | 66 | self.pointwise_conv2 = nn.Conv1d( 67 | channels, 68 | channels, 69 | kernel_size=1, 70 | stride=1, 71 | padding=0, 72 | bias=bias, 73 | ) 74 | self.activation = activation 75 | 76 | def forward( 77 | self, 78 | x: torch.Tensor, 79 | mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 80 | cache: torch.Tensor = torch.zeros((0, 0, 0)), 81 | ) -> Tuple[torch.Tensor, torch.Tensor]: 82 | """Compute convolution module. 83 | Args: 84 | x (torch.Tensor): Input tensor (#batch, time, channels). 85 | mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), 86 | (0, 0, 0) means fake mask. 87 | cache (torch.Tensor): left context cache, it is only 88 | used in causal convolution (#batch, channels, cache_t), 89 | (0, 0, 0) meas fake cache. 90 | Returns: 91 | torch.Tensor: Output tensor (#batch, time, channels). 92 | """ 93 | # exchange the temporal dimension and the feature dimension 94 | x = x.transpose(1, 2) # (#batch, channels, time) 95 | 96 | # mask batch padding 97 | if mask_pad.size(2) > 0: # time > 0 98 | x.masked_fill_(~mask_pad, 0.0) 99 | 100 | if self.lorder > 0: 101 | if cache.size(2) == 0: # cache_t == 0 102 | x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) 103 | else: 104 | assert cache.size(0) == x.size(0) # equal batch 105 | assert cache.size(1) == x.size(1) # equal channel 106 | x = torch.cat((cache, x), dim=2) 107 | assert (x.size(2) > self.lorder) 108 | new_cache = x[:, :, -self.lorder:] 109 | else: 110 | # It's better we just return None if no cache is required, 111 | # However, for JIT export, here we just fake one tensor instead of 112 | # None. 113 | new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) 114 | 115 | # GLU mechanism 116 | x = self.pointwise_conv1(x) # (batch, 2*channel, dim) 117 | x = nn.functional.glu(x, dim=1) # (batch, channel, dim) 118 | 119 | # 1D Depthwise Conv 120 | x = self.depthwise_conv(x) 121 | if self.use_layer_norm: 122 | x = x.transpose(1, 2) 123 | x = self.activation(self.norm(x)) 124 | if self.use_layer_norm: 125 | x = x.transpose(1, 2) 126 | x = self.pointwise_conv2(x) 127 | # mask batch padding 128 | if mask_pad.size(2) > 0: # time > 0 129 | x.masked_fill_(~mask_pad, 0.0) 130 | 131 | return x.transpose(1, 2), new_cache 132 | -------------------------------------------------------------------------------- /viettts/transformer/decoder_layer.py: -------------------------------------------------------------------------------- 1 | """Decoder self-attention layer definition.""" 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class DecoderLayer(nn.Module): 9 | """Single decoder layer module. 10 | 11 | Args: 12 | size (int): Input dimension. 13 | self_attn (torch.nn.Module): Self-attention module instance. 14 | `MultiHeadedAttention` instance can be used as the argument. 15 | src_attn (torch.nn.Module): Inter-attention module instance. 16 | `MultiHeadedAttention` instance can be used as the argument. 17 | If `None` is passed, Inter-attention is not used, such as 18 | CIF, GPT, and other decoder only model. 19 | feed_forward (torch.nn.Module): Feed-forward module instance. 20 | `PositionwiseFeedForward` instance can be used as the argument. 21 | dropout_rate (float): Dropout rate. 22 | normalize_before (bool): 23 | True: use layer_norm before each sub-block. 24 | False: to use layer_norm after each sub-block. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | size: int, 30 | self_attn: nn.Module, 31 | src_attn: Optional[nn.Module], 32 | feed_forward: nn.Module, 33 | dropout_rate: float, 34 | normalize_before: bool = True, 35 | ): 36 | """Construct an DecoderLayer object.""" 37 | super().__init__() 38 | self.size = size 39 | self.self_attn = self_attn 40 | self.src_attn = src_attn 41 | self.feed_forward = feed_forward 42 | self.norm1 = nn.LayerNorm(size, eps=1e-5) 43 | self.norm2 = nn.LayerNorm(size, eps=1e-5) 44 | self.norm3 = nn.LayerNorm(size, eps=1e-5) 45 | self.dropout = nn.Dropout(dropout_rate) 46 | self.normalize_before = normalize_before 47 | 48 | def forward( 49 | self, 50 | tgt: torch.Tensor, 51 | tgt_mask: torch.Tensor, 52 | memory: torch.Tensor, 53 | memory_mask: torch.Tensor, 54 | cache: Optional[torch.Tensor] = None 55 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 56 | """Compute decoded features. 57 | 58 | Args: 59 | tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). 60 | tgt_mask (torch.Tensor): Mask for input tensor 61 | (#batch, maxlen_out). 62 | memory (torch.Tensor): Encoded memory 63 | (#batch, maxlen_in, size). 64 | memory_mask (torch.Tensor): Encoded memory mask 65 | (#batch, maxlen_in). 66 | cache (torch.Tensor): cached tensors. 67 | (#batch, maxlen_out - 1, size). 68 | 69 | Returns: 70 | torch.Tensor: Output tensor (#batch, maxlen_out, size). 71 | torch.Tensor: Mask for output tensor (#batch, maxlen_out). 72 | torch.Tensor: Encoded memory (#batch, maxlen_in, size). 73 | torch.Tensor: Encoded memory mask (#batch, maxlen_in). 74 | 75 | """ 76 | residual = tgt 77 | if self.normalize_before: 78 | tgt = self.norm1(tgt) 79 | 80 | if cache is None: 81 | tgt_q = tgt 82 | tgt_q_mask = tgt_mask 83 | else: 84 | # compute only the last frame query keeping dim: max_time_out -> 1 85 | assert cache.shape == ( 86 | tgt.shape[0], 87 | tgt.shape[1] - 1, 88 | self.size, 89 | ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" 90 | tgt_q = tgt[:, -1:, :] 91 | residual = residual[:, -1:, :] 92 | tgt_q_mask = tgt_mask[:, -1:, :] 93 | 94 | x = residual + self.dropout( 95 | self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]) 96 | if not self.normalize_before: 97 | x = self.norm1(x) 98 | 99 | if self.src_attn is not None: 100 | residual = x 101 | if self.normalize_before: 102 | x = self.norm2(x) 103 | x = residual + self.dropout( 104 | self.src_attn(x, memory, memory, memory_mask)[0]) 105 | if not self.normalize_before: 106 | x = self.norm2(x) 107 | 108 | residual = x 109 | if self.normalize_before: 110 | x = self.norm3(x) 111 | x = residual + self.dropout(self.feed_forward(x)) 112 | if not self.normalize_before: 113 | x = self.norm3(x) 114 | 115 | if cache is not None: 116 | x = torch.cat([cache, x], dim=1) 117 | 118 | return x, tgt_mask, memory, memory_mask 119 | -------------------------------------------------------------------------------- /viettts/transformer/embedding.py: -------------------------------------------------------------------------------- 1 | # Modified from ESPnet(https://github.com/espnet/espnet) 2 | """Positonal Encoding Module.""" 3 | 4 | import math 5 | from typing import Tuple, Union 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import numpy as np 10 | 11 | 12 | class PositionalEncoding(torch.nn.Module): 13 | """Positional encoding. 14 | 15 | :param int d_model: embedding dim 16 | :param float dropout_rate: dropout rate 17 | :param int max_len: maximum input length 18 | 19 | PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) 20 | PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) 21 | """ 22 | 23 | def __init__(self, 24 | d_model: int, 25 | dropout_rate: float, 26 | max_len: int = 5000, 27 | reverse: bool = False): 28 | """Construct an PositionalEncoding object.""" 29 | super().__init__() 30 | self.d_model = d_model 31 | self.xscale = math.sqrt(self.d_model) 32 | self.dropout = torch.nn.Dropout(p=dropout_rate) 33 | self.max_len = max_len 34 | 35 | self.pe = torch.zeros(self.max_len, self.d_model) 36 | position = torch.arange(0, self.max_len, 37 | dtype=torch.float32).unsqueeze(1) 38 | div_term = torch.exp( 39 | torch.arange(0, self.d_model, 2, dtype=torch.float32) * 40 | -(math.log(10000.0) / self.d_model)) 41 | self.pe[:, 0::2] = torch.sin(position * div_term) 42 | self.pe[:, 1::2] = torch.cos(position * div_term) 43 | self.pe = self.pe.unsqueeze(0) 44 | 45 | def forward(self, 46 | x: torch.Tensor, 47 | offset: Union[int, torch.Tensor] = 0) \ 48 | -> Tuple[torch.Tensor, torch.Tensor]: 49 | """Add positional encoding. 50 | 51 | Args: 52 | x (torch.Tensor): Input. Its shape is (batch, time, ...) 53 | offset (int, torch.tensor): position offset 54 | 55 | Returns: 56 | torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) 57 | torch.Tensor: for compatibility to RelPositionalEncoding 58 | """ 59 | 60 | self.pe = self.pe.to(x.device) 61 | pos_emb = self.position_encoding(offset, x.size(1), False) 62 | x = x * self.xscale + pos_emb 63 | return self.dropout(x), self.dropout(pos_emb) 64 | 65 | def position_encoding(self, 66 | offset: Union[int, torch.Tensor], 67 | size: int, 68 | apply_dropout: bool = True) -> torch.Tensor: 69 | """ For getting encoding in a streaming fashion 70 | 71 | Attention!!!!! 72 | we apply dropout only once at the whole utterance level in a none 73 | streaming way, but will call this function several times with 74 | increasing input size in a streaming scenario, so the dropout will 75 | be applied several times. 76 | 77 | Args: 78 | offset (int or torch.tensor): start offset 79 | size (int): required size of position encoding 80 | 81 | Returns: 82 | torch.Tensor: Corresponding encoding 83 | """ 84 | # How to subscript a Union type: 85 | # https://github.com/pytorch/pytorch/issues/69434 86 | if isinstance(offset, int): 87 | assert offset + size <= self.max_len 88 | pos_emb = self.pe[:, offset:offset + size] 89 | elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar 90 | assert offset + size <= self.max_len 91 | pos_emb = self.pe[:, offset:offset + size] 92 | else: # for batched streaming decoding on GPU 93 | assert torch.max(offset) + size <= self.max_len 94 | index = offset.unsqueeze(1) + \ 95 | torch.arange(0, size).to(offset.device) # B X T 96 | flag = index > 0 97 | # remove negative offset 98 | index = index * flag 99 | pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model 100 | 101 | if apply_dropout: 102 | pos_emb = self.dropout(pos_emb) 103 | return pos_emb 104 | 105 | 106 | class RelPositionalEncoding(PositionalEncoding): 107 | """Relative positional encoding module. 108 | See : Appendix B in https://arxiv.org/abs/1901.02860 109 | Args: 110 | d_model (int): Embedding dimension. 111 | dropout_rate (float): Dropout rate. 112 | max_len (int): Maximum input length. 113 | """ 114 | 115 | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): 116 | """Initialize class.""" 117 | super().__init__(d_model, dropout_rate, max_len, reverse=True) 118 | 119 | def forward(self, 120 | x: torch.Tensor, 121 | offset: Union[int, torch.Tensor] = 0) \ 122 | -> Tuple[torch.Tensor, torch.Tensor]: 123 | """Compute positional encoding. 124 | Args: 125 | x (torch.Tensor): Input tensor (batch, time, `*`). 126 | Returns: 127 | torch.Tensor: Encoded tensor (batch, time, `*`). 128 | torch.Tensor: Positional embedding tensor (1, time, `*`). 129 | """ 130 | self.pe = self.pe.to(x.device) 131 | x = x * self.xscale 132 | pos_emb = self.position_encoding(offset, x.size(1), False) 133 | return self.dropout(x), self.dropout(pos_emb) 134 | 135 | 136 | class WhisperPositionalEncoding(PositionalEncoding): 137 | """ Sinusoids position encoding used in openai-whisper.encoder 138 | """ 139 | 140 | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500): 141 | super().__init__(d_model, dropout_rate, max_len) 142 | self.xscale = 1.0 143 | log_timescale_increment = np.log(10000) / (d_model // 2 - 1) 144 | inv_timescales = torch.exp(-log_timescale_increment * 145 | torch.arange(d_model // 2)) 146 | scaled_time = torch.arange(max_len)[:, np.newaxis] * \ 147 | inv_timescales[np.newaxis, :] 148 | pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) 149 | delattr(self, "pe") 150 | self.register_buffer("pe", pe.unsqueeze(0)) 151 | 152 | 153 | class LearnablePositionalEncoding(PositionalEncoding): 154 | """ Learnable position encoding used in openai-whisper.decoder 155 | """ 156 | 157 | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448): 158 | super().__init__(d_model, dropout_rate, max_len) 159 | # NOTE(xcsong): overwrite self.pe & self.xscale 160 | self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model)) 161 | self.xscale = 1.0 162 | 163 | 164 | class NoPositionalEncoding(torch.nn.Module): 165 | """ No position encoding 166 | """ 167 | 168 | def __init__(self, d_model: int, dropout_rate: float): 169 | super().__init__() 170 | self.d_model = d_model 171 | self.dropout = torch.nn.Dropout(p=dropout_rate) 172 | 173 | def forward(self, 174 | x: torch.Tensor, 175 | offset: Union[int, torch.Tensor] = 0) \ 176 | -> Tuple[torch.Tensor, torch.Tensor]: 177 | """ Just return zero vector for interface compatibility 178 | """ 179 | pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device) 180 | return self.dropout(x), pos_emb 181 | 182 | def position_encoding(self, offset: Union[int, torch.Tensor], 183 | size: int) -> torch.Tensor: 184 | return torch.zeros(1, size, self.d_model) 185 | 186 | 187 | class EspnetRelPositionalEncoding(torch.nn.Module): 188 | """Relative positional encoding module (new implementation). 189 | 190 | Details can be found in https://github.com/espnet/espnet/pull/2816. 191 | 192 | See : Appendix B in https://arxiv.org/abs/1901.02860 193 | 194 | Args: 195 | d_model (int): Embedding dimension. 196 | dropout_rate (float): Dropout rate. 197 | max_len (int): Maximum input length. 198 | 199 | """ 200 | 201 | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): 202 | """Construct an PositionalEncoding object.""" 203 | super(EspnetRelPositionalEncoding, self).__init__() 204 | self.d_model = d_model 205 | self.xscale = math.sqrt(self.d_model) 206 | self.dropout = torch.nn.Dropout(p=dropout_rate) 207 | self.pe = None 208 | self.extend_pe(torch.tensor(0.0).expand(1, max_len)) 209 | 210 | def extend_pe(self, x: torch.Tensor): 211 | """Reset the positional encodings.""" 212 | if self.pe is not None: 213 | # self.pe contains both positive and negative parts 214 | # the length of self.pe is 2 * input_len - 1 215 | if self.pe.size(1) >= x.size(1) * 2 - 1: 216 | if self.pe.dtype != x.dtype or self.pe.device != x.device: 217 | self.pe = self.pe.to(dtype=x.dtype, device=x.device) 218 | return 219 | # Suppose `i` means to the position of query vecotr and `j` means the 220 | # position of key vector. We use position relative positions when keys 221 | # are to the left (i>j) and negative relative positions otherwise (i Tuple[torch.Tensor, torch.Tensor]: 244 | """Add positional encoding. 245 | 246 | Args: 247 | x (torch.Tensor): Input tensor (batch, time, `*`). 248 | 249 | Returns: 250 | torch.Tensor: Encoded tensor (batch, time, `*`). 251 | 252 | """ 253 | self.extend_pe(x) 254 | x = x * self.xscale 255 | pos_emb = self.position_encoding(size=x.size(1), offset=offset) 256 | return self.dropout(x), self.dropout(pos_emb) 257 | 258 | def position_encoding(self, 259 | offset: Union[int, torch.Tensor], 260 | size: int) -> torch.Tensor: 261 | """ For getting encoding in a streaming fashion 262 | 263 | Attention!!!!! 264 | we apply dropout only once at the whole utterance level in a none 265 | streaming way, but will call this function several times with 266 | increasing input size in a streaming scenario, so the dropout will 267 | be applied several times. 268 | 269 | Args: 270 | offset (int or torch.tensor): start offset 271 | size (int): required size of position encoding 272 | 273 | Returns: 274 | torch.Tensor: Corresponding encoding 275 | """ 276 | pos_emb = self.pe[ 277 | :, 278 | self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size, 279 | ] 280 | return pos_emb 281 | -------------------------------------------------------------------------------- /viettts/transformer/encoder_layer.py: -------------------------------------------------------------------------------- 1 | # Modified from ESPnet(https://github.com/espnet/espnet) 2 | """Encoder self-attention layer definition.""" 3 | 4 | from typing import Optional, Tuple 5 | 6 | import torch 7 | from torch import nn 8 | 9 | 10 | class TransformerEncoderLayer(nn.Module): 11 | """Encoder layer module. 12 | 13 | Args: 14 | size (int): Input dimension. 15 | self_attn (torch.nn.Module): Self-attention module instance. 16 | `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` 17 | instance can be used as the argument. 18 | feed_forward (torch.nn.Module): Feed-forward module instance. 19 | `PositionwiseFeedForward`, instance can be used as the argument. 20 | dropout_rate (float): Dropout rate. 21 | normalize_before (bool): 22 | True: use layer_norm before each sub-block. 23 | False: to use layer_norm after each sub-block. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | size: int, 29 | self_attn: torch.nn.Module, 30 | feed_forward: torch.nn.Module, 31 | dropout_rate: float, 32 | normalize_before: bool = True, 33 | ): 34 | """Construct an EncoderLayer object.""" 35 | super().__init__() 36 | self.self_attn = self_attn 37 | self.feed_forward = feed_forward 38 | self.norm1 = nn.LayerNorm(size, eps=1e-5) 39 | self.norm2 = nn.LayerNorm(size, eps=1e-5) 40 | self.dropout = nn.Dropout(dropout_rate) 41 | self.size = size 42 | self.normalize_before = normalize_before 43 | 44 | def forward( 45 | self, 46 | x: torch.Tensor, 47 | mask: torch.Tensor, 48 | pos_emb: torch.Tensor, 49 | mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 50 | att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), 51 | cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), 52 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 53 | """Compute encoded features. 54 | 55 | Args: 56 | x (torch.Tensor): (#batch, time, size) 57 | mask (torch.Tensor): Mask tensor for the input (#batch, time,time), 58 | (0, 0, 0) means fake mask. 59 | pos_emb (torch.Tensor): just for interface compatibility 60 | to ConformerEncoderLayer 61 | mask_pad (torch.Tensor): does not used in transformer layer, 62 | just for unified api with conformer. 63 | att_cache (torch.Tensor): Cache tensor of the KEY & VALUE 64 | (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. 65 | cnn_cache (torch.Tensor): Convolution cache in conformer layer 66 | (#batch=1, size, cache_t2), not used here, it's for interface 67 | compatibility to ConformerEncoderLayer. 68 | Returns: 69 | torch.Tensor: Output tensor (#batch, time, size). 70 | torch.Tensor: Mask tensor (#batch, time, time). 71 | torch.Tensor: att_cache tensor, 72 | (#batch=1, head, cache_t1 + time, d_k * 2). 73 | torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2). 74 | 75 | """ 76 | residual = x 77 | if self.normalize_before: 78 | x = self.norm1(x) 79 | x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache) 80 | x = residual + self.dropout(x_att) 81 | if not self.normalize_before: 82 | x = self.norm1(x) 83 | 84 | residual = x 85 | if self.normalize_before: 86 | x = self.norm2(x) 87 | x = residual + self.dropout(self.feed_forward(x)) 88 | if not self.normalize_before: 89 | x = self.norm2(x) 90 | 91 | fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) 92 | return x, mask, new_att_cache, fake_cnn_cache 93 | 94 | 95 | class ConformerEncoderLayer(nn.Module): 96 | """Encoder layer module. 97 | Args: 98 | size (int): Input dimension. 99 | self_attn (torch.nn.Module): Self-attention module instance. 100 | `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` 101 | instance can be used as the argument. 102 | feed_forward (torch.nn.Module): Feed-forward module instance. 103 | `PositionwiseFeedForward` instance can be used as the argument. 104 | feed_forward_macaron (torch.nn.Module): Additional feed-forward module 105 | instance. 106 | `PositionwiseFeedForward` instance can be used as the argument. 107 | conv_module (torch.nn.Module): Convolution module instance. 108 | `ConvlutionModule` instance can be used as the argument. 109 | dropout_rate (float): Dropout rate. 110 | normalize_before (bool): 111 | True: use layer_norm before each sub-block. 112 | False: use layer_norm after each sub-block. 113 | """ 114 | 115 | def __init__( 116 | self, 117 | size: int, 118 | self_attn: torch.nn.Module, 119 | feed_forward: Optional[nn.Module] = None, 120 | feed_forward_macaron: Optional[nn.Module] = None, 121 | conv_module: Optional[nn.Module] = None, 122 | dropout_rate: float = 0.1, 123 | normalize_before: bool = True, 124 | ): 125 | """Construct an EncoderLayer object.""" 126 | super().__init__() 127 | self.self_attn = self_attn 128 | self.feed_forward = feed_forward 129 | self.feed_forward_macaron = feed_forward_macaron 130 | self.conv_module = conv_module 131 | self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module 132 | self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module 133 | if feed_forward_macaron is not None: 134 | self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5) 135 | self.ff_scale = 0.5 136 | else: 137 | self.ff_scale = 1.0 138 | if self.conv_module is not None: 139 | self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module 140 | self.norm_final = nn.LayerNorm( 141 | size, eps=1e-5) # for the final output of the block 142 | self.dropout = nn.Dropout(dropout_rate) 143 | self.size = size 144 | self.normalize_before = normalize_before 145 | 146 | def forward( 147 | self, 148 | x: torch.Tensor, 149 | mask: torch.Tensor, 150 | pos_emb: torch.Tensor, 151 | mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 152 | att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), 153 | cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), 154 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 155 | """Compute encoded features. 156 | 157 | Args: 158 | x (torch.Tensor): (#batch, time, size) 159 | mask (torch.Tensor): Mask tensor for the input (#batch, time,time), 160 | (0, 0, 0) means fake mask. 161 | pos_emb (torch.Tensor): positional encoding, must not be None 162 | for ConformerEncoderLayer. 163 | mask_pad (torch.Tensor): batch padding mask used for conv module. 164 | (#batch, 1,time), (0, 0, 0) means fake mask. 165 | att_cache (torch.Tensor): Cache tensor of the KEY & VALUE 166 | (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. 167 | cnn_cache (torch.Tensor): Convolution cache in conformer layer 168 | (#batch=1, size, cache_t2) 169 | Returns: 170 | torch.Tensor: Output tensor (#batch, time, size). 171 | torch.Tensor: Mask tensor (#batch, time, time). 172 | torch.Tensor: att_cache tensor, 173 | (#batch=1, head, cache_t1 + time, d_k * 2). 174 | torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). 175 | """ 176 | 177 | # whether to use macaron style 178 | if self.feed_forward_macaron is not None: 179 | residual = x 180 | if self.normalize_before: 181 | x = self.norm_ff_macaron(x) 182 | x = residual + self.ff_scale * self.dropout( 183 | self.feed_forward_macaron(x)) 184 | if not self.normalize_before: 185 | x = self.norm_ff_macaron(x) 186 | 187 | # multi-headed self-attention module 188 | residual = x 189 | if self.normalize_before: 190 | x = self.norm_mha(x) 191 | x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, 192 | att_cache) 193 | x = residual + self.dropout(x_att) 194 | if not self.normalize_before: 195 | x = self.norm_mha(x) 196 | 197 | # convolution module 198 | # Fake new cnn cache here, and then change it in conv_module 199 | new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) 200 | if self.conv_module is not None: 201 | residual = x 202 | if self.normalize_before: 203 | x = self.norm_conv(x) 204 | x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) 205 | x = residual + self.dropout(x) 206 | 207 | if not self.normalize_before: 208 | x = self.norm_conv(x) 209 | 210 | # feed forward module 211 | residual = x 212 | if self.normalize_before: 213 | x = self.norm_ff(x) 214 | 215 | x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) 216 | if not self.normalize_before: 217 | x = self.norm_ff(x) 218 | 219 | if self.conv_module is not None: 220 | x = self.norm_final(x) 221 | 222 | return x, mask, new_att_cache, new_cnn_cache 223 | -------------------------------------------------------------------------------- /viettts/transformer/label_smoothing_loss.py: -------------------------------------------------------------------------------- 1 | """Label smoothing module.""" 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class LabelSmoothingLoss(nn.Module): 8 | """Label-smoothing loss. 9 | 10 | In a standard CE loss, the label's data distribution is: 11 | [0,1,2] -> 12 | [ 13 | [1.0, 0.0, 0.0], 14 | [0.0, 1.0, 0.0], 15 | [0.0, 0.0, 1.0], 16 | ] 17 | 18 | In the smoothing version CE Loss,some probabilities 19 | are taken from the true label prob (1.0) and are divided 20 | among other labels. 21 | 22 | e.g. 23 | smoothing=0.1 24 | [0,1,2] -> 25 | [ 26 | [0.9, 0.05, 0.05], 27 | [0.05, 0.9, 0.05], 28 | [0.05, 0.05, 0.9], 29 | ] 30 | 31 | Args: 32 | size (int): the number of class 33 | padding_idx (int): padding class id which will be ignored for loss 34 | smoothing (float): smoothing rate (0.0 means the conventional CE) 35 | normalize_length (bool): 36 | normalize loss by sequence length if True 37 | normalize loss by batch size if False 38 | """ 39 | 40 | def __init__(self, 41 | size: int, 42 | padding_idx: int, 43 | smoothing: float, 44 | normalize_length: bool = False): 45 | """Construct an LabelSmoothingLoss object.""" 46 | super(LabelSmoothingLoss, self).__init__() 47 | self.criterion = nn.KLDivLoss(reduction="none") 48 | self.padding_idx = padding_idx 49 | self.confidence = 1.0 - smoothing 50 | self.smoothing = smoothing 51 | self.size = size 52 | self.normalize_length = normalize_length 53 | 54 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 55 | """Compute loss between x and target. 56 | 57 | The model outputs and data labels tensors are flatten to 58 | (batch*seqlen, class) shape and a mask is applied to the 59 | padding part which should not be calculated for loss. 60 | 61 | Args: 62 | x (torch.Tensor): prediction (batch, seqlen, class) 63 | target (torch.Tensor): 64 | target signal masked with self.padding_id (batch, seqlen) 65 | Returns: 66 | loss (torch.Tensor) : The KL loss, scalar float value 67 | """ 68 | assert x.size(2) == self.size 69 | batch_size = x.size(0) 70 | x = x.view(-1, self.size) 71 | target = target.view(-1) 72 | # use zeros_like instead of torch.no_grad() for true_dist, 73 | # since no_grad() can not be exported by JIT 74 | true_dist = torch.zeros_like(x) 75 | true_dist.fill_(self.smoothing / (self.size - 1)) 76 | ignore = target == self.padding_idx # (B,) 77 | total = len(target) - ignore.sum().item() 78 | target = target.masked_fill(ignore, 0) # avoid -1 index 79 | true_dist.scatter_(1, target.unsqueeze(1), self.confidence) 80 | kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) 81 | denom = total if self.normalize_length else batch_size 82 | return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom 83 | -------------------------------------------------------------------------------- /viettts/transformer/positionwise_feed_forward.py: -------------------------------------------------------------------------------- 1 | """Positionwise feed forward layer definition.""" 2 | 3 | import torch 4 | 5 | 6 | class PositionwiseFeedForward(torch.nn.Module): 7 | """Positionwise feed forward layer. 8 | 9 | FeedForward are appied on each position of the sequence. 10 | The output dim is same with the input dim. 11 | 12 | Args: 13 | idim (int): Input dimenstion. 14 | hidden_units (int): The number of hidden units. 15 | dropout_rate (float): Dropout rate. 16 | activation (torch.nn.Module): Activation function 17 | """ 18 | 19 | def __init__( 20 | self, 21 | idim: int, 22 | hidden_units: int, 23 | dropout_rate: float, 24 | activation: torch.nn.Module = torch.nn.ReLU(), 25 | ): 26 | """Construct a PositionwiseFeedForward object.""" 27 | super(PositionwiseFeedForward, self).__init__() 28 | self.w_1 = torch.nn.Linear(idim, hidden_units) 29 | self.activation = activation 30 | self.dropout = torch.nn.Dropout(dropout_rate) 31 | self.w_2 = torch.nn.Linear(hidden_units, idim) 32 | 33 | def forward(self, xs: torch.Tensor) -> torch.Tensor: 34 | """Forward function. 35 | 36 | Args: 37 | xs: input tensor (B, L, D) 38 | Returns: 39 | output tensor, (B, L, D) 40 | """ 41 | return self.w_2(self.dropout(self.activation(self.w_1(xs)))) 42 | 43 | 44 | class MoEFFNLayer(torch.nn.Module): 45 | """ 46 | Mixture of expert with Positionwise feed forward layer 47 | See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf 48 | The output dim is same with the input dim. 49 | 50 | Modified from https://github.com/Lightning-AI/lit-gpt/pull/823 51 | https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219 52 | Args: 53 | n_expert: number of expert. 54 | n_expert_per_token: The actual number of experts used for each frame 55 | idim (int): Input dimenstion. 56 | hidden_units (int): The number of hidden units. 57 | dropout_rate (float): Dropout rate. 58 | activation (torch.nn.Module): Activation function 59 | """ 60 | 61 | def __init__( 62 | self, 63 | n_expert: int, 64 | n_expert_per_token: int, 65 | idim: int, 66 | hidden_units: int, 67 | dropout_rate: float, 68 | activation: torch.nn.Module = torch.nn.ReLU(), 69 | ): 70 | super(MoEFFNLayer, self).__init__() 71 | self.gate = torch.nn.Linear(idim, n_expert, bias=False) 72 | self.experts = torch.nn.ModuleList( 73 | PositionwiseFeedForward(idim, hidden_units, dropout_rate, 74 | activation) for _ in range(n_expert)) 75 | self.n_expert_per_token = n_expert_per_token 76 | 77 | def forward(self, xs: torch.Tensor) -> torch.Tensor: 78 | """Foward function. 79 | Args: 80 | xs: input tensor (B, L, D) 81 | Returns: 82 | output tensor, (B, L, D) 83 | 84 | """ 85 | B, L, D = xs.size( 86 | ) # batch size, sequence length, embedding dimension (idim) 87 | xs = xs.view(-1, D) # (B*L, D) 88 | router = self.gate(xs) # (B*L, n_expert) 89 | logits, indices = torch.topk( 90 | router, self.n_expert_per_token 91 | ) # probs:(B*L, n_expert), indices: (B*L, n_expert) 92 | weights = torch.nn.functional.softmax( 93 | logits, dim=1, 94 | dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token) 95 | output = torch.zeros_like(xs) # (B*L, D) 96 | for i, expert in enumerate(self.experts): 97 | mask = indices == i 98 | batch_idx, ith_expert = torch.where(mask) 99 | output[batch_idx] += weights[batch_idx, ith_expert, None] * expert( 100 | xs[batch_idx]) 101 | return output.view(B, L, D) 102 | -------------------------------------------------------------------------------- /viettts/transformer/subsampling.py: -------------------------------------------------------------------------------- 1 | # Modified from ESPnet(https://github.com/espnet/espnet) 2 | """Subsampling layer definition.""" 3 | 4 | from typing import Tuple, Union 5 | 6 | import torch 7 | 8 | 9 | class BaseSubsampling(torch.nn.Module): 10 | 11 | def __init__(self): 12 | super().__init__() 13 | self.right_context = 0 14 | self.subsampling_rate = 1 15 | 16 | def position_encoding(self, offset: Union[int, torch.Tensor], 17 | size: int) -> torch.Tensor: 18 | return self.pos_enc.position_encoding(offset, size) 19 | 20 | 21 | class EmbedinigNoSubsampling(BaseSubsampling): 22 | """Embedding input without subsampling 23 | """ 24 | 25 | def __init__(self, idim: int, odim: int, dropout_rate: float, 26 | pos_enc_class: torch.nn.Module): 27 | super().__init__() 28 | self.embed = torch.nn.Embedding(idim, odim) 29 | self.pos_enc = pos_enc_class 30 | 31 | def forward( 32 | self, 33 | x: torch.Tensor, 34 | x_mask: torch.Tensor, 35 | offset: Union[int, torch.Tensor] = 0 36 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 37 | """Input x. 38 | 39 | Args: 40 | x (torch.Tensor): Input tensor (#batch, time, idim). 41 | x_mask (torch.Tensor): Input mask (#batch, 1, time). 42 | 43 | Returns: 44 | torch.Tensor: linear input tensor (#batch, time', odim), 45 | where time' = time . 46 | torch.Tensor: linear input mask (#batch, 1, time'), 47 | where time' = time . 48 | 49 | """ 50 | x = self.embed(x) 51 | x, pos_emb = self.pos_enc(x, offset) 52 | return x, pos_emb, x_mask 53 | 54 | 55 | class LinearNoSubsampling(BaseSubsampling): 56 | """Linear transform the input without subsampling 57 | 58 | Args: 59 | idim (int): Input dimension. 60 | odim (int): Output dimension. 61 | dropout_rate (float): Dropout rate. 62 | 63 | """ 64 | 65 | def __init__(self, idim: int, odim: int, dropout_rate: float, 66 | pos_enc_class: torch.nn.Module): 67 | """Construct an linear object.""" 68 | super().__init__() 69 | self.out = torch.nn.Sequential( 70 | torch.nn.Linear(idim, odim), 71 | torch.nn.LayerNorm(odim, eps=1e-5), 72 | torch.nn.Dropout(dropout_rate), 73 | ) 74 | self.pos_enc = pos_enc_class 75 | self.right_context = 0 76 | self.subsampling_rate = 1 77 | 78 | def forward( 79 | self, 80 | x: torch.Tensor, 81 | x_mask: torch.Tensor, 82 | offset: Union[int, torch.Tensor] = 0 83 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 84 | """Input x. 85 | 86 | Args: 87 | x (torch.Tensor): Input tensor (#batch, time, idim). 88 | x_mask (torch.Tensor): Input mask (#batch, 1, time). 89 | 90 | Returns: 91 | torch.Tensor: linear input tensor (#batch, time', odim), 92 | where time' = time . 93 | torch.Tensor: linear input mask (#batch, 1, time'), 94 | where time' = time . 95 | 96 | """ 97 | x = self.out(x) 98 | x, pos_emb = self.pos_enc(x, offset) 99 | return x, pos_emb, x_mask 100 | 101 | 102 | class Conv1dSubsampling2(BaseSubsampling): 103 | """Convolutional 1D subsampling (to 1/2 length). 104 | It is designed for Whisper, ref: 105 | https://github.com/openai/whisper/blob/main/whisper/model.py 106 | 107 | Args: 108 | idim (int): Input dimension. 109 | odim (int): Output dimension. 110 | dropout_rate (float): Dropout rate. 111 | 112 | """ 113 | 114 | def __init__(self, idim: int, odim: int, dropout_rate: float, 115 | pos_enc_class: torch.nn.Module): 116 | """Construct an Conv1dSubsampling2 object.""" 117 | super().__init__() 118 | self.conv = torch.nn.Sequential( 119 | torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1), 120 | torch.nn.GELU(), 121 | torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1), 122 | torch.nn.GELU(), 123 | ) 124 | self.pos_enc = pos_enc_class 125 | # The right context for every conv layer is computed by: 126 | # (kernel_size - 1) * frame_rate_of_this_layer 127 | self.subsampling_rate = 2 128 | # 4 = (3 - 1) * 1 + (3 - 1) * 1 129 | self.right_context = 4 130 | 131 | def forward( 132 | self, 133 | x: torch.Tensor, 134 | x_mask: torch.Tensor, 135 | offset: Union[int, torch.Tensor] = 0 136 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 137 | """Subsample x. 138 | 139 | Args: 140 | x (torch.Tensor): Input tensor (#batch, time, idim). 141 | x_mask (torch.Tensor): Input mask (#batch, 1, time). 142 | 143 | Returns: 144 | torch.Tensor: Subsampled tensor (#batch, time', odim), 145 | where time' = time // 2. 146 | torch.Tensor: Subsampled mask (#batch, 1, time'), 147 | where time' = time // 2. 148 | torch.Tensor: positional encoding 149 | 150 | """ 151 | time = x.size(1) 152 | x = x.transpose(1, 2) # (b, f, t) 153 | x = self.conv(x) 154 | x = x.transpose(1, 2) # (b, t, f) 155 | x, pos_emb = self.pos_enc(x, offset) 156 | return x, pos_emb, x_mask[:, :, (time + 1) % 2::2] 157 | 158 | 159 | class Conv2dSubsampling4(BaseSubsampling): 160 | """Convolutional 2D subsampling (to 1/4 length). 161 | 162 | Args: 163 | idim (int): Input dimension. 164 | odim (int): Output dimension. 165 | dropout_rate (float): Dropout rate. 166 | 167 | """ 168 | 169 | def __init__(self, idim: int, odim: int, dropout_rate: float, 170 | pos_enc_class: torch.nn.Module): 171 | """Construct an Conv2dSubsampling4 object.""" 172 | super().__init__() 173 | self.conv = torch.nn.Sequential( 174 | torch.nn.Conv2d(1, odim, 3, 2), 175 | torch.nn.ReLU(), 176 | torch.nn.Conv2d(odim, odim, 3, 2), 177 | torch.nn.ReLU(), 178 | ) 179 | self.out = torch.nn.Sequential( 180 | torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)) 181 | self.pos_enc = pos_enc_class 182 | # The right context for every conv layer is computed by: 183 | # (kernel_size - 1) * frame_rate_of_this_layer 184 | self.subsampling_rate = 4 185 | # 6 = (3 - 1) * 1 + (3 - 1) * 2 186 | self.right_context = 6 187 | 188 | def forward( 189 | self, 190 | x: torch.Tensor, 191 | x_mask: torch.Tensor, 192 | offset: Union[int, torch.Tensor] = 0 193 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 194 | """Subsample x. 195 | 196 | Args: 197 | x (torch.Tensor): Input tensor (#batch, time, idim). 198 | x_mask (torch.Tensor): Input mask (#batch, 1, time). 199 | 200 | Returns: 201 | torch.Tensor: Subsampled tensor (#batch, time', odim), 202 | where time' = time // 4. 203 | torch.Tensor: Subsampled mask (#batch, 1, time'), 204 | where time' = time // 4. 205 | torch.Tensor: positional encoding 206 | 207 | """ 208 | x = x.unsqueeze(1) # (b, c=1, t, f) 209 | x = self.conv(x) 210 | b, c, t, f = x.size() 211 | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) 212 | x, pos_emb = self.pos_enc(x, offset) 213 | return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2] 214 | 215 | 216 | class Conv2dSubsampling6(BaseSubsampling): 217 | """Convolutional 2D subsampling (to 1/6 length). 218 | Args: 219 | idim (int): Input dimension. 220 | odim (int): Output dimension. 221 | dropout_rate (float): Dropout rate. 222 | pos_enc (torch.nn.Module): Custom position encoding layer. 223 | """ 224 | 225 | def __init__(self, idim: int, odim: int, dropout_rate: float, 226 | pos_enc_class: torch.nn.Module): 227 | """Construct an Conv2dSubsampling6 object.""" 228 | super().__init__() 229 | self.conv = torch.nn.Sequential( 230 | torch.nn.Conv2d(1, odim, 3, 2), 231 | torch.nn.ReLU(), 232 | torch.nn.Conv2d(odim, odim, 5, 3), 233 | torch.nn.ReLU(), 234 | ) 235 | self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), 236 | odim) 237 | self.pos_enc = pos_enc_class 238 | # 10 = (3 - 1) * 1 + (5 - 1) * 2 239 | self.subsampling_rate = 6 240 | self.right_context = 10 241 | 242 | def forward( 243 | self, 244 | x: torch.Tensor, 245 | x_mask: torch.Tensor, 246 | offset: Union[int, torch.Tensor] = 0 247 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 248 | """Subsample x. 249 | Args: 250 | x (torch.Tensor): Input tensor (#batch, time, idim). 251 | x_mask (torch.Tensor): Input mask (#batch, 1, time). 252 | 253 | Returns: 254 | torch.Tensor: Subsampled tensor (#batch, time', odim), 255 | where time' = time // 6. 256 | torch.Tensor: Subsampled mask (#batch, 1, time'), 257 | where time' = time // 6. 258 | torch.Tensor: positional encoding 259 | """ 260 | x = x.unsqueeze(1) # (b, c, t, f) 261 | x = self.conv(x) 262 | b, c, t, f = x.size() 263 | x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f)) 264 | x, pos_emb = self.pos_enc(x, offset) 265 | return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3] 266 | 267 | 268 | class Conv2dSubsampling8(BaseSubsampling): 269 | """Convolutional 2D subsampling (to 1/8 length). 270 | 271 | Args: 272 | idim (int): Input dimension. 273 | odim (int): Output dimension. 274 | dropout_rate (float): Dropout rate. 275 | 276 | """ 277 | 278 | def __init__(self, idim: int, odim: int, dropout_rate: float, 279 | pos_enc_class: torch.nn.Module): 280 | """Construct an Conv2dSubsampling8 object.""" 281 | super().__init__() 282 | self.conv = torch.nn.Sequential( 283 | torch.nn.Conv2d(1, odim, 3, 2), 284 | torch.nn.ReLU(), 285 | torch.nn.Conv2d(odim, odim, 3, 2), 286 | torch.nn.ReLU(), 287 | torch.nn.Conv2d(odim, odim, 3, 2), 288 | torch.nn.ReLU(), 289 | ) 290 | self.linear = torch.nn.Linear( 291 | odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim) 292 | self.pos_enc = pos_enc_class 293 | self.subsampling_rate = 8 294 | # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4 295 | self.right_context = 14 296 | 297 | def forward( 298 | self, 299 | x: torch.Tensor, 300 | x_mask: torch.Tensor, 301 | offset: Union[int, torch.Tensor] = 0 302 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 303 | """Subsample x. 304 | 305 | Args: 306 | x (torch.Tensor): Input tensor (#batch, time, idim). 307 | x_mask (torch.Tensor): Input mask (#batch, 1, time). 308 | 309 | Returns: 310 | torch.Tensor: Subsampled tensor (#batch, time', odim), 311 | where time' = time // 8. 312 | torch.Tensor: Subsampled mask (#batch, 1, time'), 313 | where time' = time // 8. 314 | torch.Tensor: positional encoding 315 | """ 316 | x = x.unsqueeze(1) # (b, c, t, f) 317 | x = self.conv(x) 318 | b, c, t, f = x.size() 319 | x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f)) 320 | x, pos_emb = self.pos_enc(x, offset) 321 | return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2] 322 | 323 | 324 | class LegacyLinearNoSubsampling(BaseSubsampling): 325 | """Linear transform the input without subsampling 326 | 327 | Args: 328 | idim (int): Input dimension. 329 | odim (int): Output dimension. 330 | dropout_rate (float): Dropout rate. 331 | 332 | """ 333 | 334 | def __init__(self, idim: int, odim: int, dropout_rate: float, 335 | pos_enc_class: torch.nn.Module): 336 | """Construct an linear object.""" 337 | super().__init__() 338 | self.out = torch.nn.Sequential( 339 | torch.nn.Linear(idim, odim), 340 | torch.nn.LayerNorm(odim, eps=1e-5), 341 | torch.nn.Dropout(dropout_rate), 342 | torch.nn.ReLU(), 343 | ) 344 | self.pos_enc = pos_enc_class 345 | self.right_context = 0 346 | self.subsampling_rate = 1 347 | 348 | def forward( 349 | self, 350 | x: torch.Tensor, 351 | x_mask: torch.Tensor, 352 | offset: Union[int, torch.Tensor] = 0 353 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 354 | """Input x. 355 | 356 | Args: 357 | x (torch.Tensor): Input tensor (#batch, time, idim). 358 | x_mask (torch.Tensor): Input mask (#batch, 1, time). 359 | 360 | Returns: 361 | torch.Tensor: linear input tensor (#batch, time', odim), 362 | where time' = time . 363 | torch.Tensor: linear input mask (#batch, 1, time'), 364 | where time' = time . 365 | 366 | """ 367 | x = self.out(x) 368 | x, pos_emb = self.pos_enc(x, offset) 369 | return x, pos_emb, x_mask 370 | -------------------------------------------------------------------------------- /viettts/transformer/transformer.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | from diffusers.models.attention import ( 6 | GEGLU, 7 | GELU, 8 | AdaLayerNorm, 9 | AdaLayerNormZero, 10 | ApproximateGELU, 11 | ) 12 | from diffusers.models.attention_processor import Attention 13 | from diffusers.models.lora import LoRACompatibleLinear 14 | from diffusers.utils.torch_utils import maybe_allow_in_graph 15 | 16 | 17 | class SnakeBeta(nn.Module): 18 | """ 19 | A modified Snake function which uses separate parameters for the magnitude of the periodic components 20 | Shape: 21 | - Input: (B, C, T) 22 | - Output: (B, C, T), same shape as the input 23 | Parameters: 24 | - alpha - trainable parameter that controls frequency 25 | - beta - trainable parameter that controls magnitude 26 | References: 27 | - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 28 | https://arxiv.org/abs/2006.08195 29 | Examples: 30 | >>> a1 = snakebeta(256) 31 | >>> x = torch.randn(256) 32 | >>> x = a1(x) 33 | """ 34 | 35 | def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): 36 | """ 37 | Initialization. 38 | INPUT: 39 | - in_features: shape of the input 40 | - alpha - trainable parameter that controls frequency 41 | - beta - trainable parameter that controls magnitude 42 | alpha is initialized to 1 by default, higher values = higher-frequency. 43 | beta is initialized to 1 by default, higher values = higher-magnitude. 44 | alpha will be trained along with the rest of your model. 45 | """ 46 | super().__init__() 47 | self.in_features = out_features if isinstance(out_features, list) else [out_features] 48 | self.proj = LoRACompatibleLinear(in_features, out_features) 49 | 50 | # initialize alpha 51 | self.alpha_logscale = alpha_logscale 52 | if self.alpha_logscale: # log scale alphas initialized to zeros 53 | self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha) 54 | self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha) 55 | else: # linear scale alphas initialized to ones 56 | self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha) 57 | self.beta = nn.Parameter(torch.ones(self.in_features) * alpha) 58 | 59 | self.alpha.requires_grad = alpha_trainable 60 | self.beta.requires_grad = alpha_trainable 61 | 62 | self.no_div_by_zero = 0.000000001 63 | 64 | def forward(self, x): 65 | """ 66 | Forward pass of the function. 67 | Applies the function to the input elementwise. 68 | SnakeBeta ∶= x + 1/b * sin^2 (xa) 69 | """ 70 | x = self.proj(x) 71 | if self.alpha_logscale: 72 | alpha = torch.exp(self.alpha) 73 | beta = torch.exp(self.beta) 74 | else: 75 | alpha = self.alpha 76 | beta = self.beta 77 | 78 | x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2) 79 | 80 | return x 81 | 82 | 83 | class FeedForward(nn.Module): 84 | r""" 85 | A feed-forward layer. 86 | 87 | Parameters: 88 | dim (`int`): The number of channels in the input. 89 | dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. 90 | mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. 91 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 92 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 93 | final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. 94 | """ 95 | 96 | def __init__( 97 | self, 98 | dim: int, 99 | dim_out: Optional[int] = None, 100 | mult: int = 4, 101 | dropout: float = 0.0, 102 | activation_fn: str = "geglu", 103 | final_dropout: bool = False, 104 | ): 105 | super().__init__() 106 | inner_dim = int(dim * mult) 107 | dim_out = dim_out if dim_out is not None else dim 108 | 109 | if activation_fn == "gelu": 110 | act_fn = GELU(dim, inner_dim) 111 | if activation_fn == "gelu-approximate": 112 | act_fn = GELU(dim, inner_dim, approximate="tanh") 113 | elif activation_fn == "geglu": 114 | act_fn = GEGLU(dim, inner_dim) 115 | elif activation_fn == "geglu-approximate": 116 | act_fn = ApproximateGELU(dim, inner_dim) 117 | elif activation_fn == "snakebeta": 118 | act_fn = SnakeBeta(dim, inner_dim) 119 | 120 | self.net = nn.ModuleList([]) 121 | # project in 122 | self.net.append(act_fn) 123 | # project dropout 124 | self.net.append(nn.Dropout(dropout)) 125 | # project out 126 | self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) 127 | # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout 128 | if final_dropout: 129 | self.net.append(nn.Dropout(dropout)) 130 | 131 | def forward(self, hidden_states): 132 | for module in self.net: 133 | hidden_states = module(hidden_states) 134 | return hidden_states 135 | 136 | 137 | @maybe_allow_in_graph 138 | class BasicTransformerBlock(nn.Module): 139 | r""" 140 | A basic Transformer block. 141 | 142 | Parameters: 143 | dim (`int`): The number of channels in the input and output. 144 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 145 | attention_head_dim (`int`): The number of channels in each head. 146 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 147 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. 148 | only_cross_attention (`bool`, *optional*): 149 | Whether to use only cross-attention layers. In this case two cross attention layers are used. 150 | double_self_attention (`bool`, *optional*): 151 | Whether to use two self-attention layers. In this case no cross attention layers are used. 152 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 153 | num_embeds_ada_norm (: 154 | obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. 155 | attention_bias (: 156 | obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. 157 | """ 158 | 159 | def __init__( 160 | self, 161 | dim: int, 162 | num_attention_heads: int, 163 | attention_head_dim: int, 164 | dropout=0.0, 165 | cross_attention_dim: Optional[int] = None, 166 | activation_fn: str = "geglu", 167 | num_embeds_ada_norm: Optional[int] = None, 168 | attention_bias: bool = False, 169 | only_cross_attention: bool = False, 170 | double_self_attention: bool = False, 171 | upcast_attention: bool = False, 172 | norm_elementwise_affine: bool = True, 173 | norm_type: str = "layer_norm", 174 | final_dropout: bool = False, 175 | ): 176 | super().__init__() 177 | self.only_cross_attention = only_cross_attention 178 | 179 | self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" 180 | self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" 181 | 182 | if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: 183 | raise ValueError( 184 | f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" 185 | f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." 186 | ) 187 | 188 | # Define 3 blocks. Each block has its own normalization layer. 189 | # 1. Self-Attn 190 | if self.use_ada_layer_norm: 191 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) 192 | elif self.use_ada_layer_norm_zero: 193 | self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) 194 | else: 195 | self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 196 | self.attn1 = Attention( 197 | query_dim=dim, 198 | heads=num_attention_heads, 199 | dim_head=attention_head_dim, 200 | dropout=dropout, 201 | bias=attention_bias, 202 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 203 | upcast_attention=upcast_attention, 204 | ) 205 | 206 | # 2. Cross-Attn 207 | if cross_attention_dim is not None or double_self_attention: 208 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. 209 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during 210 | # the second cross attention block. 211 | self.norm2 = ( 212 | AdaLayerNorm(dim, num_embeds_ada_norm) 213 | if self.use_ada_layer_norm 214 | else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 215 | ) 216 | self.attn2 = Attention( 217 | query_dim=dim, 218 | cross_attention_dim=cross_attention_dim if not double_self_attention else None, 219 | heads=num_attention_heads, 220 | dim_head=attention_head_dim, 221 | dropout=dropout, 222 | bias=attention_bias, 223 | upcast_attention=upcast_attention, 224 | # scale_qk=False, # uncomment this to not to use flash attention 225 | ) # is self-attn if encoder_hidden_states is none 226 | else: 227 | self.norm2 = None 228 | self.attn2 = None 229 | 230 | # 3. Feed-forward 231 | self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 232 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) 233 | 234 | # let chunk size default to None 235 | self._chunk_size = None 236 | self._chunk_dim = 0 237 | 238 | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): 239 | # Sets chunk feed-forward 240 | self._chunk_size = chunk_size 241 | self._chunk_dim = dim 242 | 243 | def forward( 244 | self, 245 | hidden_states: torch.FloatTensor, 246 | attention_mask: Optional[torch.FloatTensor] = None, 247 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 248 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 249 | timestep: Optional[torch.LongTensor] = None, 250 | cross_attention_kwargs: Dict[str, Any] = None, 251 | class_labels: Optional[torch.LongTensor] = None, 252 | ): 253 | # Notice that normalization is always applied before the real computation in the following blocks. 254 | # 1. Self-Attention 255 | if self.use_ada_layer_norm: 256 | norm_hidden_states = self.norm1(hidden_states, timestep) 257 | elif self.use_ada_layer_norm_zero: 258 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 259 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype 260 | ) 261 | else: 262 | norm_hidden_states = self.norm1(hidden_states) 263 | 264 | cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} 265 | 266 | attn_output = self.attn1( 267 | norm_hidden_states, 268 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, 269 | attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask, 270 | **cross_attention_kwargs, 271 | ) 272 | if self.use_ada_layer_norm_zero: 273 | attn_output = gate_msa.unsqueeze(1) * attn_output 274 | hidden_states = attn_output + hidden_states 275 | 276 | # 2. Cross-Attention 277 | if self.attn2 is not None: 278 | norm_hidden_states = ( 279 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) 280 | ) 281 | 282 | attn_output = self.attn2( 283 | norm_hidden_states, 284 | encoder_hidden_states=encoder_hidden_states, 285 | attention_mask=encoder_attention_mask, 286 | **cross_attention_kwargs, 287 | ) 288 | hidden_states = attn_output + hidden_states 289 | 290 | # 3. Feed-forward 291 | norm_hidden_states = self.norm3(hidden_states) 292 | 293 | if self.use_ada_layer_norm_zero: 294 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 295 | 296 | if self._chunk_size is not None: 297 | # "feed_forward_chunk_size" can be used to save memory 298 | if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: 299 | raise ValueError( 300 | f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." 301 | ) 302 | 303 | num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size 304 | ff_output = torch.cat( 305 | [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], 306 | dim=self._chunk_dim, 307 | ) 308 | else: 309 | ff_output = self.ff(norm_hidden_states) 310 | 311 | if self.use_ada_layer_norm_zero: 312 | ff_output = gate_mlp.unsqueeze(1) * ff_output 313 | 314 | hidden_states = ff_output + hidden_states 315 | 316 | return hidden_states -------------------------------------------------------------------------------- /viettts/tts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | from loguru import logger 5 | from hyperpyyaml import load_hyperpyyaml 6 | 7 | from viettts.model import TTSModel 8 | from viettts.frontend import TTSFrontEnd 9 | from viettts.utils.file_utils import download_model, save_wav 10 | 11 | 12 | class TTS: 13 | def __init__( 14 | self, 15 | model_dir, 16 | load_jit=False, 17 | load_onnx=False 18 | ): 19 | if not os.path.exists(model_dir): 20 | logger.info(f'Downloading model from huggingface [dangvansam/viet-tts]') 21 | download_model(model_dir) 22 | 23 | with open(f'{model_dir}/config.yaml', 'r') as f: 24 | configs = load_hyperpyyaml(f) 25 | self.frontend = TTSFrontEnd( 26 | speech_embedding_model=f'{model_dir}/speech_embedding.onnx', 27 | speech_tokenizer_model=f'{model_dir}/speech_tokenizer.onnx' 28 | ) 29 | self.model = TTSModel( 30 | llm=configs['llm'], 31 | flow=configs['flow'], 32 | hift=configs['hift'] 33 | ) 34 | self.model.load( 35 | llm_model=f'{model_dir}/llm.pt', 36 | flow_model=f'{model_dir}/flow.pt', 37 | hift_model=f'{model_dir}/hift.pt' 38 | ) 39 | if load_jit: 40 | self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir), 41 | '{}/llm.llm.fp16.zip'.format(model_dir), 42 | '{}/flow.encoder.fp32.zip'.format(model_dir)) 43 | logger.success('Loaded jit model from {}'.format(model_dir)) 44 | 45 | if load_onnx: 46 | self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir)) 47 | logger.success('Loaded onnx model from {}'.format(model_dir)) 48 | 49 | logger.success('Loaded model from {}'.format(model_dir)) 50 | self.model_dir = model_dir 51 | 52 | def list_avaliable_spks(self): 53 | spks = list(self.frontend.spk2info.keys()) 54 | return spks 55 | 56 | def inference_tts(self, tts_text, prompt_speech_16k, stream=False, speed=1.0): 57 | for i in tqdm(self.frontend.preprocess_text(tts_text, split=True)): 58 | model_input = self.frontend.frontend_tts(i, prompt_speech_16k) 59 | for model_output in self.model.tts(**model_input, stream=stream, speed=speed): 60 | yield model_output 61 | 62 | def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0): 63 | model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k) 64 | for model_output in self.model.vc(**model_input, stream=stream, speed=speed): 65 | yield model_output 66 | 67 | def tts_to_wav(self, text, prompt_speech_16k, speed=1.0): 68 | wavs = [] 69 | for output in self.inference_tts(text, prompt_speech_16k, stream=False, speed=speed): 70 | wavs.append(output['tts_speech'].squeeze(0).numpy()) 71 | return np.concatenate(wavs, axis=0) 72 | 73 | def tts_to_file(self, text, prompt_speech_16k, speed, output_path): 74 | wav = self.tts_to_wav(text, prompt_speech_16k, speed) 75 | save_wav(wav, 22050, output_path) -------------------------------------------------------------------------------- /viettts/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/viettts/utils/__init__.py -------------------------------------------------------------------------------- /viettts/utils/class_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from viettts.transformer.activation import Swish 4 | from viettts.transformer.subsampling import ( 5 | LinearNoSubsampling, 6 | EmbedinigNoSubsampling, 7 | Conv1dSubsampling2, 8 | Conv2dSubsampling4, 9 | Conv2dSubsampling6, 10 | Conv2dSubsampling8, 11 | ) 12 | from viettts.transformer.embedding import ( 13 | PositionalEncoding, 14 | RelPositionalEncoding, 15 | WhisperPositionalEncoding, 16 | LearnablePositionalEncoding, 17 | NoPositionalEncoding 18 | ) 19 | from viettts.transformer.attention import ( 20 | MultiHeadedAttention, 21 | RelPositionMultiHeadedAttention 22 | ) 23 | from viettts.transformer.embedding import EspnetRelPositionalEncoding 24 | from viettts.transformer.subsampling import LegacyLinearNoSubsampling 25 | 26 | 27 | ACTIVATION_CLASSES = { 28 | "hardtanh": torch.nn.Hardtanh, 29 | "tanh": torch.nn.Tanh, 30 | "relu": torch.nn.ReLU, 31 | "selu": torch.nn.SELU, 32 | "swish": getattr(torch.nn, "SiLU", Swish), 33 | "gelu": torch.nn.GELU, 34 | } 35 | 36 | SUBSAMPLE_CLASSES = { 37 | "linear": LinearNoSubsampling, 38 | "linear_legacy": LegacyLinearNoSubsampling, 39 | "embed": EmbedinigNoSubsampling, 40 | "conv1d2": Conv1dSubsampling2, 41 | "conv2d": Conv2dSubsampling4, 42 | "conv2d6": Conv2dSubsampling6, 43 | "conv2d8": Conv2dSubsampling8, 44 | 'paraformer_dummy': torch.nn.Identity 45 | } 46 | 47 | EMB_CLASSES = { 48 | "embed": PositionalEncoding, 49 | "abs_pos": PositionalEncoding, 50 | "rel_pos": RelPositionalEncoding, 51 | "rel_pos_espnet": EspnetRelPositionalEncoding, 52 | "no_pos": NoPositionalEncoding, 53 | "abs_pos_whisper": WhisperPositionalEncoding, 54 | "embed_learnable_pe": LearnablePositionalEncoding, 55 | } 56 | 57 | ATTENTION_CLASSES = { 58 | "selfattn": MultiHeadedAttention, 59 | "rel_selfattn": RelPositionMultiHeadedAttention, 60 | } 61 | -------------------------------------------------------------------------------- /viettts/utils/common.py: -------------------------------------------------------------------------------- 1 | # Modified from ESPnet(https://github.com/espnet/espnet) 2 | """Unility functions for Transformer.""" 3 | 4 | import random 5 | from typing import List 6 | 7 | import numpy as np 8 | import torch 9 | 10 | IGNORE_ID = -1 11 | 12 | 13 | def pad_list(xs: List[torch.Tensor], pad_value: int): 14 | """Perform padding for the list of tensors. 15 | 16 | Args: 17 | xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. 18 | pad_value (float): Value for padding. 19 | 20 | Returns: 21 | Tensor: Padded tensor (B, Tmax, `*`). 22 | 23 | Examples: 24 | >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] 25 | >>> x 26 | [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] 27 | >>> pad_list(x, 0) 28 | tensor([[1., 1., 1., 1.], 29 | [1., 1., 0., 0.], 30 | [1., 0., 0., 0.]]) 31 | 32 | """ 33 | max_len = max([len(item) for item in xs]) 34 | batchs = len(xs) 35 | ndim = xs[0].ndim 36 | if ndim == 1: 37 | pad_res = torch.zeros(batchs, 38 | max_len, 39 | dtype=xs[0].dtype, 40 | device=xs[0].device) 41 | elif ndim == 2: 42 | pad_res = torch.zeros(batchs, 43 | max_len, 44 | xs[0].shape[1], 45 | dtype=xs[0].dtype, 46 | device=xs[0].device) 47 | elif ndim == 3: 48 | pad_res = torch.zeros(batchs, 49 | max_len, 50 | xs[0].shape[1], 51 | xs[0].shape[2], 52 | dtype=xs[0].dtype, 53 | device=xs[0].device) 54 | else: 55 | raise ValueError(f"Unsupported ndim: {ndim}") 56 | pad_res.fill_(pad_value) 57 | for i in range(batchs): 58 | pad_res[i, :len(xs[i])] = xs[i] 59 | return pad_res 60 | 61 | 62 | def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor, 63 | ignore_label: int) -> torch.Tensor: 64 | """Calculate accuracy. 65 | 66 | Args: 67 | pad_outputs (Tensor): Prediction tensors (B * Lmax, D). 68 | pad_targets (LongTensor): Target label tensors (B, Lmax). 69 | ignore_label (int): Ignore label id. 70 | 71 | Returns: 72 | torch.Tensor: Accuracy value (0.0 - 1.0). 73 | 74 | """ 75 | pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1), 76 | pad_outputs.size(1)).argmax(2) 77 | mask = pad_targets != ignore_label 78 | numerator = torch.sum( 79 | pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) 80 | denominator = torch.sum(mask) 81 | return (numerator / denominator).detach() 82 | 83 | 84 | def get_padding(kernel_size, dilation=1): 85 | return int((kernel_size * dilation - dilation) / 2) 86 | 87 | 88 | def init_weights(m, mean=0.0, std=0.01): 89 | classname = m.__class__.__name__ 90 | if classname.find("Conv") != -1: 91 | m.weight.data.normal_(mean, std) 92 | 93 | 94 | # Repetition Aware Sampling in VALL-E 2 95 | def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1): 96 | top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k) 97 | rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item() 98 | if rep_num >= win_size * tau_r: 99 | top_ids = random_sampling(weighted_scores, decoded_tokens, sampling) 100 | return top_ids 101 | 102 | 103 | def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25): 104 | prob, indices = [], [] 105 | cum_prob = 0.0 106 | sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True) 107 | for i in range(len(sorted_idx)): 108 | # sampling both top-p and numbers. 109 | if (cum_prob < top_p or len(prob) <=1) and len(prob) str: 19 | """ 20 | Convert an input audio file to WAV format with the desired sample rate using FFmpeg. 21 | 22 | Args: 23 | input_filepath (str): Path to the input audio file. 24 | target_sr (int): Target sample rate. 25 | 26 | Returns: 27 | str: Path to the converted WAV file. 28 | """ 29 | temp_wav_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") 30 | temp_wav_filepath = temp_wav_file.name 31 | temp_wav_file.close() 32 | 33 | ffmpeg_command = [ 34 | "ffmpeg", "-y", 35 | "-loglevel", "error", 36 | "-i", input_filepath, 37 | "-ar", str(target_sr), 38 | "-ac", "1", 39 | temp_wav_filepath 40 | ] 41 | 42 | result = subprocess.run(ffmpeg_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 43 | if result.returncode != 0: 44 | os.unlink(temp_wav_filepath) 45 | raise RuntimeError(f"FFmpeg conversion failed: {result.stderr.decode()}") 46 | 47 | return temp_wav_filepath 48 | 49 | 50 | def load_wav(filepath: str, target_sr: int): 51 | """ 52 | Load an audio file in any supported format, convert it to WAV, and load as a tensor. 53 | 54 | Args: 55 | filepath (str): Path to the audio file in any format. 56 | target_sr (int): Target sample rate. 57 | 58 | Returns: 59 | Tensor: Loaded audio tensor resampled to the target sample rate. 60 | """ 61 | # Check if the file is already in WAV format 62 | if not filepath.lower().endswith(".wav"): 63 | logger.info(f"Converting {filepath} to WAV format") 64 | filepath = convert_to_wav(filepath, target_sr) 65 | 66 | # Load the WAV file 67 | speech, sample_rate = torchaudio.load(filepath) 68 | speech = speech.mean(dim=0, keepdim=True) # Convert to mono if not already 69 | if sample_rate != target_sr: 70 | assert sample_rate > target_sr, f'WAV sample rate {sample_rate} must be greater than {target_sr}' 71 | speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech) 72 | 73 | return speech 74 | 75 | 76 | def save_wav(wav: np.ndarray, sr: int, filepath: str): 77 | soundfile.write(filepath, wav, sr) 78 | 79 | 80 | def load_prompt_speech_from_file(filepath: str, min_duration: float=3, max_duration: float=5, return_numpy: bool=False): 81 | wav = load_wav(filepath, 16000) 82 | 83 | if wav.abs().max() > 0.9: 84 | wav = wav / wav.abs().max() * 0.9 85 | 86 | wav = get_speech( 87 | audio_input=wav.squeeze(0), 88 | min_duration=min_duration, 89 | max_duration=max_duration, 90 | return_numpy=return_numpy 91 | ) 92 | return wav 93 | 94 | 95 | def load_voices(voice_dir: str): 96 | files = glob(os.path.join(voice_dir, '*.wav')) + glob(os.path.join(voice_dir, '*.mp3')) 97 | voice_name_map = { 98 | os.path.basename(f).split('.')[0]: f 99 | for f in files 100 | } 101 | return voice_name_map 102 | 103 | 104 | def download_model(save_dir: str): 105 | snapshot_download(repo_id="dangvansam/viet-tts", local_dir=save_dir) -------------------------------------------------------------------------------- /viettts/utils/frontend_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import numpy as np 4 | import torch.utils.data 5 | from vinorm import TTSnorm 6 | from librosa.filters import mel as librosa_mel_fn 7 | from scipy.io.wavfile import read 8 | 9 | MAX_WAV_VALUE = 32768.0 10 | 11 | 12 | def remove_urls_and_links(text): 13 | url_pattern = r"http[s]?:\/\/(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+|www\.[a-zA-Z0-9.\/]+" 14 | markdown_image_pattern = r"!\[.*?\]\(http[s]?:\/\/.*?\)" 15 | text = re.sub(markdown_image_pattern, '', text, 0, re.MULTILINE) 16 | text = re.sub(url_pattern, '', text, 0, re.MULTILINE) 17 | return text 18 | 19 | 20 | def remove_emojis(text): 21 | emoji_pattern = re.compile( 22 | "[" 23 | "\U0001F600-\U0001F64F" # emoticons 24 | "\U0001F300-\U0001F5FF" # symbols & pictographs 25 | "\U0001F680-\U0001F6FF" # transport & map symbols 26 | "\U0001F1E0-\U0001F1FF" # flags (iOS) 27 | "\U00002702-\U000027B0" # other miscellaneous symbols 28 | "\U000024C2-\U0001F251" 29 | "\U0001F900-\U0001F9FF" # Supplemental Symbols and Pictographs 30 | "\U0001FA70-\U0001FAFF" # Symbols and Pictographs Extended-A 31 | "\U0001F004-\U0001F0CF" # Mahjong and Playing Cards 32 | "]+", 33 | flags=re.UNICODE 34 | ) 35 | return emoji_pattern.sub(r'', text) 36 | 37 | 38 | def remove_punc(text): 39 | text = (text 40 | .replace('', '') 41 | .replace("..", ".") 42 | .replace("!.", "!") 43 | .replace('!', ".") 44 | .replace("?.", "?") 45 | .replace("?", ".") 46 | .replace(" .", ".") 47 | .replace(" ,", ",") 48 | .replace('"', "") 49 | .replace("'", "") 50 | .replace("AI", "Ây Ai") 51 | .replace("A.I", "Ây Ai") 52 | .replace("$", "") 53 | .replace("(", "") 54 | .replace(")", "") 55 | .replace("**", "") 56 | .replace(" = ", " bằng ") 57 | .replace("#", "") 58 | .replace('\\', '') 59 | .replace('```', '') 60 | .replace('- ', '') 61 | .replace('+ ', '') 62 | .replace(":", "") 63 | .replace(",,", ",") 64 | .replace(", ,", ",") 65 | .replace(",.", ".") 66 | .replace(".,", ".") 67 | .replace("..", ".") 68 | .replace(". .", ".") 69 | ) 70 | text = re.sub(r'\n+', ' ', text) 71 | text = ' '.join([t for t in text.split() if t.strip()]) 72 | text = text.strip() 73 | return text 74 | 75 | 76 | def normalize_text(text: str) -> str: 77 | text = text.strip() 78 | text = remove_urls_and_links(text) 79 | text = remove_emojis(text) 80 | text = remove_punc(text) 81 | text = TTSnorm(text, lower=False) 82 | return text 83 | 84 | 85 | def split_text(text: str, tokenize, token_max_n=80, token_min_n=60, merge_len=20, comma_split=False): 86 | def calc_utt_length(_text: str): 87 | return len(tokenize(_text)) 88 | 89 | def should_merge(_text: str): 90 | return len(tokenize(_text)) < merge_len 91 | 92 | pounc = ['.', '?', '!', ';', ':'] 93 | if comma_split: 94 | pounc.extend([',', ',']) 95 | 96 | if text[-1] not in pounc: 97 | text += "." 98 | 99 | st = 0 100 | utts = [] 101 | for i, c in enumerate(text): 102 | if c in pounc: 103 | if len(text[st: i]) > 0: 104 | utts.append(text[st: i] + c) 105 | if i + 1 < len(text) and text[i + 1] in ['"', '”']: 106 | tmp = utts.pop(-1) 107 | utts.append(tmp + text[i + 1]) 108 | st = i + 2 109 | else: 110 | st = i + 1 111 | 112 | final_utts = [] 113 | cur_utt = "" 114 | for utt in utts: 115 | if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n: 116 | final_utts.append(cur_utt) 117 | cur_utt = "" 118 | cur_utt = cur_utt + utt 119 | if len(cur_utt) > 0: 120 | if should_merge(cur_utt) and len(final_utts) != 0: 121 | final_utts[-1] = final_utts[-1] + cur_utt 122 | else: 123 | final_utts.append(cur_utt) 124 | 125 | final_utts = [utt.strip() for utt in final_utts] 126 | return final_utts 127 | 128 | 129 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 130 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 131 | 132 | 133 | def dynamic_range_decompression(x, C=1): 134 | return np.exp(x) / C 135 | 136 | 137 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 138 | return torch.log(torch.clamp(x, min=clip_val) * C) 139 | 140 | 141 | def dynamic_range_decompression_torch(x, C=1): 142 | return torch.exp(x) / C 143 | 144 | 145 | def spectral_normalize_torch(magnitudes): 146 | output = dynamic_range_compression_torch(magnitudes) 147 | return output 148 | 149 | 150 | def spectral_de_normalize_torch(magnitudes): 151 | output = dynamic_range_decompression_torch(magnitudes) 152 | return output 153 | 154 | 155 | mel_basis = {} 156 | hann_window = {} 157 | 158 | 159 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 160 | if torch.min(y) < -1.0: 161 | print("min value is ", torch.min(y)) 162 | if torch.max(y) > 1.0: 163 | print("max value is ", torch.max(y)) 164 | 165 | global mel_basis, hann_window # pylint: disable=global-statement 166 | if f"{str(fmax)}_{str(y.device)}" not in mel_basis: 167 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) 168 | mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) 169 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 170 | 171 | y = torch.nn.functional.pad( 172 | y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" 173 | ) 174 | y = y.squeeze(1) 175 | 176 | spec = torch.view_as_real( 177 | torch.stft( 178 | y, 179 | n_fft, 180 | hop_length=hop_size, 181 | win_length=win_size, 182 | window=hann_window[str(y.device)], 183 | center=center, 184 | pad_mode="reflect", 185 | normalized=False, 186 | onesided=True, 187 | return_complex=True, 188 | ) 189 | ) 190 | 191 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 192 | 193 | spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) 194 | spec = spectral_normalize_torch(spec) 195 | 196 | return spec 197 | 198 | 199 | # def tokenize(data, get_tokenizer, allowed_special): 200 | # """ Decode text to chars or BPE 201 | # Inplace operation 202 | 203 | # Args: 204 | # data: Iterable[{key, wav, txt, sample_rate}] 205 | 206 | # Returns: 207 | # Iterable[{key, wav, txt, tokens, label, sample_rate}] 208 | # """ 209 | # tokenizer = get_tokenizer() 210 | # for sample in data: 211 | # assert 'text' in sample 212 | # sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special) 213 | # sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special) 214 | -------------------------------------------------------------------------------- /viettts/utils/mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def subsequent_mask( 5 | size: int, 6 | device: torch.device = torch.device("cpu"), 7 | ) -> torch.Tensor: 8 | """Create mask for subsequent steps (size, size). 9 | 10 | This mask is used only in decoder which works in an auto-regressive mode. 11 | This means the current step could only do attention with its left steps. 12 | 13 | In encoder, fully attention is used when streaming is not necessary and 14 | the sequence is not long. In this case, no attention mask is needed. 15 | 16 | When streaming is need, chunk-based attention is used in encoder. See 17 | subsequent_chunk_mask for the chunk-based attention mask. 18 | 19 | Args: 20 | size (int): size of mask 21 | str device (str): "cpu" or "cuda" or torch.Tensor.device 22 | dtype (torch.device): result dtype 23 | 24 | Returns: 25 | torch.Tensor: mask 26 | 27 | Examples: 28 | >>> subsequent_mask(3) 29 | [[1, 0, 0], 30 | [1, 1, 0], 31 | [1, 1, 1]] 32 | """ 33 | arange = torch.arange(size, device=device) 34 | mask = arange.expand(size, size) 35 | arange = arange.unsqueeze(-1) 36 | mask = mask <= arange 37 | return mask 38 | 39 | 40 | def subsequent_chunk_mask( 41 | size: int, 42 | chunk_size: int, 43 | num_left_chunks: int = -1, 44 | device: torch.device = torch.device("cpu"), 45 | ) -> torch.Tensor: 46 | """Create mask for subsequent steps (size, size) with chunk size, 47 | this is for streaming encoder 48 | 49 | Args: 50 | size (int): size of mask 51 | chunk_size (int): size of chunk 52 | num_left_chunks (int): number of left chunks 53 | <0: use full chunk 54 | >=0: use num_left_chunks 55 | device (torch.device): "cpu" or "cuda" or torch.Tensor.device 56 | 57 | Returns: 58 | torch.Tensor: mask 59 | 60 | Examples: 61 | >>> subsequent_chunk_mask(4, 2) 62 | [[1, 1, 0, 0], 63 | [1, 1, 0, 0], 64 | [1, 1, 1, 1], 65 | [1, 1, 1, 1]] 66 | """ 67 | ret = torch.zeros(size, size, device=device, dtype=torch.bool) 68 | for i in range(size): 69 | if num_left_chunks < 0: 70 | start = 0 71 | else: 72 | start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) 73 | ending = min((i // chunk_size + 1) * chunk_size, size) 74 | ret[i, start:ending] = True 75 | return ret 76 | 77 | 78 | def add_optional_chunk_mask(xs: torch.Tensor, 79 | masks: torch.Tensor, 80 | use_dynamic_chunk: bool, 81 | use_dynamic_left_chunk: bool, 82 | decoding_chunk_size: int, 83 | static_chunk_size: int, 84 | num_decoding_left_chunks: int, 85 | enable_full_context: bool = True): 86 | """ Apply optional mask for encoder. 87 | 88 | Args: 89 | xs (torch.Tensor): padded input, (B, L, D), L for max length 90 | mask (torch.Tensor): mask for xs, (B, 1, L) 91 | use_dynamic_chunk (bool): whether to use dynamic chunk or not 92 | use_dynamic_left_chunk (bool): whether to use dynamic left chunk for 93 | training. 94 | decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's 95 | 0: default for training, use random dynamic chunk. 96 | <0: for decoding, use full chunk. 97 | >0: for decoding, use fixed chunk size as set. 98 | static_chunk_size (int): chunk size for static chunk training/decoding 99 | if it's greater than 0, if use_dynamic_chunk is true, 100 | this parameter will be ignored 101 | num_decoding_left_chunks: number of left chunks, this is for decoding, 102 | the chunk size is decoding_chunk_size. 103 | >=0: use num_decoding_left_chunks 104 | <0: use all left chunks 105 | enable_full_context (bool): 106 | True: chunk size is either [1, 25] or full context(max_len) 107 | False: chunk size ~ U[1, 25] 108 | 109 | Returns: 110 | torch.Tensor: chunk mask of the input xs. 111 | """ 112 | # Whether to use chunk mask or not 113 | if use_dynamic_chunk: 114 | max_len = xs.size(1) 115 | if decoding_chunk_size < 0: 116 | chunk_size = max_len 117 | num_left_chunks = -1 118 | elif decoding_chunk_size > 0: 119 | chunk_size = decoding_chunk_size 120 | num_left_chunks = num_decoding_left_chunks 121 | else: 122 | # chunk size is either [1, 25] or full context(max_len). 123 | # Since we use 4 times subsampling and allow up to 1s(100 frames) 124 | # delay, the maximum frame is 100 / 4 = 25. 125 | chunk_size = torch.randint(1, max_len, (1, )).item() 126 | num_left_chunks = -1 127 | if chunk_size > max_len // 2 and enable_full_context: 128 | chunk_size = max_len 129 | else: 130 | chunk_size = chunk_size % 25 + 1 131 | if use_dynamic_left_chunk: 132 | max_left_chunks = (max_len - 1) // chunk_size 133 | num_left_chunks = torch.randint(0, max_left_chunks, 134 | (1, )).item() 135 | chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size, 136 | num_left_chunks, 137 | xs.device) # (L, L) 138 | chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) 139 | chunk_masks = masks & chunk_masks # (B, L, L) 140 | elif static_chunk_size > 0: 141 | num_left_chunks = num_decoding_left_chunks 142 | chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size, 143 | num_left_chunks, 144 | xs.device) # (L, L) 145 | chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) 146 | chunk_masks = masks & chunk_masks # (B, L, L) 147 | else: 148 | chunk_masks = masks 149 | return chunk_masks 150 | 151 | 152 | def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: 153 | """Make mask tensor containing indices of padded part. 154 | 155 | See description of make_non_pad_mask. 156 | 157 | Args: 158 | lengths (torch.Tensor): Batch of lengths (B,). 159 | Returns: 160 | torch.Tensor: Mask tensor containing indices of padded part. 161 | 162 | Examples: 163 | >>> lengths = [5, 3, 2] 164 | >>> make_pad_mask(lengths) 165 | masks = [[0, 0, 0, 0 ,0], 166 | [0, 0, 0, 1, 1], 167 | [0, 0, 1, 1, 1]] 168 | """ 169 | batch_size = lengths.size(0) 170 | max_len = max_len if max_len > 0 else lengths.max().item() 171 | seq_range = torch.arange(0, 172 | max_len, 173 | dtype=torch.int64, 174 | device=lengths.device) 175 | seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) 176 | seq_length_expand = lengths.unsqueeze(-1) 177 | mask = seq_range_expand >= seq_length_expand 178 | return mask 179 | -------------------------------------------------------------------------------- /viettts/utils/vad.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import numpy as np 3 | import torch 4 | from silero_vad import load_silero_vad, read_audio, get_speech_timestamps 5 | 6 | model = load_silero_vad() 7 | 8 | def get_speech( 9 | audio_input: Union[str, np.ndarray, torch.Tensor], 10 | return_numpy: bool=False, 11 | min_duration: float=3, 12 | max_duration: float=5 13 | ) -> Union[torch.Tensor, np.ndarray]: 14 | 15 | if isinstance(audio_input, str): 16 | audio_input = read_audio(audio_input) 17 | speech_timestamps = get_speech_timestamps(audio_input, model) 18 | speech = [audio_input[t['start']:t['end']] \ 19 | for t in speech_timestamps \ 20 | if (t['end'] - t['start']) >= 16000 * min_duration \ 21 | and (t['end'] - t['start']) <= 16000 * max_duration] 22 | if not speech: 23 | speech = audio_input[:int(max_duration*16000)].unsqueeze(0) 24 | else: 25 | speech = speech[0].unsqueeze(0) 26 | if return_numpy: 27 | speech = speech.cpu().numpy() 28 | return speech 29 | 30 | 31 | if __name__ == '__main__': 32 | print(get_speech('samples/diep-chi.wav')) -------------------------------------------------------------------------------- /web/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/web/.gitkeep --------------------------------------------------------------------------------