├── .gitattributes
├── .github
└── workflows
│ ├── docker-image.yml
│ ├── python-package.yml
│ └── python-publish.yml
├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── README_VN.md
├── assets
├── viet-tts-large.png
├── viet-tts-medium.png
└── viet-tts.png
├── docker-compose.yaml
├── pyproject.toml
├── samples
├── cdteam.wav
├── cross_lingual_prompt.wav
├── diep-chi.wav
├── doremon.mp3
├── jack-sparrow.mp3
├── nguyen-ngoc-ngan.wav
├── nsnd-le-chuc.mp3
├── nu-nhe-nhang.wav
├── quynh.wav
├── son-tung-mtp.wav
├── speechify_1.wav
├── speechify_10.wav
├── speechify_11.wav
├── speechify_12.wav
├── speechify_2.wav
├── speechify_3.wav
├── speechify_4.wav
├── speechify_5.wav
├── speechify_6.wav
├── speechify_7.wav
├── speechify_8.wav
├── speechify_9.wav
└── zero_shot_prompt.wav
├── viettts
├── __init__.py
├── cli.py
├── flow
│ ├── decoder.py
│ ├── flow.py
│ ├── flow_matching.py
│ └── length_regulator.py
├── frontend.py
├── hifigan
│ ├── f0_predictor.py
│ └── generator.py
├── llm
│ └── llm.py
├── model.py
├── server.py
├── tokenizer
│ ├── multilingual.tiktoken
│ └── tokenizer.py
├── transformer
│ ├── __init__.py
│ ├── activation.py
│ ├── attention.py
│ ├── convolution.py
│ ├── decoder.py
│ ├── decoder_layer.py
│ ├── embedding.py
│ ├── encoder.py
│ ├── encoder_layer.py
│ ├── label_smoothing_loss.py
│ ├── positionwise_feed_forward.py
│ ├── subsampling.py
│ └── transformer.py
├── tts.py
└── utils
│ ├── __init__.py
│ ├── class_utils.py
│ ├── common.py
│ ├── file_utils.py
│ ├── frontend_utils.py
│ ├── mask.py
│ └── vad.py
└── web
└── .gitkeep
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.7z filter=lfs diff=lfs merge=lfs -text
2 | *.arrow filter=lfs diff=lfs merge=lfs -text
3 | *.bin filter=lfs diff=lfs merge=lfs -text
4 | *.bz2 filter=lfs diff=lfs merge=lfs -text
5 | *.ckpt filter=lfs diff=lfs merge=lfs -text
6 | *.ftz filter=lfs diff=lfs merge=lfs -text
7 | *.gz filter=lfs diff=lfs merge=lfs -text
8 | *.h5 filter=lfs diff=lfs merge=lfs -text
9 | *.joblib filter=lfs diff=lfs merge=lfs -text
10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text
11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text
12 | *.model filter=lfs diff=lfs merge=lfs -text
13 | *.msgpack filter=lfs diff=lfs merge=lfs -text
14 | *.npy filter=lfs diff=lfs merge=lfs -text
15 | *.npz filter=lfs diff=lfs merge=lfs -text
16 | *.onnx filter=lfs diff=lfs merge=lfs -text
17 | *.ot filter=lfs diff=lfs merge=lfs -text
18 | *.parquet filter=lfs diff=lfs merge=lfs -text
19 | *.pb filter=lfs diff=lfs merge=lfs -text
20 | *.pickle filter=lfs diff=lfs merge=lfs -text
21 | *.pkl filter=lfs diff=lfs merge=lfs -text
22 | *.pt filter=lfs diff=lfs merge=lfs -text
23 | *.pth filter=lfs diff=lfs merge=lfs -text
24 | *.rar filter=lfs diff=lfs merge=lfs -text
25 | *.safetensors filter=lfs diff=lfs merge=lfs -text
26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27 | *.tar.* filter=lfs diff=lfs merge=lfs -text
28 | *.tar filter=lfs diff=lfs merge=lfs -text
29 | *.tflite filter=lfs diff=lfs merge=lfs -text
30 | *.tgz filter=lfs diff=lfs merge=lfs -text
31 | *.wasm filter=lfs diff=lfs merge=lfs -text
32 | *.xz filter=lfs diff=lfs merge=lfs -text
33 | *.zip filter=lfs diff=lfs merge=lfs -text
34 | *.zst filter=lfs diff=lfs merge=lfs -text
35 | *tfevents* filter=lfs diff=lfs merge=lfs -text
36 |
--------------------------------------------------------------------------------
/.github/workflows/docker-image.yml:
--------------------------------------------------------------------------------
1 | name: Docker Image CI
2 |
3 | on:
4 | push:
5 | branches: [ "main" ]
6 | pull_request:
7 | branches: [ "main" ]
8 |
9 | jobs:
10 |
11 | build:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v4
17 | - name: Build the Docker image
18 | run: docker build . --file Dockerfile --tag my-image-name:$(date +%s)
19 |
--------------------------------------------------------------------------------
/.github/workflows/python-package.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3 |
4 | name: Python package
5 |
6 | on:
7 | push:
8 | branches: [ "main" ]
9 | pull_request:
10 | branches: [ "main" ]
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ubuntu-latest
16 | strategy:
17 | fail-fast: false
18 | matrix:
19 | python-version: ["3.10", "3.11"]
20 |
21 | steps:
22 | - uses: actions/checkout@v4
23 | - name: Set up Python ${{ matrix.python-version }}
24 | uses: actions/setup-python@v3
25 | with:
26 | python-version: ${{ matrix.python-version }}
27 | - name: Install Python environment with conda
28 | run: |
29 | pip install -e . && pip cache purge
30 | - name: Test viettts CLI
31 | run: |
32 | viettts show-voices
33 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package to PyPI when a release is created
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | permissions:
16 | contents: read
17 |
18 | jobs:
19 | release-build:
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v4
24 |
25 | - uses: actions/setup-python@v5
26 | with:
27 | python-version: "3.x"
28 |
29 | - name: Build release distributions
30 | run: |
31 | # NOTE: put your own distribution build steps here.
32 | python -m pip install build
33 | python -m build
34 |
35 | - name: Upload distributions
36 | uses: actions/upload-artifact@v4
37 | with:
38 | name: release-dists
39 | path: dist/
40 |
41 | pypi-publish:
42 | runs-on: ubuntu-latest
43 | needs:
44 | - release-build
45 | permissions:
46 | # IMPORTANT: this permission is mandatory for trusted publishing
47 | id-token: write
48 |
49 | # Dedicated environments with protections for publishing are strongly recommended.
50 | # For more information, see: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment#deployment-protection-rules
51 | environment:
52 | name: pypi
53 | # OPTIONAL: uncomment and update to include your PyPI project URL in the deployment status:
54 | # url: https://pypi.org/p/YOURPROJECT
55 | #
56 | # ALTERNATIVE: if your GitHub Release name is the PyPI project version string
57 | # ALTERNATIVE: exactly, uncomment the following line instead:
58 | # url: https://pypi.org/project/YOURPROJECT/${{ github.event.release.name }}
59 |
60 | steps:
61 | - name: Retrieve release distributions
62 | uses: actions/download-artifact@v4
63 | with:
64 | name: release-dists
65 | path: dist/
66 |
67 | - name: Publish release distributions to PyPI
68 | uses: pypa/gh-action-pypi-publish@release/v1
69 | with:
70 | packages-dir: dist/
71 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Visual Studio Code files
7 | .vscode
8 | .vs
9 |
10 | # PyCharm files
11 | .idea
12 |
13 | # Eclipse Project settings
14 | *.*project
15 | .settings
16 |
17 | # Sublime Text settings
18 | *.sublime-workspace
19 | *.sublime-project
20 |
21 | # Editor temporaries
22 | *.swn
23 | *.swo
24 | *.swp
25 | *.swm
26 | *~
27 |
28 | # IPython notebook checkpoints
29 | .ipynb_checkpoints
30 |
31 | # macOS dir files
32 | .DS_Store
33 |
34 | exp
35 | data*
36 | raw_wav
37 | tensorboard
38 | **/*build*
39 |
40 | # Clangd files
41 | .cache
42 | .github
43 | compile_commands.json
44 | node_modules
45 |
46 | # train/inference files
47 | *.m4a
48 | *.aac
49 | *.pt
50 | *.egg-info
51 | *dist
52 | *parcel-cache
53 | pretrained-models/*
54 | *_pb2_grpc.py
55 | *_pb2.py
56 | poetry.lock
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime
2 |
3 | ENV TZ=UTC DEBIAN_FRONTEND=noninteractive
4 |
5 | WORKDIR /app
6 |
7 | ENV POETRY_VERSION=1.8.3
8 |
9 | RUN pip install --no-cache-dir poetry==${POETRY_VERSION}
10 |
11 | ENV POETRY_CACHE_DIR=/tmp/poetry_cache
12 | ENV POETRY_NO_INTERACTION=1
13 | ENV POETRY_VIRTUALENVS_IN_PROJECT=true
14 | ENV POETRY_VIRTUALENVS_CREATE=true
15 | ENV POETRY_REQUESTS_TIMEOUT=15
16 |
17 | RUN apt-get update && apt-get install -y --no-install-recommends \
18 | build-essential \
19 | gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev \
20 | ffmpeg \
21 | sox
22 |
23 | RUN apt clean && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
24 |
25 | COPY ./viettts /app/viettts
26 | COPY ./samples /app/samples
27 | COPY ./web /app/web
28 | COPY ./pyproject.toml /app/
29 | COPY ./README.md /app/
30 |
31 | RUN pip install -e . && pip cache purge
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
VietTTS : An Open-Source Vietnamese Text to Speech
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 | **VietTTS** is an open-source toolkit providing the community with a powerful Vietnamese TTS model, capable of natural voice synthesis and robust voice cloning. Designed for effective experimentation, **VietTTS** supports research and application in Vietnamese voice technologies.
20 |
21 | ## ⭐ Key Features
22 | - **TTS**: Text-to-Speech generation with any voice via prompt audio
23 | - **OpenAI-API-compatible**: Compatible with OpenAI's Text-to-Speech API format
24 |
25 | ## 🛠️ Installation
26 |
27 | VietTTS can be installed via a Python installer (Linux only, with Windows and macOS support coming soon) or Docker.
28 |
29 | ### Python Installer (Python>=3.10)
30 | ```bash
31 | git clone https://github.com/dangvansam/viet-tts.git
32 | cd viet-tts
33 |
34 | # (Optional) Install Python environment with conda, you could also use virtualenv
35 | conda create --name viettts python=3.10
36 | conda activate viettts
37 |
38 | # Install
39 | pip install -e . && pip cache purge
40 | ```
41 |
42 | ### Docker
43 |
44 | 1. Install [Docker](https://docs.docker.com/get-docker/), [NVIDIA Driver](https://www.nvidia.com/download/index.aspx), [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html), and [CUDA](https://developer.nvidia.com/cuda-downloads).
45 |
46 | 2. Run the following commands:
47 | ```bash
48 | git clone https://github.com/dangvansam/viet-tts.git
49 | cd viet-tts
50 |
51 | # Build docker images
52 | docker compose build
53 |
54 | # Run with docker-compose - will create server at: http://localhost:8298
55 | docker compose up -d
56 |
57 | # Or run with docker run - will create server at: http://localhost:8298
58 | docker run -itd --gpu=alls -p 8298:8298 -v ./pretrained-models:/app/pretrained-models -n viet-tts-service viet-tts:latest viettts server --host 0.0.0.0 --port 8298
59 | ```
60 |
61 | ## 🚀 Usage
62 |
63 | ### Built-in Voices 🤠
64 | You can use available voices bellow to synthesize speech.
65 |
66 | Expand
67 |
68 | | ID | Voice | Gender | Play Audio |
69 | |-----|-----------------------|--------|--------------------------------------------------|
70 | | 1 | nsnd-le-chuc | 👨 | |
71 | | 2 | speechify_10 | 👩 | |
72 | | 3 | atuan | 👨 | |
73 | | 4 | speechify_11 | 👩 | |
74 | | 5 | cdteam | 👨 | |
75 | | 6 | speechify_12 | 👩 | |
76 | | 7 | cross_lingual_prompt | 👩 | |
77 | | 8 | speechify_2 | 👩 | |
78 | | 9 | diep-chi | 👨 | |
79 | | 10 | speechify_3 | 👩 | |
80 | | 11 | doremon | 👨 | |
81 | | 12 | speechify_4 | 👩 | |
82 | | 13 | jack-sparrow | 👨 | |
83 | | 14 | speechify_5 | 👩 | |
84 | | 15 | nguyen-ngoc-ngan | 👩 | |
85 | | 16 | speechify_6 | 👩 | |
86 | | 17 | nu-nhe-nhang | 👩 | |
87 | | 18 | speechify_7 | 👩 | |
88 | | 19 | quynh | 👩 | |
89 | | 20 | speechify_8 | 👩 | |
90 | | 21 | speechify_9 | 👩 | |
91 | | 22 | son-tung-mtp | 👨 | |
92 | | 23 | zero_shot_prompt | 👩 | |
93 | | 24 | speechify_1 | 👩 | |
94 |
95 |
96 |
97 |
98 |
99 | ### Command Line Interface (CLI)
100 | The VietTTS Command Line Interface (CLI) allows you to quickly generate speech directly from the terminal. Here's how to use it:
101 | ```bash
102 | # Usage
103 | viettts --help
104 |
105 | # Start API Server
106 | viettts server --host 0.0.0.0 --port 8298
107 |
108 | # List all built-in voices
109 | viettts show-voices
110 |
111 | # Synthesize speech from text with built-in voices
112 | viettts synthesis --text "Xin chào" --voice 0 --output test.wav
113 |
114 | # Clone voice from a local audio file
115 | viettts synthesis --text "Xin chào" --voice Download/voice.wav --output cloned.wav
116 | ```
117 |
118 | ### API Client
119 | #### Python (OpenAI Client)
120 | You need to set environment variables for the OpenAI Client:
121 | ```bash
122 | # Set base_url and API key as environment variables
123 | export OPENAI_BASE_URL=http://localhost:8298
124 | export OPENAI_API_KEY=viet-tts # not use in current version
125 | ```
126 | To create speech from input text:
127 | ```python
128 | from pathlib import Path
129 | from openai import OpenAI
130 |
131 | client = OpenAI()
132 |
133 | output_file_path = Path(__file__).parent / "speech.wav"
134 |
135 | with client.audio.speech.with_streaming_response.create(
136 | model='tts-1',
137 | voice='cdteam',
138 | input='Xin chào Việt Nam.',
139 | speed=1.0,
140 | response_format='wav'
141 | ) as response:
142 | response.stream_to_file('a.wav')
143 | ```
144 |
145 | #### CURL
146 | ```bash
147 | # Get all built-in voices
148 | curl --location http://0.0.0.0:8298/v1/voices
149 |
150 | # OpenAI format (bult-in voices)
151 | curl http://localhost:8298/v1/audio/speech \
152 | -H "Authorization: Bearer viet-tts" \
153 | -H "Content-Type: application/json" \
154 | -d '{
155 | "model": "tts-1",
156 | "input": "Xin chào Việt Nam.",
157 | "voice": "son-tung-mtp"
158 | }' \
159 | --output speech.wav
160 |
161 | # API with voice from local file
162 | curl --location http://0.0.0.0:8298/v1/tts \
163 | --form 'text="xin chào"' \
164 | --form 'audio_file=@"/home/viettts/Downloads/voice.mp4"' \
165 | --output speech.wav
166 | ```
167 |
168 | #### Node
169 | ```js
170 | import fs from "fs";
171 | import path from "path";
172 | import OpenAI from "openai";
173 |
174 | const openai = new OpenAI();
175 |
176 | const speechFile = path.resolve("./speech.wav");
177 |
178 | async function main() {
179 | const mp3 = await openai.audio.speech.create({
180 | model: "tts-1",
181 | voice: "1",
182 | input: "Xin chào Việt Nam.",
183 | });
184 | console.log(speechFile);
185 | const buffer = Buffer.from(await mp3.arrayBuffer());
186 | await fs.promises.writeFile(speechFile, buffer);
187 | }
188 | main();
189 | ```
190 |
191 | ## 🙏 Acknowledgement
192 | - 💡 Borrowed code from [Cosyvoice](https://github.com/FunAudioLLM/CosyVoice)
193 | - 🎙️ VAD model from [silero-vad](https://github.com/snakers4/silero-vad)
194 | - 📝 Text normalization with [Vinorm](https://github.com/v-nhandt21/Vinorm)
195 |
196 | ## 📜 License
197 | The **VietTTS** source code is released under the **Apache 2.0 License**. Pre-trained models and audio samples are licensed under the **CC BY-NC License**, based on an in-the-wild dataset. We apologize for any inconvenience this may cause.
198 |
199 | ## ⚠️ Disclaimer
200 | The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.
201 |
202 | ## 💬 Contact
203 | - Facebook: https://fb.com/sam.rngd
204 | - GitHub: https://github.com/dangvansam
205 | - Email: dangvansam98@gmail.com
--------------------------------------------------------------------------------
/README_VN.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
VietTTS : Công cụ chuyển văn bản thành giọng nói tiếng Việt mã nguồn mở
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 | **VietTTS** là một bộ công cụ mã nguồn mở cung cấp mô hình TTS tiếng Việt mạnh mẽ, cho phép tổng hợp giọng nói tự nhiên và tạo giọng nói mới. **VietTTS** hỗ trợ nghiên cứu và ứng dụng trong công nghệ giọng nói tiếng Việt.
18 |
19 | ## ⭐ Tính năng nổi bật
20 | - **TTS**: Tổng hợp giọng nói từ văn bản với bất kỳ giọng nào qua audio mẫu
21 | - **OpenAI-API-compatible**: Tương thích với API Text to Speech OpenAI
22 |
23 | ## 🛠️ Cài đặt
24 | VietTTS có thể được cài đặt qua trình cài đặt Python (chỉ hỗ trợ Linux, Windows và macOS sẽ có trong tương lai) hoặc Docker.
25 |
26 | ### Trình cài đặt Python (Python>=3.10)
27 |
28 | ```bash
29 | git clone https://github.com/dangvansam/viet-tts.git
30 | cd viet-tts
31 |
32 | # (Tùy chọn) Tạo môi trường Python với conda hoặc dùng virtualenv
33 | conda create --name viettts python=3.10
34 | conda activate viettts
35 |
36 | # Cài đặt
37 | pip install -e . && pip cache purge
38 | ```
39 |
40 | ### Docker
41 | 1. Cài đặt [Docker](https://docs.docker.com/get-docker/), [NVIDIA Driver](https://www.nvidia.com/download/index.aspx), [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html), và [CUDA](https://developer.nvidia.com/cuda-downloads).
42 |
43 | 2. Chạy các lệnh sau:
44 | ```bash
45 | git clone https://github.com/dangvansam/viet-tts.git
46 | cd viet-tts
47 |
48 | # Xây dựng hình ảnh docker
49 | docker compose build
50 |
51 | # Chạy bằng docker-compose - tạo server tại: http://localhost:8298
52 | docker compose up -d
53 |
54 | # Chạy bằng docker run - tạo server tại: http://localhost:8298
55 | docker run -itd --gpu=alls -p 8298:8298 -v ./pretrained-models:/app/pretrained-models -n viet-tts-service viet-tts:latest viettts server --host 0.0.0.0 --port 8298
56 | ```
57 |
58 | ## 🚀 Sử dụng
59 |
60 | ### Giọng nói tích hợp 🤠
61 | Bạn có thể sử dụng các giọng nói có sẵn dưới đây để tổng hợp giọng nói.
62 |
63 | Mở rộng
64 |
65 | | ID | Giọng | Giới tính | Phát âm thanh |
66 | |-----|--------------------------|-----------|-------------------------------------------------|
67 | | 1 | nsnd-le-chuc | 👨 | |
68 | | 2 | speechify_10 | 👩 | |
69 | | 3 | atuan | 👨 | |
70 | | 4 | speechify_11 | 👩 | |
71 | | 5 | cdteam | 👨 | |
72 | | 6 | speechify_12 | 👩 | |
73 | | 7 | cross_lingual_prompt | 👩 | |
74 | | 8 | speechify_2 | 👩 | |
75 | | 9 | diep-chi | 👨 | |
76 | | 10 | speechify_3 | 👩 | |
77 | | 11 | doremon | 👨 | |
78 | | 12 | speechify_4 | 👩 | |
79 | | 13 | jack-sparrow | 👨 | |
80 | | 14 | speechify_5 | 👩 | |
81 | | 15 | nguyen-ngoc-ngan | 👩 | |
82 | | 16 | speechify_6 | 👩 | |
83 | | 17 | nu-nhe-nhang | 👩 | |
84 | | 18 | speechify_7 | 👩 | |
85 | | 19 | quynh | 👩 | |
86 | | 20 | speechify_8 | 👩 | |
87 | | 21 | speechify_9 | 👩 | |
88 | | 22 | son-tung-mtp | 👨 | |
89 | | 23 | zero_shot_prompt | 👩 | |
90 | | 24 | speechify_1 | 👩 | |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 | ### Thực thi với lệnh (CLI)
99 |
100 | Giao diện dòng lệnh VietTTS cho phép bạn tạo giọng nói từ terminal. Cách sử dụng:
101 |
102 | ```bash
103 | # Hướng dẫn sử dụng
104 | viettts --help
105 |
106 | # Khởi động API Server
107 | viettts server --host 0.0.0.0 --port 8298
108 |
109 | # Xem tất cả các giọng nói có sẵn
110 | viettts show-voices
111 |
112 | # Tổng hợp giọng nói từ văn bản với giọng có sẵn
113 | viettts synthesis --text "Xin chào" --voice 0 --output test.wav
114 |
115 | # Sao chép giọng từ audio file bất kì
116 | viettts synthesis --text "Xin chào" --voice Download/voice.wav --output cloned.wav
117 | ```
118 |
119 | ### API Client
120 | #### Python (OpenAI Client)
121 | Thiết lập biến môi trường cho OpenAI Client:
122 |
123 | ```bash
124 | # Thiết lập base_url và API key như biến môi trường
125 | export OPENAI_BASE_URL=http://localhost:8298
126 | export OPENAI_API_KEY=viet-tts # không dùng trong phiên bản hiện tại
127 | ```
128 |
129 | Để tạo giọng nói từ văn bản đầu vào:
130 |
131 | ```python
132 | from pathlib import Path
133 | from openai import OpenAI
134 |
135 |
136 |
137 | client = OpenAI()
138 | output_file_path = Path(__file__).parent / "speech.wav"
139 |
140 | with client.audio.speech.with_streaming_response.create(
141 | model='tts-1',
142 | voice='cdteam',
143 | input='Xin chào Việt Nam.',
144 | speed=1.0,
145 | response_format='wav'
146 | ) as response:
147 | response.stream_to_file('a.wav')
148 | ```
149 |
150 | #### CURL
151 | ```bash
152 | # Lấy danh sách giọng có sẵn
153 | curl --location http://0.0.0.0:8298/v1/voices
154 |
155 | # OpenAI API format
156 | curl http://localhost:8298/v1/audio/speech \
157 | -H "Authorization: Bearer viet-tts" \
158 | -H "Content-Type: application/json" \
159 | -d '{
160 | "model": "tts-1",
161 | "input": "Xin chào Việt Nam.",
162 | "voice": "son-tung-mtp"
163 | }' \
164 | --output speech.wav
165 |
166 | # API với giọng từ file local
167 | curl --location http://0.0.0.0:8298/v1/tts \
168 | --form 'text="xin chào"' \
169 | --form 'audio_file=@"/home/viettts/Downloads/voice.mp4"' \
170 | --output speech.wav
171 | ```
172 |
173 | #### Node
174 | ```js
175 | import fs from "fs";
176 | import path from "path";
177 | import OpenAI from "openai";
178 |
179 | const openai = new OpenAI();
180 | const speechFile = path.resolve("./speech.wav");
181 |
182 | async function main() {
183 | const mp3 = await openai.audio.speech.create({
184 | model: "tts-1",
185 | voice: "1",
186 | input: "Xin chào Việt Nam.",
187 | });
188 | console.log(speechFile);
189 | const buffer = Buffer.from(await mp3.arrayBuffer());
190 | await fs.promises.writeFile(speechFile, buffer);
191 | }
192 | main();
193 | ```
194 |
195 | ## 🙏 Mã liên quan
196 | - 💡 Sử dụng mã từ [Cosyvoice](https://github.com/FunAudioLLM/CosyVoice)
197 | - 🎙️ Mô hình VAD từ [silero-vad](https://github.com/snakers4/silero-vad)
198 | - 📝 Chuẩn hóa văn bản với [Vinorm](https://github.com/v-nhandt21/Vinorm)
199 |
200 | ## 📜 Giấy phép
201 | Mã nguồn của **VietTTS** được cấp phép theo **Apache 2.0 License**. Mô hình và mẫu âm thanh huấn luyện được cấp phép theo **CC BY-NC License**, dựa trên tập dữ liệu từ internet. Xin lỗi nếu điều này gây bất tiện.
202 |
203 | ## ⚠️ Tuyên bố miễn trừ trách nhiệm
204 | Nội dung trên chỉ phục vụ mục đích học thuật và nhằm trình bày khả năng kỹ thuật. Một số ví dụ lấy từ internet. Nếu nội dung vi phạm quyền của bạn, vui lòng liên hệ để được gỡ bỏ.
205 |
206 | ## 💬 Liên hệ
207 | - Facebook: https://fb.com/sam.rngd
208 | - GitHub: https://github.com/dangvansam
209 | - Email: dangvansam98@gmail.com
--------------------------------------------------------------------------------
/assets/viet-tts-large.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/assets/viet-tts-large.png
--------------------------------------------------------------------------------
/assets/viet-tts-medium.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/assets/viet-tts-medium.png
--------------------------------------------------------------------------------
/assets/viet-tts.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/assets/viet-tts.png
--------------------------------------------------------------------------------
/docker-compose.yaml:
--------------------------------------------------------------------------------
1 | services:
2 | api:
3 | image: viet-tts:latest
4 | build: .
5 | restart: always
6 | container_name: viet-tts-service
7 | ports:
8 | - 8298:8298
9 | deploy:
10 | resources:
11 | reservations:
12 | devices:
13 | - driver: nvidia
14 | count: 1
15 | capabilities: [gpu]
16 | volumes:
17 | - ./pretrained-models:/app/pretrained-models
18 | command: viettts server --host 0.0.0.0 --port 8298
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "viettts"
3 | version = "0.1.0"
4 | description = "VietTTS: An Open-Source Vietnamese Text to Speech"
5 | authors = ["dangvansam "]
6 | readme = "README.md"
7 |
8 | [tool.poetry.dependencies]
9 | python = "^3.10"
10 | conformer = "0.3.2"
11 | diffusers = "0.27.2"
12 | gradio = "4.32.2"
13 | hydra-core = "1.3.2"
14 | hyperpyyaml = "1.2.2"
15 | librosa = "0.10.2"
16 | omegaconf = "2.3.0"
17 | onnx = "1.16.0"
18 | onnxruntime-gpu = "1.16.0"
19 | protobuf = "4.25"
20 | pydantic = "2.7.0"
21 | soundfile = "0.12.1"
22 | torch = "2.0.1"
23 | torchaudio = "2.0.2"
24 | uvicorn = "0.30.0"
25 | wget = "3.2"
26 | fastapi = "0.111.0"
27 | fastapi-cli = "0.0.4"
28 | loguru = "0.7.2"
29 | vinorm = "^2.0.7"
30 | huggingface-hub = "0.24.7"
31 | click = "^8.1.7"
32 | gunicorn = "^23.0.0"
33 | silero-vad = "^5.1.2"
34 | tiktoken = "^0.8.0"
35 | openai-whisper = "^20240930"
36 |
37 | [tool.poetry.scripts]
38 | viettts = "viettts.cli:cli"
39 |
40 | [build-system]
41 | requires = ["poetry-core"]
42 | build-backend = "poetry.core.masonry.api"
43 |
44 | [tool.setuptools]
45 | packages = ["viettts"]
--------------------------------------------------------------------------------
/samples/cdteam.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/cdteam.wav
--------------------------------------------------------------------------------
/samples/cross_lingual_prompt.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/cross_lingual_prompt.wav
--------------------------------------------------------------------------------
/samples/diep-chi.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/diep-chi.wav
--------------------------------------------------------------------------------
/samples/doremon.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/doremon.mp3
--------------------------------------------------------------------------------
/samples/jack-sparrow.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/jack-sparrow.mp3
--------------------------------------------------------------------------------
/samples/nguyen-ngoc-ngan.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/nguyen-ngoc-ngan.wav
--------------------------------------------------------------------------------
/samples/nsnd-le-chuc.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/nsnd-le-chuc.mp3
--------------------------------------------------------------------------------
/samples/nu-nhe-nhang.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/nu-nhe-nhang.wav
--------------------------------------------------------------------------------
/samples/quynh.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/quynh.wav
--------------------------------------------------------------------------------
/samples/son-tung-mtp.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/son-tung-mtp.wav
--------------------------------------------------------------------------------
/samples/speechify_1.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/speechify_1.wav
--------------------------------------------------------------------------------
/samples/speechify_10.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/speechify_10.wav
--------------------------------------------------------------------------------
/samples/speechify_11.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/speechify_11.wav
--------------------------------------------------------------------------------
/samples/speechify_12.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/speechify_12.wav
--------------------------------------------------------------------------------
/samples/speechify_2.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/speechify_2.wav
--------------------------------------------------------------------------------
/samples/speechify_3.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/speechify_3.wav
--------------------------------------------------------------------------------
/samples/speechify_4.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/speechify_4.wav
--------------------------------------------------------------------------------
/samples/speechify_5.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/speechify_5.wav
--------------------------------------------------------------------------------
/samples/speechify_6.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/speechify_6.wav
--------------------------------------------------------------------------------
/samples/speechify_7.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/speechify_7.wav
--------------------------------------------------------------------------------
/samples/speechify_8.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/speechify_8.wav
--------------------------------------------------------------------------------
/samples/speechify_9.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/speechify_9.wav
--------------------------------------------------------------------------------
/samples/zero_shot_prompt.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/samples/zero_shot_prompt.wav
--------------------------------------------------------------------------------
/viettts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/viettts/__init__.py
--------------------------------------------------------------------------------
/viettts/cli.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import click
5 | import subprocess
6 | from loguru import logger
7 | from rich.table import Table
8 | from rich.console import Console
9 | from viettts.tts import TTS
10 | from viettts.utils.file_utils import load_prompt_speech_from_file, load_voices
11 |
12 |
13 | AUDIO_DIR = 'samples'
14 | MODEL_DIR = 'pretrained-models'
15 |
16 | @click.command('server')
17 | @click.option('-h', '--host', type=str, default='0.0.0.0', help="The host address to bind the server to. Default is '0.0.0.0'.")
18 | @click.option('-p', '--port', type=int, default=8298, help="The port number to bind the server to. Default is 8298.")
19 | @click.option('-w', '--workers', type=int, default=1, help="The number of worker processes to handle requests. Default is 1.")
20 | def start_server(host: str, port: int, workers: int):
21 | """Start API server (OpenAI TTS API compatible).
22 |
23 | Usage: viettts server --host 0.0.0.0 --port 8298 -w 4
24 | """
25 | logger.info("Starting server")
26 | cmd = f'gunicorn viettts.server:app \
27 | -k uvicorn.workers.UvicornWorker \
28 | --bind {host}:{port} \
29 | --workers {workers} \
30 | --max-requests 1000 \
31 | --max-requests-jitter 50 \
32 | --timeout 300 \
33 | --keep-alive 75 \
34 | --graceful-timeout 60'
35 |
36 | subprocess.call(cmd, shell=True, stdout=sys.stdout)
37 |
38 |
39 | @click.command('synthesis')
40 | @click.option('-t', "--text", type=str, required=True, help="The input text to synthesize into speech.")
41 | @click.option('-v', "--voice", type=str, default='1', help="The voice ID or file path to clone the voice from. Default is '1'.")
42 | @click.option('-s', "--speed", type=float, default=1, help="The speed multiplier for the speech. Default is 1 (normal speed).")
43 | @click.option('-o', "--output", type=str, default='output.wav', help="The file path to save the synthesized audio. Default is 'output.wav'.")
44 | def synthesis(text: str, voice: str, speed: float, output: str):
45 | """Synthesis audio from text and save to file.
46 |
47 | Usage: viettts synthesis --text 'Xin chào VietTTS' --voice nu-nhe-nhang --voice 8 --speed 1.2 --output test_nu-nhe-nhang.wav
48 | """
49 | logger.info("Starting synthesis")
50 | st = time.perf_counter()
51 | if not text:
52 | logger.error('text must not empty')
53 | return
54 |
55 | if speed > 2 or speed < 0.5:
56 | logger.error(f'speed must in range 0.5-2.0')
57 | return
58 |
59 | if not os.path.exists(voice):
60 | voice_map = load_voices(AUDIO_DIR)
61 | if voice.isdigit():
62 | voice = list(voice_map.values())[int(voice)]
63 | else:
64 | voice = voice_map.get(voice)
65 |
66 | if not os.path.exists(voice):
67 | logger.error(f'voice is not available. Use --voice or run `viettts show-voices` to get available voices.')
68 | return
69 |
70 | logger.info('Loading model')
71 | tts = TTS(model_dir=MODEL_DIR)
72 |
73 | logger.info('Loading voice')
74 | voice = load_prompt_speech_from_file(voice)
75 |
76 | logger.info('Processing')
77 | tts.tts_to_file(text, voice, speed, output)
78 |
79 | et = time.perf_counter()
80 | logger.success(f"Saved to: {output} [time cost={et-st:.2f}s]")
81 |
82 |
83 | @click.command('show-voices')
84 | def show_voice():
85 | """Print all available voices.
86 |
87 | Usage: viettts show-voices
88 | """
89 | voice_map = load_voices(AUDIO_DIR)
90 | console = Console()
91 | table = Table(show_header=True, header_style="green", show_lines=False)
92 | table.add_column("Voice ID", width=10)
93 | table.add_column("Voice Name", width=30)
94 | table.add_column("File", justify="left")
95 |
96 | for i, (voice_name, voice_path) in enumerate(voice_map.items()):
97 | table.add_row(str(i+1), voice_name, voice_path)
98 |
99 | console.print(table)
100 |
101 |
102 | @click.group()
103 | def cli():
104 | """
105 | VietTTS CLI v0.1.0
106 |
107 | Vietnamese Text To Speech and Voice Clone
108 | License: Apache 2.0 - Author:
109 | """
110 | pass
111 |
112 | cli.add_command(start_server)
113 | cli.add_command(synthesis)
114 | cli.add_command(show_voice)
--------------------------------------------------------------------------------
/viettts/flow/flow.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import random
3 | from typing import Dict, Optional
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn import functional as F
7 | from omegaconf import DictConfig
8 | from viettts.utils.mask import make_pad_mask
9 |
10 |
11 | class MaskedDiffWithXvec(torch.nn.Module):
12 | def __init__(self,
13 | input_size: int = 512,
14 | output_size: int = 80,
15 | spk_embed_dim: int = 192,
16 | output_type: str = "mel",
17 | vocab_size: int = 4096,
18 | input_frame_rate: int = 50,
19 | only_mask_loss: bool = True,
20 | encoder: torch.nn.Module = None,
21 | length_regulator: torch.nn.Module = None,
22 | decoder: torch.nn.Module = None,
23 | decoder_conf: Dict = {
24 | 'in_channels': 240,
25 | 'out_channel': 80,
26 | 'spk_emb_dim': 80,
27 | 'n_spks': 1,
28 | 'cfm_params': DictConfig({
29 | 'sigma_min': 1e-06,
30 | 'solver': 'euler',
31 | 't_scheduler': 'cosine',
32 | 'training_cfg_rate': 0.2,
33 | 'inference_cfg_rate': 0.7,
34 | 'reg_loss_type': 'l1'
35 | }),
36 | 'decoder_params': {
37 | 'channels': [256, 256],
38 | 'dropout': 0.0,
39 | 'attention_head_dim': 64,
40 | 'n_blocks': 4,
41 | 'num_mid_blocks': 12,
42 | 'num_heads': 8,
43 | 'act_fn': 'gelu'
44 | }
45 | },
46 | mel_feat_conf: Dict = {
47 | 'n_fft': 1024,
48 | 'num_mels': 80,
49 | 'sampling_rate': 22050,
50 | 'hop_size': 256,
51 | 'win_size': 1024,
52 | 'fmin': 0,
53 | 'fmax': 8000
54 | }
55 | ):
56 | super().__init__()
57 | self.input_size = input_size
58 | self.output_size = output_size
59 | self.decoder_conf = decoder_conf
60 | self.mel_feat_conf = mel_feat_conf
61 | self.vocab_size = vocab_size
62 | self.output_type = output_type
63 | self.input_frame_rate = input_frame_rate
64 | logging.info(f"input frame rate={self.input_frame_rate}")
65 | self.input_embedding = nn.Embedding(vocab_size, input_size)
66 | self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
67 | self.encoder = encoder
68 | self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
69 | self.decoder = decoder
70 | self.length_regulator = length_regulator
71 | self.only_mask_loss = only_mask_loss
72 |
73 | def forward(
74 | self,
75 | batch: dict,
76 | device: torch.device,
77 | ) -> Dict[str, Optional[torch.Tensor]]:
78 | token = batch['speech_token'].to(device)
79 | token_len = batch['speech_token_len'].to(device)
80 | feat = batch['speech_feat'].to(device)
81 | feat_len = batch['speech_feat_len'].to(device)
82 | embedding = batch['embedding'].to(device)
83 |
84 | # xvec projection
85 | embedding = F.normalize(embedding, dim=1)
86 | embedding = self.spk_embed_affine_layer(embedding)
87 |
88 | # concat text and prompt_text
89 | mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
90 | token = self.input_embedding(torch.clamp(token, min=0)) * mask
91 |
92 | # text encode
93 | h, h_lengths = self.encoder(token, token_len)
94 | h = self.encoder_proj(h)
95 | h, h_lengths = self.length_regulator(h, feat_len)
96 |
97 | # get conditions
98 | conds = torch.zeros(feat.shape, device=token.device)
99 | for i, j in enumerate(feat_len):
100 | if random.random() < 0.5:
101 | continue
102 | index = random.randint(0, int(0.3 * j))
103 | conds[i, :index] = feat[i, :index]
104 | conds = conds.transpose(1, 2)
105 |
106 | mask = (~make_pad_mask(feat_len)).to(h)
107 | feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
108 | loss, _ = self.decoder.compute_loss(
109 | feat.transpose(1, 2).contiguous(),
110 | mask.unsqueeze(1),
111 | h.transpose(1, 2).contiguous(),
112 | embedding,
113 | cond=conds
114 | )
115 | return {'loss': loss}
116 |
117 | @torch.inference_mode()
118 | def inference(self,
119 | token,
120 | token_len,
121 | prompt_token,
122 | prompt_token_len,
123 | prompt_feat,
124 | prompt_feat_len,
125 | embedding):
126 | assert token.shape[0] == 1
127 | # xvec projection
128 | embedding = F.normalize(embedding, dim=1)
129 | embedding = self.spk_embed_affine_layer(embedding)
130 |
131 | # concat text and prompt_text
132 | token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
133 | token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
134 | mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
135 | token = self.input_embedding(torch.clamp(token, min=0)) * mask
136 |
137 | # text encode
138 | h, h_lengths = self.encoder(token, token_len)
139 | h = self.encoder_proj(h)
140 | mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
141 | h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
142 |
143 | # get conditions
144 | conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
145 | conds[:, :mel_len1] = prompt_feat
146 | conds = conds.transpose(1, 2)
147 |
148 | mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
149 | feat = self.decoder(
150 | mu=h.transpose(1, 2).contiguous(),
151 | mask=mask.unsqueeze(1),
152 | spks=embedding,
153 | cond=conds,
154 | n_timesteps=10
155 | )
156 | feat = feat[:, :, mel_len1:]
157 | assert feat.shape[2] == mel_len2
158 | return feat
159 |
--------------------------------------------------------------------------------
/viettts/flow/flow_matching.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from viettts.flow.decoder import Decoder
7 |
8 |
9 | class BASECFM(torch.nn.Module, ABC):
10 | def __init__(
11 | self,
12 | n_feats,
13 | cfm_params,
14 | n_spks=1,
15 | spk_emb_dim=128,
16 | ):
17 | super().__init__()
18 | self.n_feats = n_feats
19 | self.n_spks = n_spks
20 | self.spk_emb_dim = spk_emb_dim
21 | self.solver = cfm_params.solver
22 | if hasattr(cfm_params, "sigma_min"):
23 | self.sigma_min = cfm_params.sigma_min
24 | else:
25 | self.sigma_min = 1e-4
26 |
27 | self.estimator = None
28 |
29 | @torch.inference_mode()
30 | def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
31 | """Forward diffusion
32 |
33 | Args:
34 | mu (torch.Tensor): output of encoder
35 | shape: (batch_size, n_feats, mel_timesteps)
36 | mask (torch.Tensor): output_mask
37 | shape: (batch_size, 1, mel_timesteps)
38 | n_timesteps (int): number of diffusion steps
39 | temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
40 | spks (torch.Tensor, optional): speaker ids. Defaults to None.
41 | shape: (batch_size, spk_emb_dim)
42 | cond: Not used but kept for future purposes
43 |
44 | Returns:
45 | sample: generated mel-spectrogram
46 | shape: (batch_size, n_feats, mel_timesteps)
47 | """
48 | z = torch.randn_like(mu) * temperature
49 | t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
50 | return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
51 |
52 | def solve_euler(self, x, t_span, mu, mask, spks, cond):
53 | """
54 | Fixed euler solver for ODEs.
55 | Args:
56 | x (torch.Tensor): random noise
57 | t_span (torch.Tensor): n_timesteps interpolated
58 | shape: (n_timesteps + 1,)
59 | mu (torch.Tensor): output of encoder
60 | shape: (batch_size, n_feats, mel_timesteps)
61 | mask (torch.Tensor): output_mask
62 | shape: (batch_size, 1, mel_timesteps)
63 | spks (torch.Tensor, optional): speaker ids. Defaults to None.
64 | shape: (batch_size, spk_emb_dim)
65 | cond: Not used but kept for future purposes
66 | """
67 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
68 |
69 | # I am storing this because I can later plot it by putting a debugger here and saving it to a file
70 | # Or in future might add like a return_all_steps flag
71 | sol = []
72 |
73 | for step in range(1, len(t_span)):
74 | dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
75 |
76 | x = x + dt * dphi_dt
77 | t = t + dt
78 | sol.append(x)
79 | if step < len(t_span) - 1:
80 | dt = t_span[step + 1] - t
81 |
82 | return sol[-1]
83 |
84 | def compute_loss(self, x1, mask, mu, spks=None, cond=None):
85 | """Computes diffusion loss
86 |
87 | Args:
88 | x1 (torch.Tensor): Target
89 | shape: (batch_size, n_feats, mel_timesteps)
90 | mask (torch.Tensor): target mask
91 | shape: (batch_size, 1, mel_timesteps)
92 | mu (torch.Tensor): output of encoder
93 | shape: (batch_size, n_feats, mel_timesteps)
94 | spks (torch.Tensor, optional): speaker embedding. Defaults to None.
95 | shape: (batch_size, spk_emb_dim)
96 |
97 | Returns:
98 | loss: conditional flow matching loss
99 | y: conditional flow
100 | shape: (batch_size, n_feats, mel_timesteps)
101 | """
102 | b, _, t = mu.shape
103 |
104 | # random timestep
105 | t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
106 | # sample noise p(x_0)
107 | z = torch.randn_like(x1)
108 |
109 | y = (1 - (1 - self.sigma_min) * t) * z + t * x1
110 | u = x1 - (1 - self.sigma_min) * z
111 |
112 | loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
113 | torch.sum(mask) * u.shape[1]
114 | )
115 | return loss, y
116 |
117 |
118 | class CFM(BASECFM):
119 | def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64):
120 | super().__init__(
121 | n_feats=in_channels,
122 | cfm_params=cfm_params,
123 | n_spks=n_spks,
124 | spk_emb_dim=spk_emb_dim,
125 | )
126 |
127 | in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
128 | # Just change the architecture of the estimator here
129 | self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params)
130 |
131 |
132 | class ConditionalCFM(BASECFM):
133 | def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
134 | super().__init__(
135 | n_feats=in_channels,
136 | cfm_params=cfm_params,
137 | n_spks=n_spks,
138 | spk_emb_dim=spk_emb_dim,
139 | )
140 | self.t_scheduler = cfm_params.t_scheduler
141 | self.training_cfg_rate = cfm_params.training_cfg_rate
142 | self.inference_cfg_rate = cfm_params.inference_cfg_rate
143 | in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
144 | # Just change the architecture of the estimator here
145 | self.estimator = estimator
146 |
147 | @torch.inference_mode()
148 | def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
149 | """Forward diffusion
150 |
151 | Args:
152 | mu (torch.Tensor): output of encoder
153 | shape: (batch_size, n_feats, mel_timesteps)
154 | mask (torch.Tensor): output_mask
155 | shape: (batch_size, 1, mel_timesteps)
156 | n_timesteps (int): number of diffusion steps
157 | temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
158 | spks (torch.Tensor, optional): speaker ids. Defaults to None.
159 | shape: (batch_size, spk_emb_dim)
160 | cond: Not used but kept for future purposes
161 |
162 | Returns:
163 | sample: generated mel-spectrogram
164 | shape: (batch_size, n_feats, mel_timesteps)
165 | """
166 | z = torch.randn_like(mu) * temperature
167 | t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
168 | if self.t_scheduler == 'cosine':
169 | t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
170 | return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
171 |
172 | def solve_euler(self, x, t_span, mu, mask, spks, cond):
173 | """
174 | Fixed euler solver for ODEs.
175 | Args:
176 | x (torch.Tensor): random noise
177 | t_span (torch.Tensor): n_timesteps interpolated
178 | shape: (n_timesteps + 1,)
179 | mu (torch.Tensor): output of encoder
180 | shape: (batch_size, n_feats, mel_timesteps)
181 | mask (torch.Tensor): output_mask
182 | shape: (batch_size, 1, mel_timesteps)
183 | spks (torch.Tensor, optional): speaker ids. Defaults to None.
184 | shape: (batch_size, spk_emb_dim)
185 | cond: Not used but kept for future purposes
186 | """
187 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
188 | t = t.unsqueeze(dim=0)
189 |
190 | # I am storing this because I can later plot it by putting a debugger here and saving it to a file
191 | # Or in future might add like a return_all_steps flag
192 | sol = []
193 |
194 | for step in range(1, len(t_span)):
195 | dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
196 | # Classifier-Free Guidance inference introduced in VoiceBox
197 | if self.inference_cfg_rate > 0:
198 | cfg_dphi_dt = self.forward_estimator(
199 | x, mask,
200 | torch.zeros_like(mu), t,
201 | torch.zeros_like(spks) if spks is not None else None,
202 | torch.zeros_like(cond)
203 | )
204 | dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
205 | self.inference_cfg_rate * cfg_dphi_dt)
206 | x = x + dt * dphi_dt
207 | t = t + dt
208 | sol.append(x)
209 | if step < len(t_span) - 1:
210 | dt = t_span[step + 1] - t
211 |
212 | return sol[-1]
213 |
214 | def forward_estimator(self, x, mask, mu, t, spks, cond):
215 | if isinstance(self.estimator, torch.nn.Module):
216 | return self.estimator.forward(x, mask, mu, t, spks, cond)
217 | else:
218 | ort_inputs = {
219 | 'x': x.cpu().numpy(),
220 | 'mask': mask.cpu().numpy(),
221 | 'mu': mu.cpu().numpy(),
222 | 't': t.cpu().numpy(),
223 | 'spks': spks.cpu().numpy(),
224 | 'cond': cond.cpu().numpy()
225 | }
226 | output = self.estimator.run(None, ort_inputs)[0]
227 | return torch.tensor(output, dtype=x.dtype, device=x.device)
228 |
229 | def compute_loss(self, x1, mask, mu, spks=None, cond=None):
230 | """Computes diffusion loss
231 |
232 | Args:
233 | x1 (torch.Tensor): Target
234 | shape: (batch_size, n_feats, mel_timesteps)
235 | mask (torch.Tensor): target mask
236 | shape: (batch_size, 1, mel_timesteps)
237 | mu (torch.Tensor): output of encoder
238 | shape: (batch_size, n_feats, mel_timesteps)
239 | spks (torch.Tensor, optional): speaker embedding. Defaults to None.
240 | shape: (batch_size, spk_emb_dim)
241 |
242 | Returns:
243 | loss: conditional flow matching loss
244 | y: conditional flow
245 | shape: (batch_size, n_feats, mel_timesteps)
246 | """
247 | b, _, t = mu.shape
248 |
249 | # random timestep
250 | t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
251 | if self.t_scheduler == 'cosine':
252 | t = 1 - torch.cos(t * 0.5 * torch.pi)
253 | # sample noise p(x_0)
254 | z = torch.randn_like(x1)
255 |
256 | y = (1 - (1 - self.sigma_min) * t) * z + t * x1
257 | u = x1 - (1 - self.sigma_min) * z
258 |
259 | # during training, we randomly drop condition to trade off mode coverage and sample fidelity
260 | if self.training_cfg_rate > 0:
261 | cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
262 | mu = mu * cfg_mask.view(-1, 1, 1)
263 | spks = spks * cfg_mask.view(-1, 1)
264 | cond = cond * cfg_mask.view(-1, 1, 1)
265 |
266 | pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
267 | loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
268 | return loss, y
269 |
--------------------------------------------------------------------------------
/viettts/flow/length_regulator.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | import torch.nn as nn
3 | import torch
4 | from torch.nn import functional as F
5 | from viettts.utils.mask import make_pad_mask
6 |
7 |
8 | class InterpolateRegulator(nn.Module):
9 | def __init__(
10 | self,
11 | channels: int,
12 | sampling_ratios: Tuple,
13 | out_channels: int = None,
14 | groups: int = 1,
15 | ):
16 | super().__init__()
17 | self.sampling_ratios = sampling_ratios
18 | out_channels = out_channels or channels
19 | model = nn.ModuleList([])
20 | if len(sampling_ratios) > 0:
21 | for _ in sampling_ratios:
22 | module = nn.Conv1d(channels, channels, 3, 1, 1)
23 | norm = nn.GroupNorm(groups, channels)
24 | act = nn.Mish()
25 | model.extend([module, norm, act])
26 | model.append(
27 | nn.Conv1d(channels, out_channels, 1, 1)
28 | )
29 | self.model = nn.Sequential(*model)
30 |
31 | def forward(self, x, ylens=None):
32 | # x in (B, T, D)
33 | mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
34 | x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
35 | out = self.model(x).transpose(1, 2).contiguous()
36 | olens = ylens
37 | return out * mask, olens
38 |
39 | def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
40 | # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
41 | # x in (B, T, D)
42 | if x2.shape[1] > 40:
43 | x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
44 | x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
45 | mode='linear')
46 | x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
47 | x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
48 | else:
49 | x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
50 | if x1.shape[1] != 0:
51 | x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
52 | x = torch.concat([x1, x2], dim=2)
53 | else:
54 | x = x2
55 | out = self.model(x).transpose(1, 2).contiguous()
56 | return out, mel_len1 + mel_len2
57 |
--------------------------------------------------------------------------------
/viettts/frontend.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchaudio
4 | import whisper
5 | import onnxruntime
6 | import numpy as np
7 | import torchaudio.compliance.kaldi as kaldi
8 | from typing import Callable, List, Union
9 | from functools import partial
10 | from loguru import logger
11 |
12 | from viettts.utils.frontend_utils import split_text, normalize_text, mel_spectrogram
13 | from viettts.tokenizer.tokenizer import get_tokenizer
14 |
15 | class TTSFrontEnd:
16 | def __init__(
17 | self,
18 | speech_embedding_model: str,
19 | speech_tokenizer_model: str,
20 | ):
21 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22 | self.tokenizer = get_tokenizer()
23 | option = onnxruntime.SessionOptions()
24 | option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
25 | option.intra_op_num_threads = 1
26 | self.speech_embedding_session = onnxruntime.InferenceSession(
27 | speech_embedding_model,
28 | sess_options=option,
29 | providers=["CPUExecutionProvider"]
30 | )
31 | self.speech_tokenizer_session = onnxruntime.InferenceSession(
32 | speech_tokenizer_model,
33 | sess_options=option,
34 | providers=["CUDAExecutionProvider" if torch.cuda.is_available() else "CPUExecutionProvider"]
35 | )
36 | self.spk2info = {}
37 |
38 | def _extract_text_token(self, text: str):
39 | text_token = self.tokenizer.encode(text, allowed_special='all')
40 | text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
41 | text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
42 | return text_token, text_token_len
43 |
44 | def _extract_speech_token(self, speech: torch.Tensor):
45 | if speech.shape[1] / 16000 > 30:
46 | speech = speech[:, :int(16000 * 30)]
47 | feat = whisper.log_mel_spectrogram(speech, n_mels=128)
48 | speech_token = self.speech_tokenizer_session.run(
49 | None,
50 | {self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
51 | self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)}
52 | )[0].flatten().tolist()
53 | speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
54 | speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
55 | return speech_token, speech_token_len
56 |
57 | def _extract_spk_embedding(self, speech: torch.Tensor):
58 | feat = kaldi.fbank(
59 | waveform=speech,
60 | num_mel_bins=80,
61 | dither=0,
62 | sample_frequency=16000
63 | )
64 | feat = feat - feat.mean(dim=0, keepdim=True)
65 | embedding = self.speech_embedding_session.run(
66 | None,
67 | {self.speech_embedding_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()}
68 | )[0].flatten().tolist()
69 | embedding = torch.tensor([embedding]).to(self.device)
70 | return embedding
71 |
72 | def _extract_speech_feat(self, speech: torch.Tensor):
73 | speech_feat = mel_spectrogram(
74 | y=speech,
75 | n_fft=1024,
76 | num_mels=80,
77 | sampling_rate=22050,
78 | hop_size=256,
79 | win_size=1024,
80 | fmin=0,
81 | fmax=8000,
82 | center=False
83 | ).squeeze(dim=0).transpose(0, 1).to(self.device)
84 | speech_feat = speech_feat.unsqueeze(dim=0)
85 | speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
86 | return speech_feat, speech_feat_len
87 |
88 | def preprocess_text(self, text, split=True) -> Union[str, List[str]]:
89 | text = normalize_text(text)
90 | if split:
91 | text = list(split_text(
92 | text=text,
93 | tokenize=partial(self.tokenizer.encode, allowed_special='all'),
94 | token_max_n=30,
95 | token_min_n=10,
96 | merge_len=5,
97 | comma_split=False
98 | ))
99 | return text
100 |
101 | def frontend_tts(
102 | self,
103 | text: str,
104 | prompt_speech_16k: Union[np.ndarray, torch.Tensor]
105 | ) -> dict:
106 | if isinstance(prompt_speech_16k, np.ndarray):
107 | prompt_speech_16k = torch.from_numpy(prompt_speech_16k)
108 |
109 | text_token, text_token_len = self._extract_text_token(text)
110 | speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
111 | prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
112 | speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
113 | embedding = self._extract_spk_embedding(prompt_speech_16k)
114 |
115 | model_input = {
116 | 'text': text_token,
117 | 'text_len': text_token_len,
118 | 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
119 | 'prompt_speech_feat': speech_feat,
120 | 'prompt_speech_feat_len': speech_feat_len,
121 | 'llm_embedding': embedding,
122 | 'flow_embedding': embedding
123 | }
124 | return model_input
125 |
126 |
127 | def frontend_vc(
128 | self,
129 | source_speech_16k: Union[np.ndarray, torch.Tensor],
130 | prompt_speech_16k: Union[np.ndarray, torch.Tensor]
131 | ) -> dict:
132 | if isinstance(source_speech_16k, np.ndarray):
133 | source_speech_16k = torch.from_numpy(source_speech_16k)
134 | if isinstance(prompt_speech_16k, np.ndarray):
135 | prompt_speech_16k = torch.from_numpy(prompt_speech_16k)
136 |
137 | prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
138 | prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
139 | prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
140 | embedding = self._extract_spk_embedding(prompt_speech_16k)
141 | source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
142 | model_input = {
143 | 'source_speech_token': source_speech_token,
144 | 'source_speech_token_len': source_speech_token_len,
145 | 'flow_prompt_speech_token': prompt_speech_token,
146 | 'flow_prompt_speech_token_len': prompt_speech_token_len,
147 | 'prompt_speech_feat': prompt_speech_feat,
148 | 'prompt_speech_feat_len': prompt_speech_feat_len,
149 | 'flow_embedding': embedding
150 | }
151 | return model_input
152 |
--------------------------------------------------------------------------------
/viettts/hifigan/f0_predictor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.utils import weight_norm
4 |
5 |
6 | class ConvRNNF0Predictor(nn.Module):
7 | def __init__(self,
8 | num_class: int = 1,
9 | in_channels: int = 80,
10 | cond_channels: int = 512
11 | ):
12 | super().__init__()
13 |
14 | self.num_class = num_class
15 | self.condnet = nn.Sequential(
16 | weight_norm(
17 | nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
18 | ),
19 | nn.ELU(),
20 | weight_norm(
21 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
22 | ),
23 | nn.ELU(),
24 | weight_norm(
25 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
26 | ),
27 | nn.ELU(),
28 | weight_norm(
29 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
30 | ),
31 | nn.ELU(),
32 | weight_norm(
33 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
34 | ),
35 | nn.ELU(),
36 | )
37 | self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
38 |
39 | def forward(self, x: torch.Tensor) -> torch.Tensor:
40 | x = self.condnet(x)
41 | x = x.transpose(1, 2)
42 | return torch.abs(self.classifier(x).squeeze(-1))
43 |
--------------------------------------------------------------------------------
/viettts/llm/llm.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional, Callable, List, Generator
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as F
5 | from torch.nn.utils.rnn import pad_sequence, unpad_sequence
6 | from viettts.utils.common import IGNORE_ID
7 | from viettts.transformer.label_smoothing_loss import LabelSmoothingLoss
8 | from viettts.utils.common import th_accuracy
9 |
10 |
11 | class TransformerLM(torch.nn.Module):
12 | def __init__(
13 | self,
14 | text_encoder_input_size: int,
15 | llm_input_size: int,
16 | llm_output_size: int,
17 | text_token_size: int,
18 | speech_token_size: int,
19 | text_encoder: torch.nn.Module,
20 | llm: torch.nn.Module,
21 | sampling: Callable,
22 | length_normalized_loss: bool = True,
23 | lsm_weight: float = 0.0,
24 | spk_embed_dim: int = 192,
25 | ):
26 | super().__init__()
27 | self.llm_input_size = llm_input_size
28 | self.speech_token_size = speech_token_size
29 | # 1. build text token inputs related modules
30 | self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
31 | self.text_encoder = text_encoder
32 | self.text_encoder_affine_layer = nn.Linear(
33 | self.text_encoder.output_size(),
34 | llm_input_size
35 | )
36 |
37 | # 2. build speech token language model related modules
38 | self.sos_eos = 0
39 | self.task_id = 1
40 | self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
41 | self.llm = llm
42 | self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
43 | self.criterion_ce = LabelSmoothingLoss(
44 | size=speech_token_size + 1,
45 | padding_idx=IGNORE_ID,
46 | smoothing=lsm_weight,
47 | normalize_length=length_normalized_loss,
48 | )
49 |
50 | # 3. [Optional] build speech token related modules
51 | self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
52 | self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
53 |
54 | # 4. sampling method
55 | self.sampling = sampling
56 |
57 | def encode(
58 | self,
59 | text: torch.Tensor,
60 | text_lengths: torch.Tensor,
61 | ):
62 | encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
63 | encoder_out_lens = encoder_mask.squeeze(1).sum(1)
64 | encoder_out = self.text_encoder_affine_layer(encoder_out)
65 | return encoder_out, encoder_out_lens
66 |
67 | def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
68 | text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
69 | speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
70 | lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
71 | for i in range(len(text_token))]
72 | lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
73 | lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
74 | return lm_input, lm_input_len
75 |
76 | def forward(
77 | self,
78 | batch: dict,
79 | device: torch.device,
80 | ) -> Dict[str, Optional[torch.Tensor]]:
81 | """
82 | Args:
83 | text: (B, L, D)
84 | text_lengths: (B,)
85 | audio: (B, T, N) or (B, T)
86 | audio_lengths: (B,)
87 | """
88 | text_token = batch['text_token'].to(device)
89 | text_token_len = batch['text_token_len'].to(device)
90 | speech_token = batch['speech_token'].to(device)
91 | speech_token_len = batch['speech_token_len'].to(device)
92 | embedding = batch['embedding'].to(device)
93 |
94 | # 1. prepare llm_target
95 | lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
96 | [self.speech_token_size]) for i in range(text_token.size(0))]
97 | lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
98 |
99 | # 1. encode text_token
100 | text_token = self.text_embedding(text_token)
101 | text_token, text_token_len = self.encode(text_token, text_token_len)
102 |
103 | # 2. embedding projection
104 | embedding = F.normalize(embedding, dim=1)
105 | embedding = self.spk_embed_affine_layer(embedding)
106 | embedding = embedding.unsqueeze(1)
107 |
108 | # 3. eos and task_id
109 | sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
110 | task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
111 |
112 | # 4. encode speech_token
113 | speech_token = self.speech_embedding(speech_token)
114 |
115 | # 5. unpad and pad
116 | lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
117 | task_id_emb, speech_token, speech_token_len)
118 |
119 | # 6. run lm forward
120 | lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
121 | logits = self.llm_decoder(lm_output)
122 | loss = self.criterion_ce(logits, lm_target)
123 | acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
124 | return {'loss': loss, 'acc': acc}
125 |
126 | def sampling_ids(
127 | self,
128 | weighted_scores: torch.Tensor,
129 | decoded_tokens: List,
130 | sampling: int,
131 | ignore_eos: bool = True,
132 | ):
133 | while True:
134 | top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
135 | if (not ignore_eos) or (self.speech_token_size not in top_ids):
136 | break
137 | return top_ids
138 |
139 | @torch.inference_mode()
140 | def inference(
141 | self,
142 | text: torch.Tensor,
143 | text_len: torch.Tensor,
144 | prompt_text: torch.Tensor,
145 | prompt_text_len: torch.Tensor,
146 | prompt_speech_token: torch.Tensor,
147 | prompt_speech_token_len: torch.Tensor,
148 | embedding: torch.Tensor,
149 | sampling: int = 25,
150 | max_token_text_ratio: float = 20,
151 | min_token_text_ratio: float = 2,
152 | ) -> Generator[torch.Tensor, None, None]:
153 | device = text.device
154 | text = torch.concat([prompt_text, text], dim=1)
155 | text_len += prompt_text_len
156 | text = self.text_embedding(text)
157 |
158 | # 1. encode text
159 | text, text_len = self.encode(text, text_len)
160 |
161 | # 2. encode embedding
162 | if embedding.shape[0] != 0:
163 | embedding = F.normalize(embedding, dim=1)
164 | embedding = self.spk_embed_affine_layer(embedding)
165 | embedding = embedding.unsqueeze(dim=1)
166 | else:
167 | embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
168 |
169 | # 3. concat llm_input
170 | sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
171 | task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
172 | if prompt_speech_token_len != 0:
173 | prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
174 | else:
175 | prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
176 | lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
177 |
178 | # 4. cal min/max_length
179 | min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
180 | max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
181 |
182 | # 5. step by step decode
183 | out_tokens = []
184 | offset = 0
185 | att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
186 | for i in range(max_len):
187 | y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
188 | att_cache=att_cache, cnn_cache=cnn_cache,
189 | att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
190 | device=lm_input.device)).to(torch.bool))
191 | logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
192 | top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
193 | if top_ids == self.speech_token_size:
194 | break
195 | # in stream mode, yield token one by one
196 | yield top_ids
197 | out_tokens.append(top_ids)
198 | offset += lm_input.size(1)
199 | lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
200 |
--------------------------------------------------------------------------------
/viettts/model.py:
--------------------------------------------------------------------------------
1 | from loguru import logger
2 | import torch
3 | import numpy as np
4 | import threading
5 | import time
6 | from torch.nn import functional as F
7 | from contextlib import nullcontext
8 | import uuid
9 | from viettts.utils.common import fade_in_out_audio
10 |
11 | class TTSModel:
12 | def __init__(
13 | self,
14 | llm: torch.nn.Module,
15 | flow: torch.nn.Module,
16 | hift: torch.nn.Module
17 | ):
18 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19 | self.llm = llm
20 | self.flow = flow
21 | self.hift = hift
22 | self.token_min_hop_len = 2 * self.flow.input_frame_rate
23 | self.token_max_hop_len = 4 * self.flow.input_frame_rate
24 | self.token_overlap_len = 20
25 | # mel fade in out
26 | self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
27 | self.mel_window = np.hamming(2 * self.mel_overlap_len)
28 | # hift cache
29 | self.mel_cache_len = 20
30 | self.source_cache_len = int(self.mel_cache_len * 256)
31 | # speech fade in out
32 | self.speech_window = np.hamming(2 * self.source_cache_len)
33 | # rtf and decoding related
34 | self.stream_scale_factor = 1
35 | assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
36 | self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
37 | self.lock = threading.Lock()
38 | # dict used to store session related variable
39 | self.tts_speech_token_dict = {}
40 | self.llm_end_dict = {}
41 | self.mel_overlap_dict = {}
42 | self.hift_cache_dict = {}
43 |
44 | def load(self, llm_model, flow_model, hift_model):
45 | self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
46 | self.llm.to(self.device).eval()
47 | self.llm.half()
48 | self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
49 | self.flow.to(self.device).eval()
50 | self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
51 | self.hift.to(self.device).eval()
52 |
53 | def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
54 | llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
55 | self.llm.text_encoder = llm_text_encoder
56 | llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
57 | self.llm.llm = llm_llm
58 | flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
59 | self.flow.encoder = flow_encoder
60 |
61 | def load_onnx(self, flow_decoder_estimator_model):
62 | import onnxruntime
63 | option = onnxruntime.SessionOptions()
64 | option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
65 | option.intra_op_num_threads = 1
66 | providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
67 | del self.flow.decoder.estimator
68 | self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
69 |
70 | def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
71 | with self.llm_context:
72 | for i in self.llm.inference(
73 | text=text.to(self.device),
74 | text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
75 | prompt_text=prompt_text.to(self.device),
76 | prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
77 | prompt_speech_token=llm_prompt_speech_token.to(self.device),
78 | prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
79 | embedding=llm_embedding.to(self.device).half()
80 | ):
81 | self.tts_speech_token_dict[uuid].append(i)
82 | self.llm_end_dict[uuid] = True
83 |
84 | def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
85 | tts_mel = self.flow.inference(
86 | token=token.to(self.device),
87 | token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
88 | prompt_token=prompt_token.to(self.device),
89 | prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
90 | prompt_feat=prompt_feat.to(self.device),
91 | prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
92 | embedding=embedding.to(self.device)
93 | )
94 |
95 | if self.hift_cache_dict[uuid] is not None:
96 | hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
97 | tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
98 | else:
99 | hift_cache_source = torch.zeros(1, 1, 0)
100 |
101 | if finalize is False:
102 | self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
103 | tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
104 | tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
105 | self.hift_cache_dict[uuid] = {
106 | 'mel': tts_mel[:, :, -self.mel_cache_len:],
107 | 'source': tts_source[:, :, -self.source_cache_len:],
108 | 'speech': tts_speech[:, -self.source_cache_len:]
109 | }
110 | tts_speech = tts_speech[:, :-self.source_cache_len]
111 | else:
112 | if speed != 1.0:
113 | assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
114 | tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
115 | tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
116 |
117 | tts_speech = fade_in_out_audio(tts_speech)
118 | return tts_speech
119 |
120 | def tts(
121 | self,
122 | text: str,
123 | flow_embedding: torch.Tensor,
124 | llm_embedding: torch.Tensor=torch.zeros(0, 192),
125 | prompt_text: torch.Tensor=torch.zeros(1, 0, dtype=torch.int32),
126 | llm_prompt_speech_token: torch.Tensor=torch.zeros(1, 0, dtype=torch.int32),
127 | flow_prompt_speech_token: torch.Tensor=torch.zeros(1, 0, dtype=torch.int32),
128 | prompt_speech_feat: torch.Tensor=torch.zeros(1, 0, 80),
129 | stream: bool=False,
130 | speed: float=1.0,
131 | **kwargs
132 | ):
133 | # this_uuid is used to track variables related to this inference thread
134 | this_uuid = str(uuid.uuid1())
135 | with self.lock:
136 | self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
137 | self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
138 |
139 | p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
140 | p.start()
141 |
142 | if stream:
143 | token_hop_len = self.token_min_hop_len
144 | while True:
145 | time.sleep(0.01)
146 | if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
147 | this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]).unsqueeze(dim=0)
148 | this_tts_speech = self.token2wav(
149 | token=this_tts_speech_token,
150 | prompt_token=flow_prompt_speech_token,
151 | prompt_feat=prompt_speech_feat,
152 | embedding=flow_embedding,
153 | uuid=this_uuid,
154 | finalize=False
155 | )
156 | yield {'tts_speech': this_tts_speech.cpu()}
157 | with self.lock:
158 | self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
159 | # increase token_hop_len for better speech quality
160 | token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
161 | if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
162 | break
163 | p.join()
164 | this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
165 | this_tts_speech = self.token2wav(
166 | token=this_tts_speech_token,
167 | prompt_token=flow_prompt_speech_token,
168 | prompt_feat=prompt_speech_feat,
169 | embedding=flow_embedding,
170 | uuid=this_uuid,
171 | finalize=True
172 | )
173 | yield {'tts_speech': this_tts_speech.cpu()}
174 | else:
175 | p.join()
176 | this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
177 | this_tts_speech = self.token2wav(
178 | token=this_tts_speech_token,
179 | prompt_token=flow_prompt_speech_token,
180 | prompt_feat=prompt_speech_feat,
181 | embedding=flow_embedding,
182 | uuid=this_uuid,
183 | finalize=True,
184 | speed=speed
185 | )
186 | yield {'tts_speech': this_tts_speech.cpu()}
187 |
188 | with self.lock:
189 | self.tts_speech_token_dict.pop(this_uuid)
190 | self.llm_end_dict.pop(this_uuid)
191 | self.mel_overlap_dict.pop(this_uuid)
192 | self.hift_cache_dict.pop(this_uuid)
193 |
194 | def vc(
195 | self,
196 | source_speech_token: torch.Tensor,
197 | flow_prompt_speech_token: torch.Tensor,
198 | prompt_speech_feat: torch.Tensor,
199 | flow_embedding: torch.Tensor,
200 | stream: bool=False,
201 | speed: float=1.0,
202 | **kwargs
203 | ):
204 | this_uuid = str(uuid.uuid1())
205 | with self.lock:
206 | self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
207 | self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
208 |
209 | if stream:
210 | token_hop_len = self.token_min_hop_len
211 | while True:
212 | if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
213 | this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
214 | .unsqueeze(dim=0)
215 | this_tts_speech = self.token2wav(
216 | token=this_tts_speech_token,
217 | prompt_token=flow_prompt_speech_token,
218 | prompt_feat=prompt_speech_feat,
219 | embedding=flow_embedding,
220 | uuid=this_uuid,
221 | finalize=False
222 | )
223 | yield {'tts_speech': this_tts_speech.cpu()}
224 | with self.lock:
225 | self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
226 | # increase token_hop_len for better speech quality
227 | token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
228 | if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
229 | break
230 |
231 | # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
232 | this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid], dim=1).unsqueeze(dim=0)
233 | this_tts_speech = self.token2wav(
234 | token=this_tts_speech_token,
235 | prompt_token=flow_prompt_speech_token,
236 | prompt_feat=prompt_speech_feat,
237 | embedding=flow_embedding,
238 | uuid=this_uuid,
239 | finalize=True
240 | )
241 | yield {'tts_speech': this_tts_speech.cpu()}
242 | else:
243 | # deal with all tokens
244 | this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
245 | this_tts_speech = self.token2wav(
246 | token=this_tts_speech_token,
247 | prompt_token=flow_prompt_speech_token,
248 | prompt_feat=prompt_speech_feat,
249 | embedding=flow_embedding,
250 | uuid=this_uuid,
251 | finalize=True,
252 | speed=speed
253 | )
254 | yield {'tts_speech': this_tts_speech.cpu()}
255 |
256 | with self.lock:
257 | self.tts_speech_token_dict.pop(this_uuid)
258 | self.llm_end_dict.pop(this_uuid)
259 | self.mel_overlap_dict.pop(this_uuid)
260 | self.hift_cache_dict.pop(this_uuid)
261 |
--------------------------------------------------------------------------------
/viettts/server.py:
--------------------------------------------------------------------------------
1 | import io
2 | import os
3 | import queue
4 | import random
5 | import subprocess
6 | import threading
7 | import wave
8 |
9 | import tempfile
10 | import shutil
11 | import requests
12 | import numpy as np
13 | from loguru import logger
14 | from datetime import datetime
15 | from typing import Any, List, Optional
16 | from pydantic import BaseModel
17 | from anyio import CapacityLimiter
18 | from anyio.lowlevel import RunVar
19 | from fastapi import FastAPI, UploadFile, Form, File, HTTPException
20 | from fastapi.responses import StreamingResponse, JSONResponse, PlainTextResponse, FileResponse
21 | from fastapi.middleware.cors import CORSMiddleware
22 |
23 | from viettts.tts import TTS
24 | from viettts.utils.file_utils import load_prompt_speech_from_file, load_voices
25 |
26 |
27 | VOICE_DIR = 'samples'
28 | VOICE_MAP = load_voices(VOICE_DIR)
29 |
30 | global tts_obj
31 | tts_obj = None
32 |
33 |
34 | app = FastAPI(
35 | title="VietTTS API",
36 | description="""
37 | VietTTS API (https://github.com/dangvansam/viet-tts)
38 | Vietnamese Text To Speech and Voice Clone
39 | License: Apache 2.0 - Author:
40 | """
41 | )
42 | app.add_middleware(
43 | CORSMiddleware,
44 | allow_origins=["*"],
45 | allow_credentials=True,
46 | allow_methods=["*"],
47 | allow_headers=["*"])
48 |
49 |
50 | def generate_data(model_output):
51 | audio = wav_chunk_header()
52 | for i in model_output:
53 | tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16)
54 | tts_audio = tts_audio.tobytes()
55 | audio += tts_audio
56 | yield audio
57 |
58 |
59 | class OpenAITTSRequest(BaseModel):
60 | input: str
61 | model: str = "tts-1"
62 | voice: str = random.choice(list(VOICE_MAP))
63 | response_format: str = "wav"
64 | speed: float = 1.0
65 |
66 | class TTSRequest(BaseModel):
67 | text: str
68 | voice: str = random.choice(list(VOICE_MAP))
69 | speed: float = 1.0
70 |
71 | def wav_chunk_header(sample_rate=22050, bit_depth=16, channels=1):
72 | buffer = io.BytesIO()
73 | with wave.open(buffer, "wb") as wav_file:
74 | wav_file.setnchannels(channels)
75 | wav_file.setsampwidth(bit_depth // 8)
76 | wav_file.setframerate(sample_rate)
77 |
78 | wav_header_bytes = buffer.getvalue()
79 | buffer.close()
80 | return wav_header_bytes
81 |
82 |
83 | @app.get("/", response_class=PlainTextResponse)
84 | async def root():
85 | return 'VietTTS API'
86 |
87 | @app.get("/health", response_class=PlainTextResponse)
88 | async def health():
89 | return 'VietTTS API is running...'
90 |
91 | @app.get("/voices")
92 | @app.get("/v1/voices")
93 | async def show_voices():
94 | return list(VOICE_MAP.keys())
95 |
96 | @app.post("/audio/speech")
97 | @app.post("/v1/audio/speech")
98 | async def openai_api_tts(tts_request: OpenAITTSRequest):
99 | logger.info(f"Received TTS request: {tts_request.dict()}")
100 |
101 | if tts_request.voice.isdigit():
102 | voice_file = list(VOICE_MAP.values())[int(tts_request.voice)]
103 | else:
104 | voice_file = VOICE_MAP.get(tts_request.voice)
105 |
106 | if not voice_file:
107 | logger.error(f"Voice {tts_request.voice} not found")
108 | return PlainTextResponse(content="Voice not found", status_code=404)
109 |
110 | prompt_speech_16k = load_prompt_speech_from_file(
111 | filepath=voice_file,
112 | min_duration=3,
113 | max_duration=5
114 | )
115 | # prompt_speech_16k = fade_in_out_audio(prompt_speech_16k)
116 |
117 | def build_ffmpeg_args(response_format, input_format, sample_rate=24000):
118 | if input_format == 'WAV':
119 | ffmpeg_args = ["ffmpeg", "-loglevel", "error", "-f", "WAV", "-i", "-"]
120 | else:
121 | ffmpeg_args = ["ffmpeg", "-loglevel", "error", "-f", input_format, "-ar", sample_rate, "-ac", "1", "-i", "-"]
122 | if response_format == "mp3":
123 | ffmpeg_args.extend(["-f", "mp3", "-c:a", "libmp3lame", "-ab", "64k"])
124 | elif response_format == "opus":
125 | ffmpeg_args.extend(["-f", "ogg", "-c:a", "libopus"])
126 | elif response_format == "aac":
127 | ffmpeg_args.extend(["-f", "adts", "-c:a", "aac", "-ab", "64k"])
128 | elif response_format == "flac":
129 | ffmpeg_args.extend(["-f", "flac", "-c:a", "flac"])
130 | elif response_format == "wav":
131 | ffmpeg_args.extend(["-f", "wav", "-c:a", "pcm_s16le"])
132 | elif response_format == "pcm": # even though pcm is technically 'raw', we still use ffmpeg to adjust the speed
133 | ffmpeg_args.extend(["-f", "s16le", "-c:a", "pcm_s16le"])
134 | return ffmpeg_args
135 |
136 | def exception_check(exq: queue.Queue):
137 | try:
138 | e = exq.get_nowait()
139 | except queue.Empty:
140 | return
141 | raise e
142 |
143 | if tts_request.response_format == "mp3":
144 | media_type = "audio/mpeg"
145 | elif tts_request.response_format == "opus":
146 | media_type = "audio/ogg;codec=opus" # codecs?
147 | elif tts_request.response_format == "aac":
148 | media_type = "audio/aac"
149 | elif tts_request.response_format == "flac":
150 | media_type = "audio/x-flac"
151 | elif tts_request.response_format == "wav":
152 | media_type = "audio/wav"
153 | elif tts_request.response_format == "pcm":
154 | media_type = "audio/pcm;rate=24000"
155 | else:
156 | raise ValueError(f"Invalid response_format: '{tts_request.response_format}'", param='response_format')
157 |
158 | ffmpeg_args = None
159 | ffmpeg_args = build_ffmpeg_args(tts_request.response_format, input_format="f32le", sample_rate="24000")
160 | ffmpeg_args.extend(["-"])
161 | ffmpeg_proc = subprocess.Popen(ffmpeg_args, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
162 |
163 | in_q = queue.Queue()
164 | ex_q = queue.Queue()
165 |
166 | def generator():
167 | # text -> in_q
168 | try:
169 | model_output = tts_obj.inference_tts(
170 | tts_text=tts_request.input,
171 | prompt_speech_16k=prompt_speech_16k,
172 | speed=tts_request.speed,
173 | stream=False
174 | )
175 | for chunk in model_output:
176 | exception_check(ex_q)
177 | chunk = chunk['tts_speech'].numpy().tobytes()
178 | in_q.put(chunk)
179 | except BrokenPipeError as e:
180 | logger.info("Client disconnected - 'Broken pipe'")
181 | except Exception as e:
182 | logger.error(f"Exception: {repr(e)}")
183 | raise e
184 | finally:
185 | in_q.put(None) # sentinel
186 |
187 | def out_writer():
188 | try:
189 | while True:
190 | chunk = in_q.get()
191 | if chunk is None: # sentinel
192 | break
193 | ffmpeg_proc.stdin.write(chunk) # BrokenPipeError from here on client disconnect
194 | except Exception as e: # BrokenPipeError
195 | ex_q.put(e) # we need to get this exception into the generation loop
196 | ffmpeg_proc.kill()
197 | return
198 | finally:
199 | ffmpeg_proc.stdin.close()
200 |
201 | generator_worker = threading.Thread(target=generator, daemon=True)
202 | generator_worker.start()
203 | out_writer_worker = threading.Thread(target=out_writer, daemon=True)
204 | out_writer_worker.start()
205 |
206 | async def cleanup():
207 | try:
208 | ffmpeg_proc.kill()
209 | # del generator_worker
210 | # del out_writer_worker
211 | except Exception as e:
212 | logger.error(f"Exception: {repr(e)}")
213 |
214 | return StreamingResponse(
215 | content=ffmpeg_proc.stdout,
216 | media_type=media_type,
217 | background=cleanup
218 | )
219 |
220 | @app.post("/tts")
221 | @app.post("/v1/tts")
222 | async def tts(
223 | text: str = Form(...),
224 | voice: str = Form("0"),
225 | speed: float = Form(1.0),
226 | audio_url: str = Form(None),
227 | audio_file: UploadFile = File(None)
228 | ):
229 | logger.info(f"Received TTS request: text={text}, voice={voice}, speed={speed}, audio_url={audio_url}")
230 | voice_file = None
231 |
232 | # Case 1: Uploaded audio file
233 | if audio_file:
234 | temp_audio_file = tempfile.NamedTemporaryFile(
235 | delete=False,
236 | suffix=f'.{audio_file.filename.split(".")[-1]}'
237 | )
238 | try:
239 | with open(temp_audio_file.name, "wb") as temp_file:
240 | shutil.copyfileobj(audio_file.file, temp_file)
241 | voice_file = temp_audio_file.name
242 | logger.info(f"Using uploaded audio file as voice: {voice_file}")
243 | finally:
244 | audio_file.file.close()
245 |
246 | # Case 2: Audio URL
247 | elif audio_url:
248 | temp_audio_file = tempfile.NamedTemporaryFile(
249 | delete=False,
250 | suffix=f'.{audio_url.lower().split(".")[-1]}'
251 | )
252 | try:
253 | response = requests.get(audio_url, stream=True)
254 | if response.status_code != 200:
255 | raise HTTPException(status_code=400, detail="Failed to fetch audio from URL")
256 | with open(temp_audio_file.name, "wb") as temp_file:
257 | shutil.copyfileobj(response.raw, temp_file)
258 | voice_file = temp_audio_file.name
259 | logger.info(f"Using audio URL as voice: {voice_file}")
260 | finally:
261 | response.close()
262 |
263 | # Case 3: Predefined voice
264 | elif voice:
265 | if voice.isdigit():
266 | voice_file = list(VOICE_MAP.values())[int(voice)]
267 | else:
268 | voice_file = VOICE_MAP.get(voice)
269 |
270 | if not voice_file:
271 | logger.error(f"Voice {voice} not found")
272 | raise HTTPException(status_code=404, detail="Voice not found")
273 |
274 | else:
275 | voice_file = random.choice(list(VOICE_MAP.values()))
276 |
277 | # Error if no voice file is available
278 | if not voice_file or not os.path.exists(voice_file):
279 | raise HTTPException(status_code=400, detail="No valid voice file provided")
280 |
281 | prompt_speech_16k = load_prompt_speech_from_file(
282 | filepath=voice_file,
283 | min_duration=3,
284 | max_duration=5
285 | )
286 |
287 | temp_output_file = tempfile.NamedTemporaryFile(
288 | delete=False,
289 | suffix=f"_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp3"
290 | )
291 |
292 | try:
293 | model_output = tts_obj.inference_tts(
294 | tts_text=text,
295 | prompt_speech_16k=prompt_speech_16k,
296 | speed=speed,
297 | stream=False
298 | )
299 |
300 | raw_audio = b''.join(chunk['tts_speech'].numpy().tobytes() for chunk in model_output)
301 | ffmpeg_args = [
302 | "ffmpeg", "-loglevel", "error", "-y", "-f", "f32le", "-ar", "24000", "-ac", "1",
303 | "-i", "-", "-f", "mp3", "-c:a", "libmp3lame", "-ab", "64k", temp_output_file.name
304 | ]
305 | ffmpeg_proc = subprocess.run(
306 | ffmpeg_args,
307 | input=raw_audio,
308 | stdout=subprocess.PIPE,
309 | stderr=subprocess.PIPE
310 | )
311 |
312 | if ffmpeg_proc.returncode != 0:
313 | logger.error(f"FFmpeg error: {ffmpeg_proc.stderr.decode()}")
314 | raise HTTPException(status_code=500, detail="Error during audio processing")
315 |
316 | if not os.path.exists(temp_output_file.name):
317 | logger.error(f"FFmpeg did not create the output file: {temp_output_file.name}")
318 | raise HTTPException(status_code=500, detail="FFmpeg failed to produce the output file")
319 |
320 | return FileResponse(
321 | path=temp_output_file.name,
322 | media_type="audio/mpeg",
323 | filename=temp_output_file.name.split("/")[-1]
324 | )
325 |
326 | finally:
327 | if audio_file or audio_url:
328 | if os.path.exists(temp_audio_file.name):
329 | os.unlink(temp_audio_file.name)
330 |
331 |
332 | @app.on_event("startup")
333 | async def startup():
334 | global tts_obj
335 | RunVar("_default_thread_limiter").set(CapacityLimiter(os.cpu_count()))
336 | tts_obj = TTS('./pretrained-models')
--------------------------------------------------------------------------------
/viettts/tokenizer/tokenizer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import base64
3 | import tiktoken
4 | from functools import lru_cache
5 | from whisper.tokenizer import Tokenizer
6 |
7 | LANGUAGES = {
8 | "en": "english",
9 | "zh": "chinese",
10 | "de": "german",
11 | "es": "spanish",
12 | "ru": "russian",
13 | "ko": "korean",
14 | "fr": "french",
15 | "ja": "japanese",
16 | "pt": "portuguese",
17 | "tr": "turkish",
18 | "pl": "polish",
19 | "ca": "catalan",
20 | "nl": "dutch",
21 | "ar": "arabic",
22 | "sv": "swedish",
23 | "it": "italian",
24 | "id": "indonesian",
25 | "hi": "hindi",
26 | "fi": "finnish",
27 | "vi": "vietnamese",
28 | "he": "hebrew",
29 | "uk": "ukrainian",
30 | "el": "greek",
31 | "ms": "malay",
32 | "cs": "czech",
33 | "ro": "romanian",
34 | "da": "danish",
35 | "hu": "hungarian",
36 | "ta": "tamil",
37 | "no": "norwegian",
38 | "th": "thai",
39 | "ur": "urdu",
40 | "hr": "croatian",
41 | "bg": "bulgarian",
42 | "lt": "lithuanian",
43 | "la": "latin",
44 | "mi": "maori",
45 | "ml": "malayalam",
46 | "cy": "welsh",
47 | "sk": "slovak",
48 | "te": "telugu",
49 | "fa": "persian",
50 | "lv": "latvian",
51 | "bn": "bengali",
52 | "sr": "serbian",
53 | "az": "azerbaijani",
54 | "sl": "slovenian",
55 | "kn": "kannada",
56 | "et": "estonian",
57 | "mk": "macedonian",
58 | "br": "breton",
59 | "eu": "basque",
60 | "is": "icelandic",
61 | "hy": "armenian",
62 | "ne": "nepali",
63 | "mn": "mongolian",
64 | "bs": "bosnian",
65 | "kk": "kazakh",
66 | "sq": "albanian",
67 | "sw": "swahili",
68 | "gl": "galician",
69 | "mr": "marathi",
70 | "pa": "punjabi",
71 | "si": "sinhala",
72 | "km": "khmer",
73 | "sn": "shona",
74 | "yo": "yoruba",
75 | "so": "somali",
76 | "af": "afrikaans",
77 | "oc": "occitan",
78 | "ka": "georgian",
79 | "be": "belarusian",
80 | "tg": "tajik",
81 | "sd": "sindhi",
82 | "gu": "gujarati",
83 | "am": "amharic",
84 | "yi": "yiddish",
85 | "lo": "lao",
86 | "uz": "uzbek",
87 | "fo": "faroese",
88 | "ht": "haitian creole",
89 | "ps": "pashto",
90 | "tk": "turkmen",
91 | "nn": "nynorsk",
92 | "mt": "maltese",
93 | "sa": "sanskrit",
94 | "lb": "luxembourgish",
95 | "my": "myanmar",
96 | "bo": "tibetan",
97 | "tl": "tagalog",
98 | "mg": "malagasy",
99 | "as": "assamese",
100 | "tt": "tatar",
101 | "haw": "hawaiian",
102 | "ln": "lingala",
103 | "ha": "hausa",
104 | "ba": "bashkir",
105 | "jw": "javanese",
106 | "su": "sundanese",
107 | "yue": "cantonese",
108 | "minnan": "minnan",
109 | "wuyu": "wuyu",
110 | "dialect": "dialect",
111 | "zh/en": "zh/en",
112 | "en/zh": "en/zh",
113 | }
114 |
115 | TO_LANGUAGE_CODE = {
116 | **{language: code for code, language in LANGUAGES.items()},
117 | "burmese": "my",
118 | "valencian": "ca",
119 | "flemish": "nl",
120 | "haitian": "ht",
121 | "letzeburgesch": "lb",
122 | "pushto": "ps",
123 | "panjabi": "pa",
124 | "moldavian": "ro",
125 | "moldovan": "ro",
126 | "sinhalese": "si",
127 | "castilian": "es",
128 | "mandarin": "zh",
129 | }
130 |
131 | AUDIO_EVENT = {
132 | "ASR": "ASR",
133 | "AED": "AED",
134 | "SER": "SER",
135 | "Speech": "Speech",
136 | "/Speech": "/Speech",
137 | "BGM": "BGM",
138 | "/BGM": "/BGM",
139 | "Laughter": "Laughter",
140 | "/Laughter": "/Laughter",
141 | "Applause": "Applause",
142 | "/Applause": "/Applause",
143 | }
144 |
145 | EMOTION = {
146 | "HAPPY": "HAPPY",
147 | "SAD": "SAD",
148 | "ANGRY": "ANGRY",
149 | "NEUTRAL": "NEUTRAL",
150 | }
151 |
152 | TTS_Vocal_Token = {
153 | "TTS/B": "TTS/B",
154 | "TTS/O": "TTS/O",
155 | "TTS/Q": "TTS/Q",
156 | "TTS/A": "TTS/A",
157 | "TTS/CO": "TTS/CO",
158 | "TTS/CL": "TTS/CL",
159 | "TTS/H": "TTS/H",
160 | **{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}
161 | }
162 |
163 |
164 | @lru_cache(maxsize=None)
165 | def get_encoding(name: str="gpt2", num_languages: int=99):
166 | vocab_path = os.path.join(os.path.dirname(__file__), f"{name}.tiktoken")
167 | lines = open(vocab_path).readlines()
168 | ranks = {
169 | base64.b64decode(token): int(rank)
170 | for token, rank in (line.split() for line in lines if line)
171 | }
172 | n_vocab = len(ranks)
173 | special_tokens = {}
174 |
175 | specials = [
176 | "<|endoftext|>",
177 | "<|startoftranscript|>",
178 | *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
179 | *[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
180 | *[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
181 | "<|translate|>",
182 | "<|transcribe|>",
183 | "<|startoflm|>",
184 | "<|startofprev|>",
185 | "<|nospeech|>",
186 | "<|notimestamps|>",
187 | *[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR
188 | *[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS
189 | *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
190 | ]
191 |
192 | for token in specials:
193 | special_tokens[token] = n_vocab
194 | n_vocab += 1
195 |
196 | return tiktoken.Encoding(
197 | name=os.path.basename(vocab_path),
198 | explicit_n_vocab=n_vocab,
199 | pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
200 | mergeable_ranks=ranks,
201 | special_tokens=special_tokens,
202 | )
203 |
204 |
205 | @lru_cache(maxsize=None)
206 | def get_tokenizer() -> Tokenizer:
207 | encoding_name = "multilingual"
208 | num_languages = 100
209 | encoding = get_encoding(name=encoding_name, num_languages=num_languages)
210 | return Tokenizer(
211 | encoding=encoding,
212 | num_languages=num_languages,
213 | language='en',
214 | task='transcribe'
215 | )
216 |
217 | if __name__ == "__main__":
218 | tokenizer = get_tokenizer()
219 | print(tokenizer.decode(tokenizer.encode("xin chào Việt Nam, tôi là nam, 1234 1 2?")))
--------------------------------------------------------------------------------
/viettts/transformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/viettts/transformer/__init__.py
--------------------------------------------------------------------------------
/viettts/transformer/activation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, sin, pow
3 | from torch.nn import Parameter
4 |
5 |
6 | class Swish(torch.nn.Module):
7 | """Construct an Swish object."""
8 |
9 | def forward(self, x: torch.Tensor) -> torch.Tensor:
10 | """Return Swish activation function."""
11 | return x * torch.sigmoid(x)
12 |
13 |
14 | # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
15 | # LICENSE is in incl_licenses directory.
16 | class Snake(nn.Module):
17 | '''
18 | Implementation of a sine-based periodic activation function
19 | Shape:
20 | - Input: (B, C, T)
21 | - Output: (B, C, T), same shape as the input
22 | Parameters:
23 | - alpha - trainable parameter
24 | References:
25 | - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
26 | https://arxiv.org/abs/2006.08195
27 | Examples:
28 | >>> a1 = snake(256)
29 | >>> x = torch.randn(256)
30 | >>> x = a1(x)
31 | '''
32 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
33 | '''
34 | Initialization.
35 | INPUT:
36 | - in_features: shape of the input
37 | - alpha: trainable parameter
38 | alpha is initialized to 1 by default, higher values = higher-frequency.
39 | alpha will be trained along with the rest of your model.
40 | '''
41 | super(Snake, self).__init__()
42 | self.in_features = in_features
43 |
44 | # initialize alpha
45 | self.alpha_logscale = alpha_logscale
46 | if self.alpha_logscale: # log scale alphas initialized to zeros
47 | self.alpha = Parameter(torch.zeros(in_features) * alpha)
48 | else: # linear scale alphas initialized to ones
49 | self.alpha = Parameter(torch.ones(in_features) * alpha)
50 |
51 | self.alpha.requires_grad = alpha_trainable
52 |
53 | self.no_div_by_zero = 0.000000001
54 |
55 | def forward(self, x):
56 | '''
57 | Forward pass of the function.
58 | Applies the function to the input elementwise.
59 | Snake ∶= x + 1/a * sin^2 (xa)
60 | '''
61 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
62 | if self.alpha_logscale:
63 | alpha = torch.exp(alpha)
64 | x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
65 |
66 | return x
67 |
--------------------------------------------------------------------------------
/viettts/transformer/convolution.py:
--------------------------------------------------------------------------------
1 | # Modified from ESPnet(https://github.com/espnet/espnet)
2 | """ConvolutionModule definition."""
3 |
4 | from typing import Tuple
5 |
6 | import torch
7 | from torch import nn
8 |
9 |
10 | class ConvolutionModule(nn.Module):
11 | """ConvolutionModule in Conformer model."""
12 |
13 | def __init__(self,
14 | channels: int,
15 | kernel_size: int = 15,
16 | activation: nn.Module = nn.ReLU(),
17 | norm: str = "batch_norm",
18 | causal: bool = False,
19 | bias: bool = True):
20 | """Construct an ConvolutionModule object.
21 | Args:
22 | channels (int): The number of channels of conv layers.
23 | kernel_size (int): Kernel size of conv layers.
24 | causal (int): Whether use causal convolution or not
25 | """
26 | super().__init__()
27 |
28 | self.pointwise_conv1 = nn.Conv1d(
29 | channels,
30 | 2 * channels,
31 | kernel_size=1,
32 | stride=1,
33 | padding=0,
34 | bias=bias,
35 | )
36 | # self.lorder is used to distinguish if it's a causal convolution,
37 | # if self.lorder > 0: it's a causal convolution, the input will be
38 | # padded with self.lorder frames on the left in forward.
39 | # else: it's a symmetrical convolution
40 | if causal:
41 | padding = 0
42 | self.lorder = kernel_size - 1
43 | else:
44 | # kernel_size should be an odd number for none causal convolution
45 | assert (kernel_size - 1) % 2 == 0
46 | padding = (kernel_size - 1) // 2
47 | self.lorder = 0
48 | self.depthwise_conv = nn.Conv1d(
49 | channels,
50 | channels,
51 | kernel_size,
52 | stride=1,
53 | padding=padding,
54 | groups=channels,
55 | bias=bias,
56 | )
57 |
58 | assert norm in ['batch_norm', 'layer_norm']
59 | if norm == "batch_norm":
60 | self.use_layer_norm = False
61 | self.norm = nn.BatchNorm1d(channels)
62 | else:
63 | self.use_layer_norm = True
64 | self.norm = nn.LayerNorm(channels)
65 |
66 | self.pointwise_conv2 = nn.Conv1d(
67 | channels,
68 | channels,
69 | kernel_size=1,
70 | stride=1,
71 | padding=0,
72 | bias=bias,
73 | )
74 | self.activation = activation
75 |
76 | def forward(
77 | self,
78 | x: torch.Tensor,
79 | mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
80 | cache: torch.Tensor = torch.zeros((0, 0, 0)),
81 | ) -> Tuple[torch.Tensor, torch.Tensor]:
82 | """Compute convolution module.
83 | Args:
84 | x (torch.Tensor): Input tensor (#batch, time, channels).
85 | mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
86 | (0, 0, 0) means fake mask.
87 | cache (torch.Tensor): left context cache, it is only
88 | used in causal convolution (#batch, channels, cache_t),
89 | (0, 0, 0) meas fake cache.
90 | Returns:
91 | torch.Tensor: Output tensor (#batch, time, channels).
92 | """
93 | # exchange the temporal dimension and the feature dimension
94 | x = x.transpose(1, 2) # (#batch, channels, time)
95 |
96 | # mask batch padding
97 | if mask_pad.size(2) > 0: # time > 0
98 | x.masked_fill_(~mask_pad, 0.0)
99 |
100 | if self.lorder > 0:
101 | if cache.size(2) == 0: # cache_t == 0
102 | x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
103 | else:
104 | assert cache.size(0) == x.size(0) # equal batch
105 | assert cache.size(1) == x.size(1) # equal channel
106 | x = torch.cat((cache, x), dim=2)
107 | assert (x.size(2) > self.lorder)
108 | new_cache = x[:, :, -self.lorder:]
109 | else:
110 | # It's better we just return None if no cache is required,
111 | # However, for JIT export, here we just fake one tensor instead of
112 | # None.
113 | new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
114 |
115 | # GLU mechanism
116 | x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
117 | x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
118 |
119 | # 1D Depthwise Conv
120 | x = self.depthwise_conv(x)
121 | if self.use_layer_norm:
122 | x = x.transpose(1, 2)
123 | x = self.activation(self.norm(x))
124 | if self.use_layer_norm:
125 | x = x.transpose(1, 2)
126 | x = self.pointwise_conv2(x)
127 | # mask batch padding
128 | if mask_pad.size(2) > 0: # time > 0
129 | x.masked_fill_(~mask_pad, 0.0)
130 |
131 | return x.transpose(1, 2), new_cache
132 |
--------------------------------------------------------------------------------
/viettts/transformer/decoder_layer.py:
--------------------------------------------------------------------------------
1 | """Decoder self-attention layer definition."""
2 | from typing import Optional, Tuple
3 |
4 | import torch
5 | from torch import nn
6 |
7 |
8 | class DecoderLayer(nn.Module):
9 | """Single decoder layer module.
10 |
11 | Args:
12 | size (int): Input dimension.
13 | self_attn (torch.nn.Module): Self-attention module instance.
14 | `MultiHeadedAttention` instance can be used as the argument.
15 | src_attn (torch.nn.Module): Inter-attention module instance.
16 | `MultiHeadedAttention` instance can be used as the argument.
17 | If `None` is passed, Inter-attention is not used, such as
18 | CIF, GPT, and other decoder only model.
19 | feed_forward (torch.nn.Module): Feed-forward module instance.
20 | `PositionwiseFeedForward` instance can be used as the argument.
21 | dropout_rate (float): Dropout rate.
22 | normalize_before (bool):
23 | True: use layer_norm before each sub-block.
24 | False: to use layer_norm after each sub-block.
25 | """
26 |
27 | def __init__(
28 | self,
29 | size: int,
30 | self_attn: nn.Module,
31 | src_attn: Optional[nn.Module],
32 | feed_forward: nn.Module,
33 | dropout_rate: float,
34 | normalize_before: bool = True,
35 | ):
36 | """Construct an DecoderLayer object."""
37 | super().__init__()
38 | self.size = size
39 | self.self_attn = self_attn
40 | self.src_attn = src_attn
41 | self.feed_forward = feed_forward
42 | self.norm1 = nn.LayerNorm(size, eps=1e-5)
43 | self.norm2 = nn.LayerNorm(size, eps=1e-5)
44 | self.norm3 = nn.LayerNorm(size, eps=1e-5)
45 | self.dropout = nn.Dropout(dropout_rate)
46 | self.normalize_before = normalize_before
47 |
48 | def forward(
49 | self,
50 | tgt: torch.Tensor,
51 | tgt_mask: torch.Tensor,
52 | memory: torch.Tensor,
53 | memory_mask: torch.Tensor,
54 | cache: Optional[torch.Tensor] = None
55 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
56 | """Compute decoded features.
57 |
58 | Args:
59 | tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
60 | tgt_mask (torch.Tensor): Mask for input tensor
61 | (#batch, maxlen_out).
62 | memory (torch.Tensor): Encoded memory
63 | (#batch, maxlen_in, size).
64 | memory_mask (torch.Tensor): Encoded memory mask
65 | (#batch, maxlen_in).
66 | cache (torch.Tensor): cached tensors.
67 | (#batch, maxlen_out - 1, size).
68 |
69 | Returns:
70 | torch.Tensor: Output tensor (#batch, maxlen_out, size).
71 | torch.Tensor: Mask for output tensor (#batch, maxlen_out).
72 | torch.Tensor: Encoded memory (#batch, maxlen_in, size).
73 | torch.Tensor: Encoded memory mask (#batch, maxlen_in).
74 |
75 | """
76 | residual = tgt
77 | if self.normalize_before:
78 | tgt = self.norm1(tgt)
79 |
80 | if cache is None:
81 | tgt_q = tgt
82 | tgt_q_mask = tgt_mask
83 | else:
84 | # compute only the last frame query keeping dim: max_time_out -> 1
85 | assert cache.shape == (
86 | tgt.shape[0],
87 | tgt.shape[1] - 1,
88 | self.size,
89 | ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
90 | tgt_q = tgt[:, -1:, :]
91 | residual = residual[:, -1:, :]
92 | tgt_q_mask = tgt_mask[:, -1:, :]
93 |
94 | x = residual + self.dropout(
95 | self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
96 | if not self.normalize_before:
97 | x = self.norm1(x)
98 |
99 | if self.src_attn is not None:
100 | residual = x
101 | if self.normalize_before:
102 | x = self.norm2(x)
103 | x = residual + self.dropout(
104 | self.src_attn(x, memory, memory, memory_mask)[0])
105 | if not self.normalize_before:
106 | x = self.norm2(x)
107 |
108 | residual = x
109 | if self.normalize_before:
110 | x = self.norm3(x)
111 | x = residual + self.dropout(self.feed_forward(x))
112 | if not self.normalize_before:
113 | x = self.norm3(x)
114 |
115 | if cache is not None:
116 | x = torch.cat([cache, x], dim=1)
117 |
118 | return x, tgt_mask, memory, memory_mask
119 |
--------------------------------------------------------------------------------
/viettts/transformer/embedding.py:
--------------------------------------------------------------------------------
1 | # Modified from ESPnet(https://github.com/espnet/espnet)
2 | """Positonal Encoding Module."""
3 |
4 | import math
5 | from typing import Tuple, Union
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | import numpy as np
10 |
11 |
12 | class PositionalEncoding(torch.nn.Module):
13 | """Positional encoding.
14 |
15 | :param int d_model: embedding dim
16 | :param float dropout_rate: dropout rate
17 | :param int max_len: maximum input length
18 |
19 | PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
20 | PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
21 | """
22 |
23 | def __init__(self,
24 | d_model: int,
25 | dropout_rate: float,
26 | max_len: int = 5000,
27 | reverse: bool = False):
28 | """Construct an PositionalEncoding object."""
29 | super().__init__()
30 | self.d_model = d_model
31 | self.xscale = math.sqrt(self.d_model)
32 | self.dropout = torch.nn.Dropout(p=dropout_rate)
33 | self.max_len = max_len
34 |
35 | self.pe = torch.zeros(self.max_len, self.d_model)
36 | position = torch.arange(0, self.max_len,
37 | dtype=torch.float32).unsqueeze(1)
38 | div_term = torch.exp(
39 | torch.arange(0, self.d_model, 2, dtype=torch.float32) *
40 | -(math.log(10000.0) / self.d_model))
41 | self.pe[:, 0::2] = torch.sin(position * div_term)
42 | self.pe[:, 1::2] = torch.cos(position * div_term)
43 | self.pe = self.pe.unsqueeze(0)
44 |
45 | def forward(self,
46 | x: torch.Tensor,
47 | offset: Union[int, torch.Tensor] = 0) \
48 | -> Tuple[torch.Tensor, torch.Tensor]:
49 | """Add positional encoding.
50 |
51 | Args:
52 | x (torch.Tensor): Input. Its shape is (batch, time, ...)
53 | offset (int, torch.tensor): position offset
54 |
55 | Returns:
56 | torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
57 | torch.Tensor: for compatibility to RelPositionalEncoding
58 | """
59 |
60 | self.pe = self.pe.to(x.device)
61 | pos_emb = self.position_encoding(offset, x.size(1), False)
62 | x = x * self.xscale + pos_emb
63 | return self.dropout(x), self.dropout(pos_emb)
64 |
65 | def position_encoding(self,
66 | offset: Union[int, torch.Tensor],
67 | size: int,
68 | apply_dropout: bool = True) -> torch.Tensor:
69 | """ For getting encoding in a streaming fashion
70 |
71 | Attention!!!!!
72 | we apply dropout only once at the whole utterance level in a none
73 | streaming way, but will call this function several times with
74 | increasing input size in a streaming scenario, so the dropout will
75 | be applied several times.
76 |
77 | Args:
78 | offset (int or torch.tensor): start offset
79 | size (int): required size of position encoding
80 |
81 | Returns:
82 | torch.Tensor: Corresponding encoding
83 | """
84 | # How to subscript a Union type:
85 | # https://github.com/pytorch/pytorch/issues/69434
86 | if isinstance(offset, int):
87 | assert offset + size <= self.max_len
88 | pos_emb = self.pe[:, offset:offset + size]
89 | elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
90 | assert offset + size <= self.max_len
91 | pos_emb = self.pe[:, offset:offset + size]
92 | else: # for batched streaming decoding on GPU
93 | assert torch.max(offset) + size <= self.max_len
94 | index = offset.unsqueeze(1) + \
95 | torch.arange(0, size).to(offset.device) # B X T
96 | flag = index > 0
97 | # remove negative offset
98 | index = index * flag
99 | pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
100 |
101 | if apply_dropout:
102 | pos_emb = self.dropout(pos_emb)
103 | return pos_emb
104 |
105 |
106 | class RelPositionalEncoding(PositionalEncoding):
107 | """Relative positional encoding module.
108 | See : Appendix B in https://arxiv.org/abs/1901.02860
109 | Args:
110 | d_model (int): Embedding dimension.
111 | dropout_rate (float): Dropout rate.
112 | max_len (int): Maximum input length.
113 | """
114 |
115 | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
116 | """Initialize class."""
117 | super().__init__(d_model, dropout_rate, max_len, reverse=True)
118 |
119 | def forward(self,
120 | x: torch.Tensor,
121 | offset: Union[int, torch.Tensor] = 0) \
122 | -> Tuple[torch.Tensor, torch.Tensor]:
123 | """Compute positional encoding.
124 | Args:
125 | x (torch.Tensor): Input tensor (batch, time, `*`).
126 | Returns:
127 | torch.Tensor: Encoded tensor (batch, time, `*`).
128 | torch.Tensor: Positional embedding tensor (1, time, `*`).
129 | """
130 | self.pe = self.pe.to(x.device)
131 | x = x * self.xscale
132 | pos_emb = self.position_encoding(offset, x.size(1), False)
133 | return self.dropout(x), self.dropout(pos_emb)
134 |
135 |
136 | class WhisperPositionalEncoding(PositionalEncoding):
137 | """ Sinusoids position encoding used in openai-whisper.encoder
138 | """
139 |
140 | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500):
141 | super().__init__(d_model, dropout_rate, max_len)
142 | self.xscale = 1.0
143 | log_timescale_increment = np.log(10000) / (d_model // 2 - 1)
144 | inv_timescales = torch.exp(-log_timescale_increment *
145 | torch.arange(d_model // 2))
146 | scaled_time = torch.arange(max_len)[:, np.newaxis] * \
147 | inv_timescales[np.newaxis, :]
148 | pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
149 | delattr(self, "pe")
150 | self.register_buffer("pe", pe.unsqueeze(0))
151 |
152 |
153 | class LearnablePositionalEncoding(PositionalEncoding):
154 | """ Learnable position encoding used in openai-whisper.decoder
155 | """
156 |
157 | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448):
158 | super().__init__(d_model, dropout_rate, max_len)
159 | # NOTE(xcsong): overwrite self.pe & self.xscale
160 | self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model))
161 | self.xscale = 1.0
162 |
163 |
164 | class NoPositionalEncoding(torch.nn.Module):
165 | """ No position encoding
166 | """
167 |
168 | def __init__(self, d_model: int, dropout_rate: float):
169 | super().__init__()
170 | self.d_model = d_model
171 | self.dropout = torch.nn.Dropout(p=dropout_rate)
172 |
173 | def forward(self,
174 | x: torch.Tensor,
175 | offset: Union[int, torch.Tensor] = 0) \
176 | -> Tuple[torch.Tensor, torch.Tensor]:
177 | """ Just return zero vector for interface compatibility
178 | """
179 | pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
180 | return self.dropout(x), pos_emb
181 |
182 | def position_encoding(self, offset: Union[int, torch.Tensor],
183 | size: int) -> torch.Tensor:
184 | return torch.zeros(1, size, self.d_model)
185 |
186 |
187 | class EspnetRelPositionalEncoding(torch.nn.Module):
188 | """Relative positional encoding module (new implementation).
189 |
190 | Details can be found in https://github.com/espnet/espnet/pull/2816.
191 |
192 | See : Appendix B in https://arxiv.org/abs/1901.02860
193 |
194 | Args:
195 | d_model (int): Embedding dimension.
196 | dropout_rate (float): Dropout rate.
197 | max_len (int): Maximum input length.
198 |
199 | """
200 |
201 | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
202 | """Construct an PositionalEncoding object."""
203 | super(EspnetRelPositionalEncoding, self).__init__()
204 | self.d_model = d_model
205 | self.xscale = math.sqrt(self.d_model)
206 | self.dropout = torch.nn.Dropout(p=dropout_rate)
207 | self.pe = None
208 | self.extend_pe(torch.tensor(0.0).expand(1, max_len))
209 |
210 | def extend_pe(self, x: torch.Tensor):
211 | """Reset the positional encodings."""
212 | if self.pe is not None:
213 | # self.pe contains both positive and negative parts
214 | # the length of self.pe is 2 * input_len - 1
215 | if self.pe.size(1) >= x.size(1) * 2 - 1:
216 | if self.pe.dtype != x.dtype or self.pe.device != x.device:
217 | self.pe = self.pe.to(dtype=x.dtype, device=x.device)
218 | return
219 | # Suppose `i` means to the position of query vecotr and `j` means the
220 | # position of key vector. We use position relative positions when keys
221 | # are to the left (i>j) and negative relative positions otherwise (i Tuple[torch.Tensor, torch.Tensor]:
244 | """Add positional encoding.
245 |
246 | Args:
247 | x (torch.Tensor): Input tensor (batch, time, `*`).
248 |
249 | Returns:
250 | torch.Tensor: Encoded tensor (batch, time, `*`).
251 |
252 | """
253 | self.extend_pe(x)
254 | x = x * self.xscale
255 | pos_emb = self.position_encoding(size=x.size(1), offset=offset)
256 | return self.dropout(x), self.dropout(pos_emb)
257 |
258 | def position_encoding(self,
259 | offset: Union[int, torch.Tensor],
260 | size: int) -> torch.Tensor:
261 | """ For getting encoding in a streaming fashion
262 |
263 | Attention!!!!!
264 | we apply dropout only once at the whole utterance level in a none
265 | streaming way, but will call this function several times with
266 | increasing input size in a streaming scenario, so the dropout will
267 | be applied several times.
268 |
269 | Args:
270 | offset (int or torch.tensor): start offset
271 | size (int): required size of position encoding
272 |
273 | Returns:
274 | torch.Tensor: Corresponding encoding
275 | """
276 | pos_emb = self.pe[
277 | :,
278 | self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
279 | ]
280 | return pos_emb
281 |
--------------------------------------------------------------------------------
/viettts/transformer/encoder_layer.py:
--------------------------------------------------------------------------------
1 | # Modified from ESPnet(https://github.com/espnet/espnet)
2 | """Encoder self-attention layer definition."""
3 |
4 | from typing import Optional, Tuple
5 |
6 | import torch
7 | from torch import nn
8 |
9 |
10 | class TransformerEncoderLayer(nn.Module):
11 | """Encoder layer module.
12 |
13 | Args:
14 | size (int): Input dimension.
15 | self_attn (torch.nn.Module): Self-attention module instance.
16 | `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
17 | instance can be used as the argument.
18 | feed_forward (torch.nn.Module): Feed-forward module instance.
19 | `PositionwiseFeedForward`, instance can be used as the argument.
20 | dropout_rate (float): Dropout rate.
21 | normalize_before (bool):
22 | True: use layer_norm before each sub-block.
23 | False: to use layer_norm after each sub-block.
24 | """
25 |
26 | def __init__(
27 | self,
28 | size: int,
29 | self_attn: torch.nn.Module,
30 | feed_forward: torch.nn.Module,
31 | dropout_rate: float,
32 | normalize_before: bool = True,
33 | ):
34 | """Construct an EncoderLayer object."""
35 | super().__init__()
36 | self.self_attn = self_attn
37 | self.feed_forward = feed_forward
38 | self.norm1 = nn.LayerNorm(size, eps=1e-5)
39 | self.norm2 = nn.LayerNorm(size, eps=1e-5)
40 | self.dropout = nn.Dropout(dropout_rate)
41 | self.size = size
42 | self.normalize_before = normalize_before
43 |
44 | def forward(
45 | self,
46 | x: torch.Tensor,
47 | mask: torch.Tensor,
48 | pos_emb: torch.Tensor,
49 | mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
50 | att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
51 | cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
52 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
53 | """Compute encoded features.
54 |
55 | Args:
56 | x (torch.Tensor): (#batch, time, size)
57 | mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
58 | (0, 0, 0) means fake mask.
59 | pos_emb (torch.Tensor): just for interface compatibility
60 | to ConformerEncoderLayer
61 | mask_pad (torch.Tensor): does not used in transformer layer,
62 | just for unified api with conformer.
63 | att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
64 | (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
65 | cnn_cache (torch.Tensor): Convolution cache in conformer layer
66 | (#batch=1, size, cache_t2), not used here, it's for interface
67 | compatibility to ConformerEncoderLayer.
68 | Returns:
69 | torch.Tensor: Output tensor (#batch, time, size).
70 | torch.Tensor: Mask tensor (#batch, time, time).
71 | torch.Tensor: att_cache tensor,
72 | (#batch=1, head, cache_t1 + time, d_k * 2).
73 | torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
74 |
75 | """
76 | residual = x
77 | if self.normalize_before:
78 | x = self.norm1(x)
79 | x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache)
80 | x = residual + self.dropout(x_att)
81 | if not self.normalize_before:
82 | x = self.norm1(x)
83 |
84 | residual = x
85 | if self.normalize_before:
86 | x = self.norm2(x)
87 | x = residual + self.dropout(self.feed_forward(x))
88 | if not self.normalize_before:
89 | x = self.norm2(x)
90 |
91 | fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
92 | return x, mask, new_att_cache, fake_cnn_cache
93 |
94 |
95 | class ConformerEncoderLayer(nn.Module):
96 | """Encoder layer module.
97 | Args:
98 | size (int): Input dimension.
99 | self_attn (torch.nn.Module): Self-attention module instance.
100 | `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
101 | instance can be used as the argument.
102 | feed_forward (torch.nn.Module): Feed-forward module instance.
103 | `PositionwiseFeedForward` instance can be used as the argument.
104 | feed_forward_macaron (torch.nn.Module): Additional feed-forward module
105 | instance.
106 | `PositionwiseFeedForward` instance can be used as the argument.
107 | conv_module (torch.nn.Module): Convolution module instance.
108 | `ConvlutionModule` instance can be used as the argument.
109 | dropout_rate (float): Dropout rate.
110 | normalize_before (bool):
111 | True: use layer_norm before each sub-block.
112 | False: use layer_norm after each sub-block.
113 | """
114 |
115 | def __init__(
116 | self,
117 | size: int,
118 | self_attn: torch.nn.Module,
119 | feed_forward: Optional[nn.Module] = None,
120 | feed_forward_macaron: Optional[nn.Module] = None,
121 | conv_module: Optional[nn.Module] = None,
122 | dropout_rate: float = 0.1,
123 | normalize_before: bool = True,
124 | ):
125 | """Construct an EncoderLayer object."""
126 | super().__init__()
127 | self.self_attn = self_attn
128 | self.feed_forward = feed_forward
129 | self.feed_forward_macaron = feed_forward_macaron
130 | self.conv_module = conv_module
131 | self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
132 | self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
133 | if feed_forward_macaron is not None:
134 | self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
135 | self.ff_scale = 0.5
136 | else:
137 | self.ff_scale = 1.0
138 | if self.conv_module is not None:
139 | self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module
140 | self.norm_final = nn.LayerNorm(
141 | size, eps=1e-5) # for the final output of the block
142 | self.dropout = nn.Dropout(dropout_rate)
143 | self.size = size
144 | self.normalize_before = normalize_before
145 |
146 | def forward(
147 | self,
148 | x: torch.Tensor,
149 | mask: torch.Tensor,
150 | pos_emb: torch.Tensor,
151 | mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
152 | att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
153 | cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
154 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
155 | """Compute encoded features.
156 |
157 | Args:
158 | x (torch.Tensor): (#batch, time, size)
159 | mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
160 | (0, 0, 0) means fake mask.
161 | pos_emb (torch.Tensor): positional encoding, must not be None
162 | for ConformerEncoderLayer.
163 | mask_pad (torch.Tensor): batch padding mask used for conv module.
164 | (#batch, 1,time), (0, 0, 0) means fake mask.
165 | att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
166 | (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
167 | cnn_cache (torch.Tensor): Convolution cache in conformer layer
168 | (#batch=1, size, cache_t2)
169 | Returns:
170 | torch.Tensor: Output tensor (#batch, time, size).
171 | torch.Tensor: Mask tensor (#batch, time, time).
172 | torch.Tensor: att_cache tensor,
173 | (#batch=1, head, cache_t1 + time, d_k * 2).
174 | torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
175 | """
176 |
177 | # whether to use macaron style
178 | if self.feed_forward_macaron is not None:
179 | residual = x
180 | if self.normalize_before:
181 | x = self.norm_ff_macaron(x)
182 | x = residual + self.ff_scale * self.dropout(
183 | self.feed_forward_macaron(x))
184 | if not self.normalize_before:
185 | x = self.norm_ff_macaron(x)
186 |
187 | # multi-headed self-attention module
188 | residual = x
189 | if self.normalize_before:
190 | x = self.norm_mha(x)
191 | x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
192 | att_cache)
193 | x = residual + self.dropout(x_att)
194 | if not self.normalize_before:
195 | x = self.norm_mha(x)
196 |
197 | # convolution module
198 | # Fake new cnn cache here, and then change it in conv_module
199 | new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
200 | if self.conv_module is not None:
201 | residual = x
202 | if self.normalize_before:
203 | x = self.norm_conv(x)
204 | x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
205 | x = residual + self.dropout(x)
206 |
207 | if not self.normalize_before:
208 | x = self.norm_conv(x)
209 |
210 | # feed forward module
211 | residual = x
212 | if self.normalize_before:
213 | x = self.norm_ff(x)
214 |
215 | x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
216 | if not self.normalize_before:
217 | x = self.norm_ff(x)
218 |
219 | if self.conv_module is not None:
220 | x = self.norm_final(x)
221 |
222 | return x, mask, new_att_cache, new_cnn_cache
223 |
--------------------------------------------------------------------------------
/viettts/transformer/label_smoothing_loss.py:
--------------------------------------------------------------------------------
1 | """Label smoothing module."""
2 |
3 | import torch
4 | from torch import nn
5 |
6 |
7 | class LabelSmoothingLoss(nn.Module):
8 | """Label-smoothing loss.
9 |
10 | In a standard CE loss, the label's data distribution is:
11 | [0,1,2] ->
12 | [
13 | [1.0, 0.0, 0.0],
14 | [0.0, 1.0, 0.0],
15 | [0.0, 0.0, 1.0],
16 | ]
17 |
18 | In the smoothing version CE Loss,some probabilities
19 | are taken from the true label prob (1.0) and are divided
20 | among other labels.
21 |
22 | e.g.
23 | smoothing=0.1
24 | [0,1,2] ->
25 | [
26 | [0.9, 0.05, 0.05],
27 | [0.05, 0.9, 0.05],
28 | [0.05, 0.05, 0.9],
29 | ]
30 |
31 | Args:
32 | size (int): the number of class
33 | padding_idx (int): padding class id which will be ignored for loss
34 | smoothing (float): smoothing rate (0.0 means the conventional CE)
35 | normalize_length (bool):
36 | normalize loss by sequence length if True
37 | normalize loss by batch size if False
38 | """
39 |
40 | def __init__(self,
41 | size: int,
42 | padding_idx: int,
43 | smoothing: float,
44 | normalize_length: bool = False):
45 | """Construct an LabelSmoothingLoss object."""
46 | super(LabelSmoothingLoss, self).__init__()
47 | self.criterion = nn.KLDivLoss(reduction="none")
48 | self.padding_idx = padding_idx
49 | self.confidence = 1.0 - smoothing
50 | self.smoothing = smoothing
51 | self.size = size
52 | self.normalize_length = normalize_length
53 |
54 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
55 | """Compute loss between x and target.
56 |
57 | The model outputs and data labels tensors are flatten to
58 | (batch*seqlen, class) shape and a mask is applied to the
59 | padding part which should not be calculated for loss.
60 |
61 | Args:
62 | x (torch.Tensor): prediction (batch, seqlen, class)
63 | target (torch.Tensor):
64 | target signal masked with self.padding_id (batch, seqlen)
65 | Returns:
66 | loss (torch.Tensor) : The KL loss, scalar float value
67 | """
68 | assert x.size(2) == self.size
69 | batch_size = x.size(0)
70 | x = x.view(-1, self.size)
71 | target = target.view(-1)
72 | # use zeros_like instead of torch.no_grad() for true_dist,
73 | # since no_grad() can not be exported by JIT
74 | true_dist = torch.zeros_like(x)
75 | true_dist.fill_(self.smoothing / (self.size - 1))
76 | ignore = target == self.padding_idx # (B,)
77 | total = len(target) - ignore.sum().item()
78 | target = target.masked_fill(ignore, 0) # avoid -1 index
79 | true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
80 | kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
81 | denom = total if self.normalize_length else batch_size
82 | return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
83 |
--------------------------------------------------------------------------------
/viettts/transformer/positionwise_feed_forward.py:
--------------------------------------------------------------------------------
1 | """Positionwise feed forward layer definition."""
2 |
3 | import torch
4 |
5 |
6 | class PositionwiseFeedForward(torch.nn.Module):
7 | """Positionwise feed forward layer.
8 |
9 | FeedForward are appied on each position of the sequence.
10 | The output dim is same with the input dim.
11 |
12 | Args:
13 | idim (int): Input dimenstion.
14 | hidden_units (int): The number of hidden units.
15 | dropout_rate (float): Dropout rate.
16 | activation (torch.nn.Module): Activation function
17 | """
18 |
19 | def __init__(
20 | self,
21 | idim: int,
22 | hidden_units: int,
23 | dropout_rate: float,
24 | activation: torch.nn.Module = torch.nn.ReLU(),
25 | ):
26 | """Construct a PositionwiseFeedForward object."""
27 | super(PositionwiseFeedForward, self).__init__()
28 | self.w_1 = torch.nn.Linear(idim, hidden_units)
29 | self.activation = activation
30 | self.dropout = torch.nn.Dropout(dropout_rate)
31 | self.w_2 = torch.nn.Linear(hidden_units, idim)
32 |
33 | def forward(self, xs: torch.Tensor) -> torch.Tensor:
34 | """Forward function.
35 |
36 | Args:
37 | xs: input tensor (B, L, D)
38 | Returns:
39 | output tensor, (B, L, D)
40 | """
41 | return self.w_2(self.dropout(self.activation(self.w_1(xs))))
42 |
43 |
44 | class MoEFFNLayer(torch.nn.Module):
45 | """
46 | Mixture of expert with Positionwise feed forward layer
47 | See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf
48 | The output dim is same with the input dim.
49 |
50 | Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
51 | https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
52 | Args:
53 | n_expert: number of expert.
54 | n_expert_per_token: The actual number of experts used for each frame
55 | idim (int): Input dimenstion.
56 | hidden_units (int): The number of hidden units.
57 | dropout_rate (float): Dropout rate.
58 | activation (torch.nn.Module): Activation function
59 | """
60 |
61 | def __init__(
62 | self,
63 | n_expert: int,
64 | n_expert_per_token: int,
65 | idim: int,
66 | hidden_units: int,
67 | dropout_rate: float,
68 | activation: torch.nn.Module = torch.nn.ReLU(),
69 | ):
70 | super(MoEFFNLayer, self).__init__()
71 | self.gate = torch.nn.Linear(idim, n_expert, bias=False)
72 | self.experts = torch.nn.ModuleList(
73 | PositionwiseFeedForward(idim, hidden_units, dropout_rate,
74 | activation) for _ in range(n_expert))
75 | self.n_expert_per_token = n_expert_per_token
76 |
77 | def forward(self, xs: torch.Tensor) -> torch.Tensor:
78 | """Foward function.
79 | Args:
80 | xs: input tensor (B, L, D)
81 | Returns:
82 | output tensor, (B, L, D)
83 |
84 | """
85 | B, L, D = xs.size(
86 | ) # batch size, sequence length, embedding dimension (idim)
87 | xs = xs.view(-1, D) # (B*L, D)
88 | router = self.gate(xs) # (B*L, n_expert)
89 | logits, indices = torch.topk(
90 | router, self.n_expert_per_token
91 | ) # probs:(B*L, n_expert), indices: (B*L, n_expert)
92 | weights = torch.nn.functional.softmax(
93 | logits, dim=1,
94 | dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token)
95 | output = torch.zeros_like(xs) # (B*L, D)
96 | for i, expert in enumerate(self.experts):
97 | mask = indices == i
98 | batch_idx, ith_expert = torch.where(mask)
99 | output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
100 | xs[batch_idx])
101 | return output.view(B, L, D)
102 |
--------------------------------------------------------------------------------
/viettts/transformer/subsampling.py:
--------------------------------------------------------------------------------
1 | # Modified from ESPnet(https://github.com/espnet/espnet)
2 | """Subsampling layer definition."""
3 |
4 | from typing import Tuple, Union
5 |
6 | import torch
7 |
8 |
9 | class BaseSubsampling(torch.nn.Module):
10 |
11 | def __init__(self):
12 | super().__init__()
13 | self.right_context = 0
14 | self.subsampling_rate = 1
15 |
16 | def position_encoding(self, offset: Union[int, torch.Tensor],
17 | size: int) -> torch.Tensor:
18 | return self.pos_enc.position_encoding(offset, size)
19 |
20 |
21 | class EmbedinigNoSubsampling(BaseSubsampling):
22 | """Embedding input without subsampling
23 | """
24 |
25 | def __init__(self, idim: int, odim: int, dropout_rate: float,
26 | pos_enc_class: torch.nn.Module):
27 | super().__init__()
28 | self.embed = torch.nn.Embedding(idim, odim)
29 | self.pos_enc = pos_enc_class
30 |
31 | def forward(
32 | self,
33 | x: torch.Tensor,
34 | x_mask: torch.Tensor,
35 | offset: Union[int, torch.Tensor] = 0
36 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
37 | """Input x.
38 |
39 | Args:
40 | x (torch.Tensor): Input tensor (#batch, time, idim).
41 | x_mask (torch.Tensor): Input mask (#batch, 1, time).
42 |
43 | Returns:
44 | torch.Tensor: linear input tensor (#batch, time', odim),
45 | where time' = time .
46 | torch.Tensor: linear input mask (#batch, 1, time'),
47 | where time' = time .
48 |
49 | """
50 | x = self.embed(x)
51 | x, pos_emb = self.pos_enc(x, offset)
52 | return x, pos_emb, x_mask
53 |
54 |
55 | class LinearNoSubsampling(BaseSubsampling):
56 | """Linear transform the input without subsampling
57 |
58 | Args:
59 | idim (int): Input dimension.
60 | odim (int): Output dimension.
61 | dropout_rate (float): Dropout rate.
62 |
63 | """
64 |
65 | def __init__(self, idim: int, odim: int, dropout_rate: float,
66 | pos_enc_class: torch.nn.Module):
67 | """Construct an linear object."""
68 | super().__init__()
69 | self.out = torch.nn.Sequential(
70 | torch.nn.Linear(idim, odim),
71 | torch.nn.LayerNorm(odim, eps=1e-5),
72 | torch.nn.Dropout(dropout_rate),
73 | )
74 | self.pos_enc = pos_enc_class
75 | self.right_context = 0
76 | self.subsampling_rate = 1
77 |
78 | def forward(
79 | self,
80 | x: torch.Tensor,
81 | x_mask: torch.Tensor,
82 | offset: Union[int, torch.Tensor] = 0
83 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
84 | """Input x.
85 |
86 | Args:
87 | x (torch.Tensor): Input tensor (#batch, time, idim).
88 | x_mask (torch.Tensor): Input mask (#batch, 1, time).
89 |
90 | Returns:
91 | torch.Tensor: linear input tensor (#batch, time', odim),
92 | where time' = time .
93 | torch.Tensor: linear input mask (#batch, 1, time'),
94 | where time' = time .
95 |
96 | """
97 | x = self.out(x)
98 | x, pos_emb = self.pos_enc(x, offset)
99 | return x, pos_emb, x_mask
100 |
101 |
102 | class Conv1dSubsampling2(BaseSubsampling):
103 | """Convolutional 1D subsampling (to 1/2 length).
104 | It is designed for Whisper, ref:
105 | https://github.com/openai/whisper/blob/main/whisper/model.py
106 |
107 | Args:
108 | idim (int): Input dimension.
109 | odim (int): Output dimension.
110 | dropout_rate (float): Dropout rate.
111 |
112 | """
113 |
114 | def __init__(self, idim: int, odim: int, dropout_rate: float,
115 | pos_enc_class: torch.nn.Module):
116 | """Construct an Conv1dSubsampling2 object."""
117 | super().__init__()
118 | self.conv = torch.nn.Sequential(
119 | torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1),
120 | torch.nn.GELU(),
121 | torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1),
122 | torch.nn.GELU(),
123 | )
124 | self.pos_enc = pos_enc_class
125 | # The right context for every conv layer is computed by:
126 | # (kernel_size - 1) * frame_rate_of_this_layer
127 | self.subsampling_rate = 2
128 | # 4 = (3 - 1) * 1 + (3 - 1) * 1
129 | self.right_context = 4
130 |
131 | def forward(
132 | self,
133 | x: torch.Tensor,
134 | x_mask: torch.Tensor,
135 | offset: Union[int, torch.Tensor] = 0
136 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
137 | """Subsample x.
138 |
139 | Args:
140 | x (torch.Tensor): Input tensor (#batch, time, idim).
141 | x_mask (torch.Tensor): Input mask (#batch, 1, time).
142 |
143 | Returns:
144 | torch.Tensor: Subsampled tensor (#batch, time', odim),
145 | where time' = time // 2.
146 | torch.Tensor: Subsampled mask (#batch, 1, time'),
147 | where time' = time // 2.
148 | torch.Tensor: positional encoding
149 |
150 | """
151 | time = x.size(1)
152 | x = x.transpose(1, 2) # (b, f, t)
153 | x = self.conv(x)
154 | x = x.transpose(1, 2) # (b, t, f)
155 | x, pos_emb = self.pos_enc(x, offset)
156 | return x, pos_emb, x_mask[:, :, (time + 1) % 2::2]
157 |
158 |
159 | class Conv2dSubsampling4(BaseSubsampling):
160 | """Convolutional 2D subsampling (to 1/4 length).
161 |
162 | Args:
163 | idim (int): Input dimension.
164 | odim (int): Output dimension.
165 | dropout_rate (float): Dropout rate.
166 |
167 | """
168 |
169 | def __init__(self, idim: int, odim: int, dropout_rate: float,
170 | pos_enc_class: torch.nn.Module):
171 | """Construct an Conv2dSubsampling4 object."""
172 | super().__init__()
173 | self.conv = torch.nn.Sequential(
174 | torch.nn.Conv2d(1, odim, 3, 2),
175 | torch.nn.ReLU(),
176 | torch.nn.Conv2d(odim, odim, 3, 2),
177 | torch.nn.ReLU(),
178 | )
179 | self.out = torch.nn.Sequential(
180 | torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
181 | self.pos_enc = pos_enc_class
182 | # The right context for every conv layer is computed by:
183 | # (kernel_size - 1) * frame_rate_of_this_layer
184 | self.subsampling_rate = 4
185 | # 6 = (3 - 1) * 1 + (3 - 1) * 2
186 | self.right_context = 6
187 |
188 | def forward(
189 | self,
190 | x: torch.Tensor,
191 | x_mask: torch.Tensor,
192 | offset: Union[int, torch.Tensor] = 0
193 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
194 | """Subsample x.
195 |
196 | Args:
197 | x (torch.Tensor): Input tensor (#batch, time, idim).
198 | x_mask (torch.Tensor): Input mask (#batch, 1, time).
199 |
200 | Returns:
201 | torch.Tensor: Subsampled tensor (#batch, time', odim),
202 | where time' = time // 4.
203 | torch.Tensor: Subsampled mask (#batch, 1, time'),
204 | where time' = time // 4.
205 | torch.Tensor: positional encoding
206 |
207 | """
208 | x = x.unsqueeze(1) # (b, c=1, t, f)
209 | x = self.conv(x)
210 | b, c, t, f = x.size()
211 | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
212 | x, pos_emb = self.pos_enc(x, offset)
213 | return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
214 |
215 |
216 | class Conv2dSubsampling6(BaseSubsampling):
217 | """Convolutional 2D subsampling (to 1/6 length).
218 | Args:
219 | idim (int): Input dimension.
220 | odim (int): Output dimension.
221 | dropout_rate (float): Dropout rate.
222 | pos_enc (torch.nn.Module): Custom position encoding layer.
223 | """
224 |
225 | def __init__(self, idim: int, odim: int, dropout_rate: float,
226 | pos_enc_class: torch.nn.Module):
227 | """Construct an Conv2dSubsampling6 object."""
228 | super().__init__()
229 | self.conv = torch.nn.Sequential(
230 | torch.nn.Conv2d(1, odim, 3, 2),
231 | torch.nn.ReLU(),
232 | torch.nn.Conv2d(odim, odim, 5, 3),
233 | torch.nn.ReLU(),
234 | )
235 | self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
236 | odim)
237 | self.pos_enc = pos_enc_class
238 | # 10 = (3 - 1) * 1 + (5 - 1) * 2
239 | self.subsampling_rate = 6
240 | self.right_context = 10
241 |
242 | def forward(
243 | self,
244 | x: torch.Tensor,
245 | x_mask: torch.Tensor,
246 | offset: Union[int, torch.Tensor] = 0
247 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
248 | """Subsample x.
249 | Args:
250 | x (torch.Tensor): Input tensor (#batch, time, idim).
251 | x_mask (torch.Tensor): Input mask (#batch, 1, time).
252 |
253 | Returns:
254 | torch.Tensor: Subsampled tensor (#batch, time', odim),
255 | where time' = time // 6.
256 | torch.Tensor: Subsampled mask (#batch, 1, time'),
257 | where time' = time // 6.
258 | torch.Tensor: positional encoding
259 | """
260 | x = x.unsqueeze(1) # (b, c, t, f)
261 | x = self.conv(x)
262 | b, c, t, f = x.size()
263 | x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
264 | x, pos_emb = self.pos_enc(x, offset)
265 | return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
266 |
267 |
268 | class Conv2dSubsampling8(BaseSubsampling):
269 | """Convolutional 2D subsampling (to 1/8 length).
270 |
271 | Args:
272 | idim (int): Input dimension.
273 | odim (int): Output dimension.
274 | dropout_rate (float): Dropout rate.
275 |
276 | """
277 |
278 | def __init__(self, idim: int, odim: int, dropout_rate: float,
279 | pos_enc_class: torch.nn.Module):
280 | """Construct an Conv2dSubsampling8 object."""
281 | super().__init__()
282 | self.conv = torch.nn.Sequential(
283 | torch.nn.Conv2d(1, odim, 3, 2),
284 | torch.nn.ReLU(),
285 | torch.nn.Conv2d(odim, odim, 3, 2),
286 | torch.nn.ReLU(),
287 | torch.nn.Conv2d(odim, odim, 3, 2),
288 | torch.nn.ReLU(),
289 | )
290 | self.linear = torch.nn.Linear(
291 | odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
292 | self.pos_enc = pos_enc_class
293 | self.subsampling_rate = 8
294 | # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
295 | self.right_context = 14
296 |
297 | def forward(
298 | self,
299 | x: torch.Tensor,
300 | x_mask: torch.Tensor,
301 | offset: Union[int, torch.Tensor] = 0
302 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
303 | """Subsample x.
304 |
305 | Args:
306 | x (torch.Tensor): Input tensor (#batch, time, idim).
307 | x_mask (torch.Tensor): Input mask (#batch, 1, time).
308 |
309 | Returns:
310 | torch.Tensor: Subsampled tensor (#batch, time', odim),
311 | where time' = time // 8.
312 | torch.Tensor: Subsampled mask (#batch, 1, time'),
313 | where time' = time // 8.
314 | torch.Tensor: positional encoding
315 | """
316 | x = x.unsqueeze(1) # (b, c, t, f)
317 | x = self.conv(x)
318 | b, c, t, f = x.size()
319 | x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
320 | x, pos_emb = self.pos_enc(x, offset)
321 | return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
322 |
323 |
324 | class LegacyLinearNoSubsampling(BaseSubsampling):
325 | """Linear transform the input without subsampling
326 |
327 | Args:
328 | idim (int): Input dimension.
329 | odim (int): Output dimension.
330 | dropout_rate (float): Dropout rate.
331 |
332 | """
333 |
334 | def __init__(self, idim: int, odim: int, dropout_rate: float,
335 | pos_enc_class: torch.nn.Module):
336 | """Construct an linear object."""
337 | super().__init__()
338 | self.out = torch.nn.Sequential(
339 | torch.nn.Linear(idim, odim),
340 | torch.nn.LayerNorm(odim, eps=1e-5),
341 | torch.nn.Dropout(dropout_rate),
342 | torch.nn.ReLU(),
343 | )
344 | self.pos_enc = pos_enc_class
345 | self.right_context = 0
346 | self.subsampling_rate = 1
347 |
348 | def forward(
349 | self,
350 | x: torch.Tensor,
351 | x_mask: torch.Tensor,
352 | offset: Union[int, torch.Tensor] = 0
353 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
354 | """Input x.
355 |
356 | Args:
357 | x (torch.Tensor): Input tensor (#batch, time, idim).
358 | x_mask (torch.Tensor): Input mask (#batch, 1, time).
359 |
360 | Returns:
361 | torch.Tensor: linear input tensor (#batch, time', odim),
362 | where time' = time .
363 | torch.Tensor: linear input mask (#batch, 1, time'),
364 | where time' = time .
365 |
366 | """
367 | x = self.out(x)
368 | x, pos_emb = self.pos_enc(x, offset)
369 | return x, pos_emb, x_mask
370 |
--------------------------------------------------------------------------------
/viettts/transformer/transformer.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Optional
2 |
3 | import torch
4 | import torch.nn as nn
5 | from diffusers.models.attention import (
6 | GEGLU,
7 | GELU,
8 | AdaLayerNorm,
9 | AdaLayerNormZero,
10 | ApproximateGELU,
11 | )
12 | from diffusers.models.attention_processor import Attention
13 | from diffusers.models.lora import LoRACompatibleLinear
14 | from diffusers.utils.torch_utils import maybe_allow_in_graph
15 |
16 |
17 | class SnakeBeta(nn.Module):
18 | """
19 | A modified Snake function which uses separate parameters for the magnitude of the periodic components
20 | Shape:
21 | - Input: (B, C, T)
22 | - Output: (B, C, T), same shape as the input
23 | Parameters:
24 | - alpha - trainable parameter that controls frequency
25 | - beta - trainable parameter that controls magnitude
26 | References:
27 | - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
28 | https://arxiv.org/abs/2006.08195
29 | Examples:
30 | >>> a1 = snakebeta(256)
31 | >>> x = torch.randn(256)
32 | >>> x = a1(x)
33 | """
34 |
35 | def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
36 | """
37 | Initialization.
38 | INPUT:
39 | - in_features: shape of the input
40 | - alpha - trainable parameter that controls frequency
41 | - beta - trainable parameter that controls magnitude
42 | alpha is initialized to 1 by default, higher values = higher-frequency.
43 | beta is initialized to 1 by default, higher values = higher-magnitude.
44 | alpha will be trained along with the rest of your model.
45 | """
46 | super().__init__()
47 | self.in_features = out_features if isinstance(out_features, list) else [out_features]
48 | self.proj = LoRACompatibleLinear(in_features, out_features)
49 |
50 | # initialize alpha
51 | self.alpha_logscale = alpha_logscale
52 | if self.alpha_logscale: # log scale alphas initialized to zeros
53 | self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
54 | self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
55 | else: # linear scale alphas initialized to ones
56 | self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
57 | self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
58 |
59 | self.alpha.requires_grad = alpha_trainable
60 | self.beta.requires_grad = alpha_trainable
61 |
62 | self.no_div_by_zero = 0.000000001
63 |
64 | def forward(self, x):
65 | """
66 | Forward pass of the function.
67 | Applies the function to the input elementwise.
68 | SnakeBeta ∶= x + 1/b * sin^2 (xa)
69 | """
70 | x = self.proj(x)
71 | if self.alpha_logscale:
72 | alpha = torch.exp(self.alpha)
73 | beta = torch.exp(self.beta)
74 | else:
75 | alpha = self.alpha
76 | beta = self.beta
77 |
78 | x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
79 |
80 | return x
81 |
82 |
83 | class FeedForward(nn.Module):
84 | r"""
85 | A feed-forward layer.
86 |
87 | Parameters:
88 | dim (`int`): The number of channels in the input.
89 | dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
90 | mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
91 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
92 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
93 | final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
94 | """
95 |
96 | def __init__(
97 | self,
98 | dim: int,
99 | dim_out: Optional[int] = None,
100 | mult: int = 4,
101 | dropout: float = 0.0,
102 | activation_fn: str = "geglu",
103 | final_dropout: bool = False,
104 | ):
105 | super().__init__()
106 | inner_dim = int(dim * mult)
107 | dim_out = dim_out if dim_out is not None else dim
108 |
109 | if activation_fn == "gelu":
110 | act_fn = GELU(dim, inner_dim)
111 | if activation_fn == "gelu-approximate":
112 | act_fn = GELU(dim, inner_dim, approximate="tanh")
113 | elif activation_fn == "geglu":
114 | act_fn = GEGLU(dim, inner_dim)
115 | elif activation_fn == "geglu-approximate":
116 | act_fn = ApproximateGELU(dim, inner_dim)
117 | elif activation_fn == "snakebeta":
118 | act_fn = SnakeBeta(dim, inner_dim)
119 |
120 | self.net = nn.ModuleList([])
121 | # project in
122 | self.net.append(act_fn)
123 | # project dropout
124 | self.net.append(nn.Dropout(dropout))
125 | # project out
126 | self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
127 | # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
128 | if final_dropout:
129 | self.net.append(nn.Dropout(dropout))
130 |
131 | def forward(self, hidden_states):
132 | for module in self.net:
133 | hidden_states = module(hidden_states)
134 | return hidden_states
135 |
136 |
137 | @maybe_allow_in_graph
138 | class BasicTransformerBlock(nn.Module):
139 | r"""
140 | A basic Transformer block.
141 |
142 | Parameters:
143 | dim (`int`): The number of channels in the input and output.
144 | num_attention_heads (`int`): The number of heads to use for multi-head attention.
145 | attention_head_dim (`int`): The number of channels in each head.
146 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
147 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
148 | only_cross_attention (`bool`, *optional*):
149 | Whether to use only cross-attention layers. In this case two cross attention layers are used.
150 | double_self_attention (`bool`, *optional*):
151 | Whether to use two self-attention layers. In this case no cross attention layers are used.
152 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
153 | num_embeds_ada_norm (:
154 | obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
155 | attention_bias (:
156 | obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
157 | """
158 |
159 | def __init__(
160 | self,
161 | dim: int,
162 | num_attention_heads: int,
163 | attention_head_dim: int,
164 | dropout=0.0,
165 | cross_attention_dim: Optional[int] = None,
166 | activation_fn: str = "geglu",
167 | num_embeds_ada_norm: Optional[int] = None,
168 | attention_bias: bool = False,
169 | only_cross_attention: bool = False,
170 | double_self_attention: bool = False,
171 | upcast_attention: bool = False,
172 | norm_elementwise_affine: bool = True,
173 | norm_type: str = "layer_norm",
174 | final_dropout: bool = False,
175 | ):
176 | super().__init__()
177 | self.only_cross_attention = only_cross_attention
178 |
179 | self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
180 | self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
181 |
182 | if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
183 | raise ValueError(
184 | f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
185 | f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
186 | )
187 |
188 | # Define 3 blocks. Each block has its own normalization layer.
189 | # 1. Self-Attn
190 | if self.use_ada_layer_norm:
191 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
192 | elif self.use_ada_layer_norm_zero:
193 | self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
194 | else:
195 | self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
196 | self.attn1 = Attention(
197 | query_dim=dim,
198 | heads=num_attention_heads,
199 | dim_head=attention_head_dim,
200 | dropout=dropout,
201 | bias=attention_bias,
202 | cross_attention_dim=cross_attention_dim if only_cross_attention else None,
203 | upcast_attention=upcast_attention,
204 | )
205 |
206 | # 2. Cross-Attn
207 | if cross_attention_dim is not None or double_self_attention:
208 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
209 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
210 | # the second cross attention block.
211 | self.norm2 = (
212 | AdaLayerNorm(dim, num_embeds_ada_norm)
213 | if self.use_ada_layer_norm
214 | else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
215 | )
216 | self.attn2 = Attention(
217 | query_dim=dim,
218 | cross_attention_dim=cross_attention_dim if not double_self_attention else None,
219 | heads=num_attention_heads,
220 | dim_head=attention_head_dim,
221 | dropout=dropout,
222 | bias=attention_bias,
223 | upcast_attention=upcast_attention,
224 | # scale_qk=False, # uncomment this to not to use flash attention
225 | ) # is self-attn if encoder_hidden_states is none
226 | else:
227 | self.norm2 = None
228 | self.attn2 = None
229 |
230 | # 3. Feed-forward
231 | self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
232 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
233 |
234 | # let chunk size default to None
235 | self._chunk_size = None
236 | self._chunk_dim = 0
237 |
238 | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
239 | # Sets chunk feed-forward
240 | self._chunk_size = chunk_size
241 | self._chunk_dim = dim
242 |
243 | def forward(
244 | self,
245 | hidden_states: torch.FloatTensor,
246 | attention_mask: Optional[torch.FloatTensor] = None,
247 | encoder_hidden_states: Optional[torch.FloatTensor] = None,
248 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
249 | timestep: Optional[torch.LongTensor] = None,
250 | cross_attention_kwargs: Dict[str, Any] = None,
251 | class_labels: Optional[torch.LongTensor] = None,
252 | ):
253 | # Notice that normalization is always applied before the real computation in the following blocks.
254 | # 1. Self-Attention
255 | if self.use_ada_layer_norm:
256 | norm_hidden_states = self.norm1(hidden_states, timestep)
257 | elif self.use_ada_layer_norm_zero:
258 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
259 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
260 | )
261 | else:
262 | norm_hidden_states = self.norm1(hidden_states)
263 |
264 | cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
265 |
266 | attn_output = self.attn1(
267 | norm_hidden_states,
268 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
269 | attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
270 | **cross_attention_kwargs,
271 | )
272 | if self.use_ada_layer_norm_zero:
273 | attn_output = gate_msa.unsqueeze(1) * attn_output
274 | hidden_states = attn_output + hidden_states
275 |
276 | # 2. Cross-Attention
277 | if self.attn2 is not None:
278 | norm_hidden_states = (
279 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
280 | )
281 |
282 | attn_output = self.attn2(
283 | norm_hidden_states,
284 | encoder_hidden_states=encoder_hidden_states,
285 | attention_mask=encoder_attention_mask,
286 | **cross_attention_kwargs,
287 | )
288 | hidden_states = attn_output + hidden_states
289 |
290 | # 3. Feed-forward
291 | norm_hidden_states = self.norm3(hidden_states)
292 |
293 | if self.use_ada_layer_norm_zero:
294 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
295 |
296 | if self._chunk_size is not None:
297 | # "feed_forward_chunk_size" can be used to save memory
298 | if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
299 | raise ValueError(
300 | f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
301 | )
302 |
303 | num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
304 | ff_output = torch.cat(
305 | [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
306 | dim=self._chunk_dim,
307 | )
308 | else:
309 | ff_output = self.ff(norm_hidden_states)
310 |
311 | if self.use_ada_layer_norm_zero:
312 | ff_output = gate_mlp.unsqueeze(1) * ff_output
313 |
314 | hidden_states = ff_output + hidden_states
315 |
316 | return hidden_states
--------------------------------------------------------------------------------
/viettts/tts.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from tqdm import tqdm
4 | from loguru import logger
5 | from hyperpyyaml import load_hyperpyyaml
6 |
7 | from viettts.model import TTSModel
8 | from viettts.frontend import TTSFrontEnd
9 | from viettts.utils.file_utils import download_model, save_wav
10 |
11 |
12 | class TTS:
13 | def __init__(
14 | self,
15 | model_dir,
16 | load_jit=False,
17 | load_onnx=False
18 | ):
19 | if not os.path.exists(model_dir):
20 | logger.info(f'Downloading model from huggingface [dangvansam/viet-tts]')
21 | download_model(model_dir)
22 |
23 | with open(f'{model_dir}/config.yaml', 'r') as f:
24 | configs = load_hyperpyyaml(f)
25 | self.frontend = TTSFrontEnd(
26 | speech_embedding_model=f'{model_dir}/speech_embedding.onnx',
27 | speech_tokenizer_model=f'{model_dir}/speech_tokenizer.onnx'
28 | )
29 | self.model = TTSModel(
30 | llm=configs['llm'],
31 | flow=configs['flow'],
32 | hift=configs['hift']
33 | )
34 | self.model.load(
35 | llm_model=f'{model_dir}/llm.pt',
36 | flow_model=f'{model_dir}/flow.pt',
37 | hift_model=f'{model_dir}/hift.pt'
38 | )
39 | if load_jit:
40 | self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
41 | '{}/llm.llm.fp16.zip'.format(model_dir),
42 | '{}/flow.encoder.fp32.zip'.format(model_dir))
43 | logger.success('Loaded jit model from {}'.format(model_dir))
44 |
45 | if load_onnx:
46 | self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
47 | logger.success('Loaded onnx model from {}'.format(model_dir))
48 |
49 | logger.success('Loaded model from {}'.format(model_dir))
50 | self.model_dir = model_dir
51 |
52 | def list_avaliable_spks(self):
53 | spks = list(self.frontend.spk2info.keys())
54 | return spks
55 |
56 | def inference_tts(self, tts_text, prompt_speech_16k, stream=False, speed=1.0):
57 | for i in tqdm(self.frontend.preprocess_text(tts_text, split=True)):
58 | model_input = self.frontend.frontend_tts(i, prompt_speech_16k)
59 | for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
60 | yield model_output
61 |
62 | def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
63 | model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k)
64 | for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
65 | yield model_output
66 |
67 | def tts_to_wav(self, text, prompt_speech_16k, speed=1.0):
68 | wavs = []
69 | for output in self.inference_tts(text, prompt_speech_16k, stream=False, speed=speed):
70 | wavs.append(output['tts_speech'].squeeze(0).numpy())
71 | return np.concatenate(wavs, axis=0)
72 |
73 | def tts_to_file(self, text, prompt_speech_16k, speed, output_path):
74 | wav = self.tts_to_wav(text, prompt_speech_16k, speed)
75 | save_wav(wav, 22050, output_path)
--------------------------------------------------------------------------------
/viettts/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/viettts/utils/__init__.py
--------------------------------------------------------------------------------
/viettts/utils/class_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from viettts.transformer.activation import Swish
4 | from viettts.transformer.subsampling import (
5 | LinearNoSubsampling,
6 | EmbedinigNoSubsampling,
7 | Conv1dSubsampling2,
8 | Conv2dSubsampling4,
9 | Conv2dSubsampling6,
10 | Conv2dSubsampling8,
11 | )
12 | from viettts.transformer.embedding import (
13 | PositionalEncoding,
14 | RelPositionalEncoding,
15 | WhisperPositionalEncoding,
16 | LearnablePositionalEncoding,
17 | NoPositionalEncoding
18 | )
19 | from viettts.transformer.attention import (
20 | MultiHeadedAttention,
21 | RelPositionMultiHeadedAttention
22 | )
23 | from viettts.transformer.embedding import EspnetRelPositionalEncoding
24 | from viettts.transformer.subsampling import LegacyLinearNoSubsampling
25 |
26 |
27 | ACTIVATION_CLASSES = {
28 | "hardtanh": torch.nn.Hardtanh,
29 | "tanh": torch.nn.Tanh,
30 | "relu": torch.nn.ReLU,
31 | "selu": torch.nn.SELU,
32 | "swish": getattr(torch.nn, "SiLU", Swish),
33 | "gelu": torch.nn.GELU,
34 | }
35 |
36 | SUBSAMPLE_CLASSES = {
37 | "linear": LinearNoSubsampling,
38 | "linear_legacy": LegacyLinearNoSubsampling,
39 | "embed": EmbedinigNoSubsampling,
40 | "conv1d2": Conv1dSubsampling2,
41 | "conv2d": Conv2dSubsampling4,
42 | "conv2d6": Conv2dSubsampling6,
43 | "conv2d8": Conv2dSubsampling8,
44 | 'paraformer_dummy': torch.nn.Identity
45 | }
46 |
47 | EMB_CLASSES = {
48 | "embed": PositionalEncoding,
49 | "abs_pos": PositionalEncoding,
50 | "rel_pos": RelPositionalEncoding,
51 | "rel_pos_espnet": EspnetRelPositionalEncoding,
52 | "no_pos": NoPositionalEncoding,
53 | "abs_pos_whisper": WhisperPositionalEncoding,
54 | "embed_learnable_pe": LearnablePositionalEncoding,
55 | }
56 |
57 | ATTENTION_CLASSES = {
58 | "selfattn": MultiHeadedAttention,
59 | "rel_selfattn": RelPositionMultiHeadedAttention,
60 | }
61 |
--------------------------------------------------------------------------------
/viettts/utils/common.py:
--------------------------------------------------------------------------------
1 | # Modified from ESPnet(https://github.com/espnet/espnet)
2 | """Unility functions for Transformer."""
3 |
4 | import random
5 | from typing import List
6 |
7 | import numpy as np
8 | import torch
9 |
10 | IGNORE_ID = -1
11 |
12 |
13 | def pad_list(xs: List[torch.Tensor], pad_value: int):
14 | """Perform padding for the list of tensors.
15 |
16 | Args:
17 | xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
18 | pad_value (float): Value for padding.
19 |
20 | Returns:
21 | Tensor: Padded tensor (B, Tmax, `*`).
22 |
23 | Examples:
24 | >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
25 | >>> x
26 | [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
27 | >>> pad_list(x, 0)
28 | tensor([[1., 1., 1., 1.],
29 | [1., 1., 0., 0.],
30 | [1., 0., 0., 0.]])
31 |
32 | """
33 | max_len = max([len(item) for item in xs])
34 | batchs = len(xs)
35 | ndim = xs[0].ndim
36 | if ndim == 1:
37 | pad_res = torch.zeros(batchs,
38 | max_len,
39 | dtype=xs[0].dtype,
40 | device=xs[0].device)
41 | elif ndim == 2:
42 | pad_res = torch.zeros(batchs,
43 | max_len,
44 | xs[0].shape[1],
45 | dtype=xs[0].dtype,
46 | device=xs[0].device)
47 | elif ndim == 3:
48 | pad_res = torch.zeros(batchs,
49 | max_len,
50 | xs[0].shape[1],
51 | xs[0].shape[2],
52 | dtype=xs[0].dtype,
53 | device=xs[0].device)
54 | else:
55 | raise ValueError(f"Unsupported ndim: {ndim}")
56 | pad_res.fill_(pad_value)
57 | for i in range(batchs):
58 | pad_res[i, :len(xs[i])] = xs[i]
59 | return pad_res
60 |
61 |
62 | def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor,
63 | ignore_label: int) -> torch.Tensor:
64 | """Calculate accuracy.
65 |
66 | Args:
67 | pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
68 | pad_targets (LongTensor): Target label tensors (B, Lmax).
69 | ignore_label (int): Ignore label id.
70 |
71 | Returns:
72 | torch.Tensor: Accuracy value (0.0 - 1.0).
73 |
74 | """
75 | pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1),
76 | pad_outputs.size(1)).argmax(2)
77 | mask = pad_targets != ignore_label
78 | numerator = torch.sum(
79 | pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
80 | denominator = torch.sum(mask)
81 | return (numerator / denominator).detach()
82 |
83 |
84 | def get_padding(kernel_size, dilation=1):
85 | return int((kernel_size * dilation - dilation) / 2)
86 |
87 |
88 | def init_weights(m, mean=0.0, std=0.01):
89 | classname = m.__class__.__name__
90 | if classname.find("Conv") != -1:
91 | m.weight.data.normal_(mean, std)
92 |
93 |
94 | # Repetition Aware Sampling in VALL-E 2
95 | def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
96 | top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
97 | rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item()
98 | if rep_num >= win_size * tau_r:
99 | top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
100 | return top_ids
101 |
102 |
103 | def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
104 | prob, indices = [], []
105 | cum_prob = 0.0
106 | sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
107 | for i in range(len(sorted_idx)):
108 | # sampling both top-p and numbers.
109 | if (cum_prob < top_p or len(prob) <=1) and len(prob) str:
19 | """
20 | Convert an input audio file to WAV format with the desired sample rate using FFmpeg.
21 |
22 | Args:
23 | input_filepath (str): Path to the input audio file.
24 | target_sr (int): Target sample rate.
25 |
26 | Returns:
27 | str: Path to the converted WAV file.
28 | """
29 | temp_wav_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
30 | temp_wav_filepath = temp_wav_file.name
31 | temp_wav_file.close()
32 |
33 | ffmpeg_command = [
34 | "ffmpeg", "-y",
35 | "-loglevel", "error",
36 | "-i", input_filepath,
37 | "-ar", str(target_sr),
38 | "-ac", "1",
39 | temp_wav_filepath
40 | ]
41 |
42 | result = subprocess.run(ffmpeg_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
43 | if result.returncode != 0:
44 | os.unlink(temp_wav_filepath)
45 | raise RuntimeError(f"FFmpeg conversion failed: {result.stderr.decode()}")
46 |
47 | return temp_wav_filepath
48 |
49 |
50 | def load_wav(filepath: str, target_sr: int):
51 | """
52 | Load an audio file in any supported format, convert it to WAV, and load as a tensor.
53 |
54 | Args:
55 | filepath (str): Path to the audio file in any format.
56 | target_sr (int): Target sample rate.
57 |
58 | Returns:
59 | Tensor: Loaded audio tensor resampled to the target sample rate.
60 | """
61 | # Check if the file is already in WAV format
62 | if not filepath.lower().endswith(".wav"):
63 | logger.info(f"Converting {filepath} to WAV format")
64 | filepath = convert_to_wav(filepath, target_sr)
65 |
66 | # Load the WAV file
67 | speech, sample_rate = torchaudio.load(filepath)
68 | speech = speech.mean(dim=0, keepdim=True) # Convert to mono if not already
69 | if sample_rate != target_sr:
70 | assert sample_rate > target_sr, f'WAV sample rate {sample_rate} must be greater than {target_sr}'
71 | speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
72 |
73 | return speech
74 |
75 |
76 | def save_wav(wav: np.ndarray, sr: int, filepath: str):
77 | soundfile.write(filepath, wav, sr)
78 |
79 |
80 | def load_prompt_speech_from_file(filepath: str, min_duration: float=3, max_duration: float=5, return_numpy: bool=False):
81 | wav = load_wav(filepath, 16000)
82 |
83 | if wav.abs().max() > 0.9:
84 | wav = wav / wav.abs().max() * 0.9
85 |
86 | wav = get_speech(
87 | audio_input=wav.squeeze(0),
88 | min_duration=min_duration,
89 | max_duration=max_duration,
90 | return_numpy=return_numpy
91 | )
92 | return wav
93 |
94 |
95 | def load_voices(voice_dir: str):
96 | files = glob(os.path.join(voice_dir, '*.wav')) + glob(os.path.join(voice_dir, '*.mp3'))
97 | voice_name_map = {
98 | os.path.basename(f).split('.')[0]: f
99 | for f in files
100 | }
101 | return voice_name_map
102 |
103 |
104 | def download_model(save_dir: str):
105 | snapshot_download(repo_id="dangvansam/viet-tts", local_dir=save_dir)
--------------------------------------------------------------------------------
/viettts/utils/frontend_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 | import numpy as np
4 | import torch.utils.data
5 | from vinorm import TTSnorm
6 | from librosa.filters import mel as librosa_mel_fn
7 | from scipy.io.wavfile import read
8 |
9 | MAX_WAV_VALUE = 32768.0
10 |
11 |
12 | def remove_urls_and_links(text):
13 | url_pattern = r"http[s]?:\/\/(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+|www\.[a-zA-Z0-9.\/]+"
14 | markdown_image_pattern = r"!\[.*?\]\(http[s]?:\/\/.*?\)"
15 | text = re.sub(markdown_image_pattern, '', text, 0, re.MULTILINE)
16 | text = re.sub(url_pattern, '', text, 0, re.MULTILINE)
17 | return text
18 |
19 |
20 | def remove_emojis(text):
21 | emoji_pattern = re.compile(
22 | "["
23 | "\U0001F600-\U0001F64F" # emoticons
24 | "\U0001F300-\U0001F5FF" # symbols & pictographs
25 | "\U0001F680-\U0001F6FF" # transport & map symbols
26 | "\U0001F1E0-\U0001F1FF" # flags (iOS)
27 | "\U00002702-\U000027B0" # other miscellaneous symbols
28 | "\U000024C2-\U0001F251"
29 | "\U0001F900-\U0001F9FF" # Supplemental Symbols and Pictographs
30 | "\U0001FA70-\U0001FAFF" # Symbols and Pictographs Extended-A
31 | "\U0001F004-\U0001F0CF" # Mahjong and Playing Cards
32 | "]+",
33 | flags=re.UNICODE
34 | )
35 | return emoji_pattern.sub(r'', text)
36 |
37 |
38 | def remove_punc(text):
39 | text = (text
40 | .replace(' ', '')
41 | .replace("..", ".")
42 | .replace("!.", "!")
43 | .replace('!', ".")
44 | .replace("?.", "?")
45 | .replace("?", ".")
46 | .replace(" .", ".")
47 | .replace(" ,", ",")
48 | .replace('"', "")
49 | .replace("'", "")
50 | .replace("AI", "Ây Ai")
51 | .replace("A.I", "Ây Ai")
52 | .replace("$", "")
53 | .replace("(", "")
54 | .replace(")", "")
55 | .replace("**", "")
56 | .replace(" = ", " bằng ")
57 | .replace("#", "")
58 | .replace('\\', '')
59 | .replace('```', '')
60 | .replace('- ', '')
61 | .replace('+ ', '')
62 | .replace(":", "")
63 | .replace(",,", ",")
64 | .replace(", ,", ",")
65 | .replace(",.", ".")
66 | .replace(".,", ".")
67 | .replace("..", ".")
68 | .replace(". .", ".")
69 | )
70 | text = re.sub(r'\n+', ' ', text)
71 | text = ' '.join([t for t in text.split() if t.strip()])
72 | text = text.strip()
73 | return text
74 |
75 |
76 | def normalize_text(text: str) -> str:
77 | text = text.strip()
78 | text = remove_urls_and_links(text)
79 | text = remove_emojis(text)
80 | text = remove_punc(text)
81 | text = TTSnorm(text, lower=False)
82 | return text
83 |
84 |
85 | def split_text(text: str, tokenize, token_max_n=80, token_min_n=60, merge_len=20, comma_split=False):
86 | def calc_utt_length(_text: str):
87 | return len(tokenize(_text))
88 |
89 | def should_merge(_text: str):
90 | return len(tokenize(_text)) < merge_len
91 |
92 | pounc = ['.', '?', '!', ';', ':']
93 | if comma_split:
94 | pounc.extend([',', ','])
95 |
96 | if text[-1] not in pounc:
97 | text += "."
98 |
99 | st = 0
100 | utts = []
101 | for i, c in enumerate(text):
102 | if c in pounc:
103 | if len(text[st: i]) > 0:
104 | utts.append(text[st: i] + c)
105 | if i + 1 < len(text) and text[i + 1] in ['"', '”']:
106 | tmp = utts.pop(-1)
107 | utts.append(tmp + text[i + 1])
108 | st = i + 2
109 | else:
110 | st = i + 1
111 |
112 | final_utts = []
113 | cur_utt = ""
114 | for utt in utts:
115 | if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
116 | final_utts.append(cur_utt)
117 | cur_utt = ""
118 | cur_utt = cur_utt + utt
119 | if len(cur_utt) > 0:
120 | if should_merge(cur_utt) and len(final_utts) != 0:
121 | final_utts[-1] = final_utts[-1] + cur_utt
122 | else:
123 | final_utts.append(cur_utt)
124 |
125 | final_utts = [utt.strip() for utt in final_utts]
126 | return final_utts
127 |
128 |
129 | def dynamic_range_compression(x, C=1, clip_val=1e-5):
130 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
131 |
132 |
133 | def dynamic_range_decompression(x, C=1):
134 | return np.exp(x) / C
135 |
136 |
137 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
138 | return torch.log(torch.clamp(x, min=clip_val) * C)
139 |
140 |
141 | def dynamic_range_decompression_torch(x, C=1):
142 | return torch.exp(x) / C
143 |
144 |
145 | def spectral_normalize_torch(magnitudes):
146 | output = dynamic_range_compression_torch(magnitudes)
147 | return output
148 |
149 |
150 | def spectral_de_normalize_torch(magnitudes):
151 | output = dynamic_range_decompression_torch(magnitudes)
152 | return output
153 |
154 |
155 | mel_basis = {}
156 | hann_window = {}
157 |
158 |
159 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
160 | if torch.min(y) < -1.0:
161 | print("min value is ", torch.min(y))
162 | if torch.max(y) > 1.0:
163 | print("max value is ", torch.max(y))
164 |
165 | global mel_basis, hann_window # pylint: disable=global-statement
166 | if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
167 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
168 | mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
169 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
170 |
171 | y = torch.nn.functional.pad(
172 | y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
173 | )
174 | y = y.squeeze(1)
175 |
176 | spec = torch.view_as_real(
177 | torch.stft(
178 | y,
179 | n_fft,
180 | hop_length=hop_size,
181 | win_length=win_size,
182 | window=hann_window[str(y.device)],
183 | center=center,
184 | pad_mode="reflect",
185 | normalized=False,
186 | onesided=True,
187 | return_complex=True,
188 | )
189 | )
190 |
191 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
192 |
193 | spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
194 | spec = spectral_normalize_torch(spec)
195 |
196 | return spec
197 |
198 |
199 | # def tokenize(data, get_tokenizer, allowed_special):
200 | # """ Decode text to chars or BPE
201 | # Inplace operation
202 |
203 | # Args:
204 | # data: Iterable[{key, wav, txt, sample_rate}]
205 |
206 | # Returns:
207 | # Iterable[{key, wav, txt, tokens, label, sample_rate}]
208 | # """
209 | # tokenizer = get_tokenizer()
210 | # for sample in data:
211 | # assert 'text' in sample
212 | # sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
213 | # sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
214 |
--------------------------------------------------------------------------------
/viettts/utils/mask.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def subsequent_mask(
5 | size: int,
6 | device: torch.device = torch.device("cpu"),
7 | ) -> torch.Tensor:
8 | """Create mask for subsequent steps (size, size).
9 |
10 | This mask is used only in decoder which works in an auto-regressive mode.
11 | This means the current step could only do attention with its left steps.
12 |
13 | In encoder, fully attention is used when streaming is not necessary and
14 | the sequence is not long. In this case, no attention mask is needed.
15 |
16 | When streaming is need, chunk-based attention is used in encoder. See
17 | subsequent_chunk_mask for the chunk-based attention mask.
18 |
19 | Args:
20 | size (int): size of mask
21 | str device (str): "cpu" or "cuda" or torch.Tensor.device
22 | dtype (torch.device): result dtype
23 |
24 | Returns:
25 | torch.Tensor: mask
26 |
27 | Examples:
28 | >>> subsequent_mask(3)
29 | [[1, 0, 0],
30 | [1, 1, 0],
31 | [1, 1, 1]]
32 | """
33 | arange = torch.arange(size, device=device)
34 | mask = arange.expand(size, size)
35 | arange = arange.unsqueeze(-1)
36 | mask = mask <= arange
37 | return mask
38 |
39 |
40 | def subsequent_chunk_mask(
41 | size: int,
42 | chunk_size: int,
43 | num_left_chunks: int = -1,
44 | device: torch.device = torch.device("cpu"),
45 | ) -> torch.Tensor:
46 | """Create mask for subsequent steps (size, size) with chunk size,
47 | this is for streaming encoder
48 |
49 | Args:
50 | size (int): size of mask
51 | chunk_size (int): size of chunk
52 | num_left_chunks (int): number of left chunks
53 | <0: use full chunk
54 | >=0: use num_left_chunks
55 | device (torch.device): "cpu" or "cuda" or torch.Tensor.device
56 |
57 | Returns:
58 | torch.Tensor: mask
59 |
60 | Examples:
61 | >>> subsequent_chunk_mask(4, 2)
62 | [[1, 1, 0, 0],
63 | [1, 1, 0, 0],
64 | [1, 1, 1, 1],
65 | [1, 1, 1, 1]]
66 | """
67 | ret = torch.zeros(size, size, device=device, dtype=torch.bool)
68 | for i in range(size):
69 | if num_left_chunks < 0:
70 | start = 0
71 | else:
72 | start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
73 | ending = min((i // chunk_size + 1) * chunk_size, size)
74 | ret[i, start:ending] = True
75 | return ret
76 |
77 |
78 | def add_optional_chunk_mask(xs: torch.Tensor,
79 | masks: torch.Tensor,
80 | use_dynamic_chunk: bool,
81 | use_dynamic_left_chunk: bool,
82 | decoding_chunk_size: int,
83 | static_chunk_size: int,
84 | num_decoding_left_chunks: int,
85 | enable_full_context: bool = True):
86 | """ Apply optional mask for encoder.
87 |
88 | Args:
89 | xs (torch.Tensor): padded input, (B, L, D), L for max length
90 | mask (torch.Tensor): mask for xs, (B, 1, L)
91 | use_dynamic_chunk (bool): whether to use dynamic chunk or not
92 | use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
93 | training.
94 | decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
95 | 0: default for training, use random dynamic chunk.
96 | <0: for decoding, use full chunk.
97 | >0: for decoding, use fixed chunk size as set.
98 | static_chunk_size (int): chunk size for static chunk training/decoding
99 | if it's greater than 0, if use_dynamic_chunk is true,
100 | this parameter will be ignored
101 | num_decoding_left_chunks: number of left chunks, this is for decoding,
102 | the chunk size is decoding_chunk_size.
103 | >=0: use num_decoding_left_chunks
104 | <0: use all left chunks
105 | enable_full_context (bool):
106 | True: chunk size is either [1, 25] or full context(max_len)
107 | False: chunk size ~ U[1, 25]
108 |
109 | Returns:
110 | torch.Tensor: chunk mask of the input xs.
111 | """
112 | # Whether to use chunk mask or not
113 | if use_dynamic_chunk:
114 | max_len = xs.size(1)
115 | if decoding_chunk_size < 0:
116 | chunk_size = max_len
117 | num_left_chunks = -1
118 | elif decoding_chunk_size > 0:
119 | chunk_size = decoding_chunk_size
120 | num_left_chunks = num_decoding_left_chunks
121 | else:
122 | # chunk size is either [1, 25] or full context(max_len).
123 | # Since we use 4 times subsampling and allow up to 1s(100 frames)
124 | # delay, the maximum frame is 100 / 4 = 25.
125 | chunk_size = torch.randint(1, max_len, (1, )).item()
126 | num_left_chunks = -1
127 | if chunk_size > max_len // 2 and enable_full_context:
128 | chunk_size = max_len
129 | else:
130 | chunk_size = chunk_size % 25 + 1
131 | if use_dynamic_left_chunk:
132 | max_left_chunks = (max_len - 1) // chunk_size
133 | num_left_chunks = torch.randint(0, max_left_chunks,
134 | (1, )).item()
135 | chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
136 | num_left_chunks,
137 | xs.device) # (L, L)
138 | chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
139 | chunk_masks = masks & chunk_masks # (B, L, L)
140 | elif static_chunk_size > 0:
141 | num_left_chunks = num_decoding_left_chunks
142 | chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
143 | num_left_chunks,
144 | xs.device) # (L, L)
145 | chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
146 | chunk_masks = masks & chunk_masks # (B, L, L)
147 | else:
148 | chunk_masks = masks
149 | return chunk_masks
150 |
151 |
152 | def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
153 | """Make mask tensor containing indices of padded part.
154 |
155 | See description of make_non_pad_mask.
156 |
157 | Args:
158 | lengths (torch.Tensor): Batch of lengths (B,).
159 | Returns:
160 | torch.Tensor: Mask tensor containing indices of padded part.
161 |
162 | Examples:
163 | >>> lengths = [5, 3, 2]
164 | >>> make_pad_mask(lengths)
165 | masks = [[0, 0, 0, 0 ,0],
166 | [0, 0, 0, 1, 1],
167 | [0, 0, 1, 1, 1]]
168 | """
169 | batch_size = lengths.size(0)
170 | max_len = max_len if max_len > 0 else lengths.max().item()
171 | seq_range = torch.arange(0,
172 | max_len,
173 | dtype=torch.int64,
174 | device=lengths.device)
175 | seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
176 | seq_length_expand = lengths.unsqueeze(-1)
177 | mask = seq_range_expand >= seq_length_expand
178 | return mask
179 |
--------------------------------------------------------------------------------
/viettts/utils/vad.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | import numpy as np
3 | import torch
4 | from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
5 |
6 | model = load_silero_vad()
7 |
8 | def get_speech(
9 | audio_input: Union[str, np.ndarray, torch.Tensor],
10 | return_numpy: bool=False,
11 | min_duration: float=3,
12 | max_duration: float=5
13 | ) -> Union[torch.Tensor, np.ndarray]:
14 |
15 | if isinstance(audio_input, str):
16 | audio_input = read_audio(audio_input)
17 | speech_timestamps = get_speech_timestamps(audio_input, model)
18 | speech = [audio_input[t['start']:t['end']] \
19 | for t in speech_timestamps \
20 | if (t['end'] - t['start']) >= 16000 * min_duration \
21 | and (t['end'] - t['start']) <= 16000 * max_duration]
22 | if not speech:
23 | speech = audio_input[:int(max_duration*16000)].unsqueeze(0)
24 | else:
25 | speech = speech[0].unsqueeze(0)
26 | if return_numpy:
27 | speech = speech.cpu().numpy()
28 | return speech
29 |
30 |
31 | if __name__ == '__main__':
32 | print(get_speech('samples/diep-chi.wav'))
--------------------------------------------------------------------------------
/web/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dangvansam/viet-tts/c29b7535d4a81aefbbef175d0db26ac9df4bc028/web/.gitkeep
--------------------------------------------------------------------------------