├── LICENSE
├── README-ZH.md
├── README.md
├── customs
└── ph.txt
├── data
├── __init__.py
├── collation.py
├── datamodule.py
├── dataset.py
├── fbank.py
├── input_strategies.py
└── tokenizer.py
├── descriptions.py
├── examples.py
├── images
└── vallex_framework.jpg
├── launch-ui.py
├── macros.py
├── model-card.md
├── models
├── __init__.py
├── macros.py
├── transformer.py
├── vallex.py
└── visualizer.py
├── modules
├── __init__.py
├── activation.py
├── embedding.py
├── optim.py
├── scaling.py
├── scheduler.py
└── transformer.py
├── nltk_data
└── tokenizers
│ └── punkt
│ ├── .DS_Store
│ ├── PY3
│ ├── README
│ ├── czech.pickle
│ ├── danish.pickle
│ ├── dutch.pickle
│ ├── english.pickle
│ ├── estonian.pickle
│ ├── finnish.pickle
│ ├── french.pickle
│ ├── german.pickle
│ ├── greek.pickle
│ ├── italian.pickle
│ ├── malayalam.pickle
│ ├── norwegian.pickle
│ ├── polish.pickle
│ ├── portuguese.pickle
│ ├── russian.pickle
│ ├── slovene.pickle
│ ├── spanish.pickle
│ ├── swedish.pickle
│ └── turkish.pickle
│ ├── README
│ ├── czech.pickle
│ ├── danish.pickle
│ ├── dutch.pickle
│ ├── english.pickle
│ ├── estonian.pickle
│ ├── finnish.pickle
│ ├── french.pickle
│ ├── german.pickle
│ ├── greek.pickle
│ ├── italian.pickle
│ ├── malayalam.pickle
│ ├── norwegian.pickle
│ ├── polish.pickle
│ ├── portuguese.pickle
│ ├── russian.pickle
│ ├── slovene.pickle
│ ├── spanish.pickle
│ ├── swedish.pickle
│ └── turkish.pickle
├── presets
├── acou_1.npz
├── acou_2.npz
├── acou_3.npz
├── acou_4.npz
├── alan.npz
├── amused.npz
├── anger.npz
├── babara.npz
├── bronya.npz
├── cafe.npz
├── dingzhen.npz
├── disgust.npz
├── emo_amused.npz
├── emo_anger.npz
├── emo_neutral.npz
├── emo_sleepy.npz
├── emotion_sleepiness.npz
├── en2zh_tts_1.npz
├── en2zh_tts_2.npz
├── en2zh_tts_3.npz
├── en2zh_tts_4.npz
├── esta.npz
├── fuxuan.npz
├── librispeech_1.npz
├── librispeech_2.npz
├── librispeech_3.npz
├── librispeech_4.npz
├── neutral.npz
├── paimon.npz
├── rosalia.npz
├── seel.npz
├── sleepiness.npz
├── vctk_1.npz
├── vctk_2.npz
├── vctk_3.npz
├── vctk_4.npz
├── yaesakura.npz
├── zh2en_tts_1.npz
├── zh2en_tts_2.npz
├── zh2en_tts_3.npz
└── zh2en_tts_4.npz
├── prompts
├── en-1.wav
├── en-2.wav
├── ja-1.wav
├── ja-2.ogg
├── ph.txt
├── zh-1.wav
└── zh-2.wav
├── requirements.txt
└── utils
├── __init__.py
├── download.py
├── g2p
├── __init__.py
├── bpe_1024.json
├── bpe_69.json
├── cleaners.py
├── english.py
├── japanese.py
├── mandarin.py
└── symbols.py
├── generation.py
├── prompt_making.py
├── sentence_cutter.py
└── symbol_table.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Songting
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README-ZH.md:
--------------------------------------------------------------------------------
1 | # VALL-E X: 多语言文本到语音合成与语音克隆 🔊
2 | [](https://discord.gg/qCBRmAnTxg)
3 |
4 | [English](README.md) | 中文
5 |
6 | 微软[VALL-E X](https://arxiv.org/pdf/2303.03926) 零样本语音合成模型的开源实现.
7 | **预训练模型现已向公众开放,供研究或应用使用。**
8 | 
9 |
10 | VALL-E X 是一个强大而创新的多语言文本转语音(TTS)模型,最初由微软发布。虽然微软最初在他们的研究论文中提出了该概念,但并未发布任何代码或预训练模型。我们认识到了这项技术的潜力和价值,复现并训练了一个开源可用的VALL-E X模型。我们很乐意与社区分享我们的预训练模型,让每个人都能体验到次世代TTS的威力。 🎧
11 |
12 | 更多细节请查看 [model card](./model-card.md).
13 |
14 | ## 📖 目录
15 | * [🚀 更新日志](#-更新日志)
16 | * [📢 功能特点](#-功能特点)
17 | * [💻 本地安装](#-本地安装)
18 | * [🎧 在线Demo](#-在线Demo)
19 | * [🐍 使用方法](#-Python中的使用方法)
20 | * [❓ FAQ](#-faq)
21 | * [🧠 TODO](#-todo)
22 |
23 | ## 🚀 Updates
24 | **2023.09.10**
25 | - 支持AR decoder的batch decoding以实现更稳定的生成结果
26 |
27 | **2023.08.30**
28 | - 将EnCodec解码器替换成了Vocos解码器,提升了音质。 (感谢[@v0xie](https://github.com/v0xie))
29 |
30 | **2023.08.23**
31 | - 加入了长文本生成功能
32 |
33 | **2023.08.20**
34 | - 加入了中文版README
35 |
36 | **2023.08.14**
37 | - 预训练模型权重已发布,从[这里](https://drive.google.com/file/d/10gdQWvP-K_e1undkvv0p2b7SU6I4Egyl/view?usp=sharing)下载。
38 |
39 | ## 💻 本地安装
40 | ### 使用pip安装,必须使用Python 3.10,CUDA 11.7 ~ 12.0,PyTorch 2.0+
41 | ```commandline
42 | git clone https://github.com/Plachtaa/VALL-E-X.git
43 | cd VALL-E-X
44 | pip install -r requirements.txt
45 | ```
46 |
47 | > 注意:如果需要制作prompt,需要安装 ffmpeg 并将其所在文件夹加入到环境变量PATH中
48 |
49 | 第一次运行程序时,会自动下载相应的模型。如果下载失败并报错,请按照以下步骤手动下载模型。
50 |
51 | (请注意目录和文件夹的大小写)
52 |
53 | 1.检查安装目录下是否存在`checkpoints`文件夹,如果没有,在安装目录下手动创建`checkpoints`文件夹(`./checkpoints/`)。
54 |
55 | 2.检查`checkpoints`文件夹中是否有`vallex-checkpoint.pt`文件。如果没有,请从[这里](https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt)
56 | 手动下载`vallex-checkpoint.pt`文件并放到`checkpoints`文件夹里。
57 |
58 | 3.检查安装目录下是否存在`whisper`文件夹,如果没有,在安装目录下手动创建`whisper`文件夹(`./whisper/`)。
59 |
60 | 4.检查`whisper`文件夹中是否有`medium.pt`文件。如果没有,请从[这里](https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt)
61 | 手动下载`medium.pt`文件并放到`whisper`文件夹里。
62 |
63 | ## 🎧 在线Demo
64 | 如果你不想在本地安装,你可以在线体验VALL-E X的功能,点击下面的任意一个链接即可开始体验。
65 |
66 | [](https://huggingface.co/spaces/Plachta/VALL-E-X)
67 | [](https://colab.research.google.com/drive/1yyD_sz531QntLKowMHo-XxorsFBCfKul?usp=sharing)
68 |
69 |
70 | ## 📢 功能特点
71 |
72 | VALL-E X 配备有一系列尖端功能:
73 |
74 | 1. **多语言 TTS**: 可使用三种语言 - 英语、中文和日语 - 进行自然、富有表现力的语音合成。
75 |
76 | 2. **零样本语音克隆**: 仅需录制任意说话人的短短的 3~10 秒录音,VALL-E X 就能生成个性化、高质量的语音,完美还原他们的声音。
77 |
78 |
79 | 查看示例
80 |
81 | [prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/a7baa51d-a53a-41cc-a03d-6970f25fcca7)
82 |
83 |
84 | [output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/b895601a-d126-4138-beff-061aabdc7985)
85 |
86 |
87 |
88 | 3. **语音情感控制**: VALL-E X 可以合成与给定说话人录音相同情感的语音,为音频增添更多表现力。
89 |
90 |
91 | 查看示例
92 |
93 | https://github.com/Plachtaa/VALL-E-X/assets/112609742/56fa9988-925e-4757-82c5-83ecb0df6266
94 |
95 |
96 | https://github.com/Plachtaa/VALL-E-X/assets/112609742/699c47a3-d502-4801-8364-bd89bcc0b8f1
97 |
98 |
99 |
100 | 4. **零样本跨语言语音合成**: VALL-E X 可以合成与给定说话人母语不同的另一种语言,在不影响口音和流利度的同时,保留该说话人的音色与情感。以下是一个使用日语母语者进行英文与中文合成的样例: 🇯🇵 🗣
101 |
102 |
103 | 查看示例
104 |
105 | [jp-prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/ea6e2ee4-139a-41b4-837e-0bd04dda6e19)
106 |
107 |
108 | [en-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/db8f9782-923f-425e-ba94-e8c1bd48f207)
109 |
110 |
111 | [zh-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/15829d79-e448-44d3-8965-fafa7a3f8c28)
112 |
113 |
114 |
115 | 5. **口音控制**: VALL-E X 允许您控制所合成音频的口音,比如说中文带英语口音或反之。 🇨🇳 💬
116 |
117 |
118 | 查看示例
119 |
120 | [en-prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/f688d7f6-70ef-46ec-b1cc-355c31e78b3b)
121 |
122 |
123 | [zh-accent-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/be59c7ca-b45b-44ca-a30d-4d800c950ccc)
124 |
125 |
126 | [en-accent-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/8b4f4f9b-f299-4ea4-a548-137437b71738)
127 |
128 |
129 |
130 | 6. **声学环境保留**: 当给定说话人的录音在不同的声学环境下录制时,VALL-E X 可以保留该声学环境,使合成语音听起来更加自然。
131 |
132 |
133 | 查看示例
134 |
135 | [noise-prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/68986d88-abd0-4d1d-96e4-4f893eb9259e)
136 |
137 |
138 | [noise-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/96c4c612-4516-4683-8804-501b70938608)
139 |
140 |
141 |
142 |
143 | 你可以访问我们的[demo页面](https://plachtaa.github.io/) 来浏览更多示例!
144 |
145 | ## 💻 Python中的使用方法
146 |
147 |
148 | 🪑 基本使用
149 |
150 | ```python
151 | from utils.generation import SAMPLE_RATE, generate_audio, preload_models
152 | from scipy.io.wavfile import write as write_wav
153 | from IPython.display import Audio
154 |
155 | # download and load all models
156 | preload_models()
157 |
158 | # generate audio from text
159 | text_prompt = """
160 | Hello, my name is Nose. And uh, and I like hamburger. Hahaha... But I also have other interests such as playing tactic toast.
161 | """
162 | audio_array = generate_audio(text_prompt)
163 |
164 | # save audio to disk
165 | write_wav("vallex_generation.wav", SAMPLE_RATE, audio_array)
166 |
167 | # play text in notebook
168 | Audio(audio_array, rate=SAMPLE_RATE)
169 | ```
170 |
171 | [hamburger.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/578d7bbe-cda9-483e-898c-29646edc8f2e)
172 |
173 |
174 |
175 |
176 | 🌎 多语言
177 |
178 | 该VALL-E X实现支持三种语言:英语、中文和日语。您可以通过设置`language`参数来指定语言。默认情况下,该模型将自动检测语言。
179 |
180 |
181 | ```python
182 |
183 | text_prompt = """
184 | チュソクは私のお気に入りの祭りです。 私は数日間休んで、友人や家族との時間を過ごすことができます。
185 | """
186 | audio_array = generate_audio(text_prompt)
187 | ```
188 |
189 | [vallex_japanese.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/ee57a688-3e83-4be5-b0fe-019d16eec51c)
190 |
191 | *注意:即使在一句话中混合多种语言的情况下,VALL-E X也能完美地控制口音,但是您需要手动标记各个句子对应的语言以便于我们的G2P工具识别它们。*
192 | ```python
193 | text_prompt = """
194 | [EN]The Thirty Years' War was a devastating conflict that had a profound impact on Europe.[EN]
195 | [ZH]这是历史的开始。 如果您想听更多,请继续。[ZH]
196 | """
197 | audio_array = generate_audio(text_prompt, language='mix')
198 | ```
199 |
200 | [vallex_codeswitch.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/d8667abf-bd08-499f-a383-a861d852f98a)
201 |
202 |
203 |
204 |
205 | 📼 预设音色
206 |
207 | 我们提供十几种说话人音色可直接VALL-E X使用! 在[这里](/presets)浏览所有可用音色。
208 |
209 | > VALL-E X 尝试匹配给定预设音色的音调、音高、情感和韵律。该模型还尝试保留音乐、环境噪声等。
210 | ```python
211 | text_prompt = """
212 | I am an innocent boy with a smoky voice. It is a great honor for me to speak at the United Nations today.
213 | """
214 | audio_array = generate_audio(text_prompt, prompt="dingzhen")
215 | ```
216 |
217 | [smoky.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/d3f55732-b1cd-420f-87d6-eab60db14dc5)
218 |
219 |
220 |
221 |
222 | 🎙声音克隆
223 |
224 | VALL-E X 支持声音克隆!你可以使用任何人,角色,甚至是你自己的声音,来制作一个音频提示。在你使用该音频提示时,VALL-E X 将会使用与其相似的声音来合成文本。
225 |
226 | 你需要提供一段3~10秒长的语音,以及该语音对应的文本,来制作音频提示。你也可以将文本留空,让[Whisper](https://github.com/openai/whisper)模型为你生成文本。
227 | > VALL-E X 尝试匹配给定音频提示的音调、音高、情感和韵律。该模型还尝试保留音乐、环境噪声等。
228 |
229 | ```python
230 | from utils.prompt_making import make_prompt
231 |
232 | ### Use given transcript
233 | make_prompt(name="paimon", audio_prompt_path="paimon_prompt.wav",
234 | transcript="Just, what was that? Paimon thought we were gonna get eaten.")
235 |
236 | ### Alternatively, use whisper
237 | make_prompt(name="paimon", audio_prompt_path="paimon_prompt.wav")
238 | ```
239 | 来尝试一下刚刚做好的音频提示吧!
240 | ```python
241 | from utils.generation import SAMPLE_RATE, generate_audio, preload_models
242 | from scipy.io.wavfile import write as write_wav
243 |
244 | # download and load all models
245 | preload_models()
246 |
247 | text_prompt = """
248 | Hey, Traveler, Listen to this, This machine has taken my voice, and now it can talk just like me!
249 | """
250 | audio_array = generate_audio(text_prompt, prompt="paimon")
251 |
252 | write_wav("paimon_cloned.wav", SAMPLE_RATE, audio_array)
253 |
254 | ```
255 |
256 | [paimon_prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/e7922859-9d12-4e2a-8651-e156e4280311)
257 |
258 |
259 | [paimon_cloned.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/60d3b7e9-5ead-4024-b499-a897ce5f3d5e)
260 |
261 |
262 |
263 |
264 |
265 |
266 | 🎢用户界面
267 |
268 | 如果你不擅长代码,我们还为VALL-E X创建了一个用户友好的图形界面。它可以让您轻松地与模型进行交互,使语音克隆和多语言语音合成变得轻而易举。
269 |
270 | 使用以下命令启动用户界面:
271 | ```commandline
272 | python -X utf8 launch-ui.py
273 | ```
274 |
275 |
276 | ## 🛠️ 硬件要求及推理速度
277 |
278 | VALL-E X 可以在CPU或GPU上运行 (`pytorch 2.0+`, CUDA 11.7 ~ CUDA 12.0).
279 |
280 | 若使用GPU运行,你需要至少6GB的显存。
281 |
282 | ## ⚙️ Details
283 |
284 | VALL-E X 与 [Bark](https://github.com/suno-ai/bark), [VALL-E](https://arxiv.org/abs/2301.02111) and [AudioLM](https://arxiv.org/abs/2209.03143)类似, 使用GPT风格的模型以自回归方式预测量化音频token,并由[EnCodec](https://github.com/facebookresearch/encodec)解码.
285 |
286 | 与 [Bark](https://github.com/suno-ai/bark) 相比:
287 | - ✔ **轻量**: 3️⃣ ✖ 更小,
288 | - ✔ **快速**: 4️⃣ ✖ 更快,
289 | - ✔ **中文&日文的更高质量**
290 | - ✔ **跨语言合成时没有外国口音**
291 | - ✔ **开放且易于操作的声音克隆**
292 | - ❌ **支持的语言较少**
293 | - ❌ **没有用于合成音乐及特殊音效的token**
294 |
295 | ### 支持的语言
296 |
297 | | 语言 | 状态 |
298 | |---------| :---: |
299 | | 英语 (en) | ✅ |
300 | | 日语 (ja) | ✅ |
301 | | 中文 (zh) | ✅ |
302 |
303 | ## ❓ FAQ
304 |
305 | #### 在哪里可以下载checkpoint?
306 | * 当您第一次运行程序时,我们使用`wget`将模型下载到`./checkpoints/`目录里。
307 | * 如果第一次运行时下载失败,请从[这里](https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt)手动下载模型,并将文件放在`./checkpoints/`里。
308 |
309 | #### 需要多少显存?
310 | * 6GB 显存(GPU VRAM) - 几乎所有NVIDIA GPU都满足要求.
311 |
312 | #### 为什么模型无法生成长文本?
313 | * 当序列长度增加时,Transformer的计算复杂度呈二次方增长。因此,所有训练音频都保持在22秒以下。请确保音频提示(audio prompt)和生成的音频的总长度小于22秒以确保可接受的性能。
314 |
315 | #### 更多...
316 |
317 | ## 🧠 待办事项
318 | - [x] 添加中文 README
319 | - [x] 长文本生成
320 | - [x] 用Vocos解码器替换Encodec解码器
321 | - [ ] 微调以实现更好的语音自适应
322 | - [ ] 给非python用户的`.bat`脚本
323 | - [ ] 更多...
324 |
325 | ## 🙏 感谢
326 | - [VALL-E X paper](https://arxiv.org/pdf/2303.03926) for the brilliant idea
327 | - [lifeiteng's vall-e](https://github.com/lifeiteng/vall-e) for related training code
328 | - [bark](https://github.com/suno-ai/bark) for the amazing pioneering work in neuro-codec TTS model
329 |
330 | ## ⭐️ 表示出你的支持
331 |
332 | 如果您觉得VALL-E X有趣且有用,请在GitHub上给我们一颗星! ⭐️ 它鼓励我们不断改进模型并添加令人兴奋的功能。
333 |
334 | ## 📜 License
335 |
336 | VALL-E X 使用 [MIT License](./LICENSE).
337 |
338 | ---
339 |
340 | 有问题或需要帮助? 可以随便 [open an issue](https://github.com/Plachtaa/VALL-E-X/issues/new) 或加入我们的 [Discord](https://discord.gg/qCBRmAnTxg)
341 |
342 | Happy voice cloning! 🎤
343 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VALL-E X: Multilingual Text-to-Speech Synthesis and Voice Cloning 🔊
2 | [](https://discord.gg/qCBRmAnTxg)
3 |
4 | English | [中文](README-ZH.md)
5 |
6 | An open source implementation of Microsoft's [VALL-E X](https://arxiv.org/pdf/2303.03926) zero-shot TTS model.
7 | **We release our trained model to the public for research or application usage.**
8 |
9 | 
10 |
11 | VALL-E X is an amazing multilingual text-to-speech (TTS) model proposed by Microsoft. While Microsoft initially publish in their research paper, they did not release any code or pretrained models. Recognizing the potential and value of this technology, our team took on the challenge to reproduce the results and train our own model. We are glad to share our trained VALL-E X model with the community, allowing everyone to experience the power next-generation TTS! 🎧
12 |
13 |
14 | More details about the model are presented in [model card](./model-card.md).
15 |
16 | ## 📖 Quick Index
17 | * [🚀 Updates](#-updates)
18 | * [📢 Features](#-features)
19 | * [💻 Installation](#-installation)
20 | * [🎧 Demos](#-demos)
21 | * [🐍 Usage](#-usage-in-python)
22 | * [❓ FAQ](#-faq)
23 | * [🧠 TODO](#-todo)
24 |
25 | ## 🚀 Updates
26 | **2023.09.10**
27 | - Added AR decoder batch decoding for more stable generation result.
28 |
29 | **2023.08.30**
30 | - Replaced EnCodec decoder with Vocos decoder, improved audio quality. (Thanks to [@v0xie](https://github.com/v0xie))
31 |
32 | **2023.08.23**
33 | - Added long text generation.
34 |
35 | **2023.08.20**
36 | - Added [Chinese README](README-ZH.md).
37 |
38 | **2023.08.14**
39 | - Pretrained VALL-E X checkpoint is now released. Download it [here](https://drive.google.com/file/d/10gdQWvP-K_e1undkvv0p2b7SU6I4Egyl/view?usp=sharing)
40 |
41 | ## 💻 Installation
42 | ### Install with pip, Python 3.10, CUDA 11.7 ~ 12.0, PyTorch 2.0+
43 | ```commandline
44 | git clone https://github.com/Plachtaa/VALL-E-X.git
45 | cd VALL-E-X
46 | pip install -r requirements.txt
47 | ```
48 |
49 | > Note: If you want to make prompt, you need to install ffmpeg and add its folder to the environment variable PATH.
50 |
51 | When you run the program for the first time, it will automatically download the corresponding model.
52 |
53 | If the download fails and reports an error, please follow the steps below to manually download the model.
54 |
55 | (Please pay attention to the capitalization of folders)
56 |
57 | 1. Check whether there is a `checkpoints` folder in the installation directory.
58 | If not, manually create a `checkpoints` folder (`./checkpoints/`) in the installation directory.
59 |
60 | 2. Check whether there is a `vallex-checkpoint.pt` file in the `checkpoints` folder.
61 | If not, please manually download the `vallex-checkpoint.pt` file from [here](https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt) and put it in the `checkpoints` folder.
62 |
63 | 3. Check whether there is a `whisper` folder in the installation directory.
64 | If not, manually create a `whisper` folder (`./whisper/`) in the installation directory.
65 |
66 | 4. Check whether there is a `medium.pt` file in the `whisper` folder.
67 | If not, please manually download the `medium.pt` file from [here](https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt) and put it in the `whisper` folder.
68 |
69 | ## 🎧 Demos
70 | Not ready to set up the environment on your local machine just yet? No problem! We've got you covered with our online demos. You can try out VALL-E X directly on Hugging Face or Google Colab, experiencing the model's capabilities hassle-free!
71 |
72 | [](https://huggingface.co/spaces/Plachta/VALL-E-X)
73 | [](https://colab.research.google.com/drive/1yyD_sz531QntLKowMHo-XxorsFBCfKul?usp=sharing)
74 |
75 |
76 | ## 📢 Features
77 |
78 | VALL-E X comes packed with cutting-edge functionalities:
79 |
80 | 1. **Multilingual TTS**: Speak in three languages - English, Chinese, and Japanese - with natural and expressive speech synthesis.
81 |
82 | 2. **Zero-shot Voice Cloning**: Enroll a short 3~10 seconds recording of an unseen speaker, and watch VALL-E X create personalized, high-quality speech that sounds just like them!
83 |
84 |
85 | see example
86 |
87 | [prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/a7baa51d-a53a-41cc-a03d-6970f25fcca7)
88 |
89 |
90 | [output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/b895601a-d126-4138-beff-061aabdc7985)
91 |
92 |
93 |
94 | 3. **Speech Emotion Control**: Experience the power of emotions! VALL-E X can synthesize speech with the same emotion as the acoustic prompt provided, adding an extra layer of expressiveness to your audio.
95 |
96 |
97 | see example
98 |
99 | https://github.com/Plachtaa/VALL-E-X/assets/112609742/56fa9988-925e-4757-82c5-83ecb0df6266
100 |
101 |
102 | https://github.com/Plachtaa/VALL-E-X/assets/112609742/699c47a3-d502-4801-8364-bd89bcc0b8f1
103 |
104 |
105 |
106 | 4. **Zero-shot Cross-Lingual Speech Synthesis**: Take monolingual speakers on a linguistic journey! VALL-E X can produce personalized speech in another language without compromising on fluency or accent. Below is a Japanese speaker talk in Chinese & English. 🇯🇵 🗣
107 |
108 |
109 | see example
110 |
111 | [jp-prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/ea6e2ee4-139a-41b4-837e-0bd04dda6e19)
112 |
113 |
114 | [en-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/db8f9782-923f-425e-ba94-e8c1bd48f207)
115 |
116 |
117 | [zh-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/15829d79-e448-44d3-8965-fafa7a3f8c28)
118 |
119 |
120 |
121 | 5. **Accent Control**: Get creative with accents! VALL-E X allows you to experiment with different accents, like speaking Chinese with an English accent or vice versa. 🇨🇳 💬
122 |
123 |
124 | see example
125 |
126 | [en-prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/f688d7f6-70ef-46ec-b1cc-355c31e78b3b)
127 |
128 |
129 | [zh-accent-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/be59c7ca-b45b-44ca-a30d-4d800c950ccc)
130 |
131 |
132 | [en-accent-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/8b4f4f9b-f299-4ea4-a548-137437b71738)
133 |
134 |
135 |
136 | 6. **Acoustic Environment Maintenance**: No need for perfectly clean audio prompts! VALL-E X adapts to the acoustic environment of the input, making speech generation feel natural and immersive.
137 |
138 |
139 | see example
140 |
141 | [noise-prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/68986d88-abd0-4d1d-96e4-4f893eb9259e)
142 |
143 |
144 | [noise-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/96c4c612-4516-4683-8804-501b70938608)
145 |
146 |
147 |
148 |
149 | Explore our [demo page](https://plachtaa.github.io/) for a lot more examples!
150 |
151 | ## 🐍 Usage in Python
152 |
153 |
154 | 🪑 Basics
155 |
156 | ```python
157 | from utils.generation import SAMPLE_RATE, generate_audio, preload_models
158 | from scipy.io.wavfile import write as write_wav
159 | from IPython.display import Audio
160 |
161 | # download and load all models
162 | preload_models()
163 |
164 | # generate audio from text
165 | text_prompt = """
166 | Hello, my name is Nose. And uh, and I like hamburger. Hahaha... But I also have other interests such as playing tactic toast.
167 | """
168 | audio_array = generate_audio(text_prompt)
169 |
170 | # save audio to disk
171 | write_wav("vallex_generation.wav", SAMPLE_RATE, audio_array)
172 |
173 | # play text in notebook
174 | Audio(audio_array, rate=SAMPLE_RATE)
175 | ```
176 |
177 | [hamburger.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/578d7bbe-cda9-483e-898c-29646edc8f2e)
178 |
179 |
180 |
181 |
182 | 🌎 Foreign Language
183 |
184 | This VALL-E X implementation also supports Chinese and Japanese. All three languages have equally awesome performance!
185 |
186 |
187 | ```python
188 |
189 | text_prompt = """
190 | チュソクは私のお気に入りの祭りです。 私は数日間休んで、友人や家族との時間を過ごすことができます。
191 | """
192 | audio_array = generate_audio(text_prompt)
193 | ```
194 |
195 | [vallex_japanese.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/ee57a688-3e83-4be5-b0fe-019d16eec51c)
196 |
197 | *Note: VALL-E X controls accent perfectly even when synthesizing code-switch text. However, you need to manually denote language of respective sentences (since our g2p tool is rule-base)*
198 | ```python
199 | text_prompt = """
200 | [EN]The Thirty Years' War was a devastating conflict that had a profound impact on Europe.[EN]
201 | [ZH]这是历史的开始。 如果您想听更多,请继续。[ZH]
202 | """
203 | audio_array = generate_audio(text_prompt, language='mix')
204 | ```
205 |
206 | [vallex_codeswitch.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/d8667abf-bd08-499f-a383-a861d852f98a)
207 |
208 |
209 |
210 |
211 | 📼 Voice Presets
212 |
213 | VALL-E X provides tens of speaker voices which you can directly used for inference! Browse all voices in the [code](/presets)
214 |
215 | > VALL-E X tries to match the tone, pitch, emotion and prosody of a given preset. The model also attempts to preserve music, ambient noise, etc.
216 |
217 | ```python
218 | text_prompt = """
219 | I am an innocent boy with a smoky voice. It is a great honor for me to speak at the United Nations today.
220 | """
221 | audio_array = generate_audio(text_prompt, prompt="dingzhen")
222 | ```
223 |
224 | [smoky.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/d3f55732-b1cd-420f-87d6-eab60db14dc5)
225 |
226 |
227 |
228 |
229 | 🎙Voice Cloning
230 |
231 | VALL-E X supports voice cloning! You can make a voice prompt with any person, character or even your own voice, and use it like other voice presets.
232 | To make a voice prompt, you need to provide a speech of 3~10 seconds long, as well as the transcript of the speech.
233 | You can also leave the transcript blank to let the [Whisper](https://github.com/openai/whisper) model to generate the transcript.
234 | > VALL-E X tries to match the tone, pitch, emotion and prosody of a given prompt. The model also attempts to preserve music, ambient noise, etc.
235 |
236 | ```python
237 | from utils.prompt_making import make_prompt
238 |
239 | ### Use given transcript
240 | make_prompt(name="paimon", audio_prompt_path="paimon_prompt.wav",
241 | transcript="Just, what was that? Paimon thought we were gonna get eaten.")
242 |
243 | ### Alternatively, use whisper
244 | make_prompt(name="paimon", audio_prompt_path="paimon_prompt.wav")
245 | ```
246 | Now let's try out the prompt we've just made!
247 | ```python
248 | from utils.generation import SAMPLE_RATE, generate_audio, preload_models
249 | from scipy.io.wavfile import write as write_wav
250 |
251 | # download and load all models
252 | preload_models()
253 |
254 | text_prompt = """
255 | Hey, Traveler, Listen to this, This machine has taken my voice, and now it can talk just like me!
256 | """
257 | audio_array = generate_audio(text_prompt, prompt="paimon")
258 |
259 | write_wav("paimon_cloned.wav", SAMPLE_RATE, audio_array)
260 |
261 | ```
262 |
263 | [paimon_prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/e7922859-9d12-4e2a-8651-e156e4280311)
264 |
265 |
266 | [paimon_cloned.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/60d3b7e9-5ead-4024-b499-a897ce5f3d5e)
267 |
268 |
269 |
270 |
271 |
272 |
273 | 🎢User Interface
274 |
275 | Not comfortable with codes? No problem! We've also created a user-friendly graphical interface for VALL-E X. It allows you to interact with the model effortlessly, making voice cloning and multilingual speech synthesis a breeze.
276 |
277 | You can launch the UI by the following command:
278 | ```commandline
279 | python -X utf8 launch-ui.py
280 | ```
281 |
282 |
283 | ## 🛠️ Hardware and Inference Speed
284 |
285 | VALL-E X works well on both CPU and GPU (`pytorch 2.0+`, CUDA 11.7 and CUDA 12.0).
286 |
287 | A GPU VRAM of 6GB is enough for running VALL-E X without offloading.
288 |
289 | ## ⚙️ Details
290 |
291 | VALL-E X is similar to [Bark](https://github.com/suno-ai/bark), [VALL-E](https://arxiv.org/abs/2301.02111) and [AudioLM](https://arxiv.org/abs/2209.03143), which generates audio in GPT-style by predicting audio tokens quantized by [EnCodec](https://github.com/facebookresearch/encodec).
292 |
293 | Comparing to [Bark](https://github.com/suno-ai/bark):
294 | - ✔ **Light-weighted**: 3️⃣ ✖ smaller,
295 | - ✔ **Efficient**: 4️⃣ ✖ faster,
296 | - ✔ **Better quality on Chinese & Japanese**
297 | - ✔ **Cross-lingual speech without foreign accent**
298 | - ✔ **Easy voice-cloning**
299 | - ❌ **Less languages**
300 | - ❌ **No special tokens for music / sound effects**
301 |
302 | ### Supported Languages
303 |
304 | | Language | Status |
305 | | --- | :---: |
306 | | English (en) | ✅ |
307 | | Japanese (ja) | ✅ |
308 | | Chinese, simplified (zh) | ✅ |
309 |
310 | ## ❓ FAQ
311 |
312 | #### Where is code for training?
313 | * [lifeiteng's vall-e](https://github.com/lifeiteng/vall-e) has almost everything. There is no plan to release our training code because there is no difference between lifeiteng's implementation.
314 |
315 | #### Where can I download the model checkpoint?
316 | * We use `wget` to download the model to directory `./checkpoints/` when you run the program for the first time.
317 | * If the download fails on the first run, please manually download from [this link](https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt), and put the file under directory `./checkpoints/`.
318 |
319 | #### How much VRAM do I need?
320 | * 6GB GPU VRAM - Almost all NVIDIA GPUs satisfy the requirement.
321 |
322 | #### Why the model fails to generate long text?
323 | * Transformer's computation complexity increases quadratically while the sequence length increases. Hence, all training
324 | are kept under 22 seconds. Please make sure the total length of audio prompt and generated audio is less than 22 seconds
325 | to ensure acceptable performance.
326 |
327 |
328 | #### MORE TO BE ADDED...
329 |
330 | ## 🧠 TODO
331 | - [x] Add Chinese README
332 | - [x] Long text generation
333 | - [x] Replace Encodec decoder with Vocos decoder
334 | - [ ] Fine-tuning for better voice adaptation
335 | - [ ] `.bat` scripts for non-python users
336 | - [ ] To be added...
337 |
338 | ## 🙏 Appreciation
339 | - [VALL-E X paper](https://arxiv.org/pdf/2303.03926) for the brilliant idea
340 | - [lifeiteng's vall-e](https://github.com/lifeiteng/vall-e) for related training code
341 | - [bark](https://github.com/suno-ai/bark) for the amazing pioneering work in neuro-codec TTS model
342 |
343 | ## ⭐️ Show Your Support
344 |
345 | If you find VALL-E X interesting and useful, give us a star on GitHub! ⭐️ It encourages us to keep improving the model and adding exciting features.
346 |
347 | ## 📜 License
348 |
349 | VALL-E X is licensed under the [MIT License](./LICENSE).
350 |
351 | ---
352 |
353 | Have questions or need assistance? Feel free to [open an issue](https://github.com/Plachtaa/VALL-E-X/issues/new) or join our [Discord](https://discord.gg/qCBRmAnTxg)
354 |
355 | Happy voice cloning! 🎤
356 |
--------------------------------------------------------------------------------
/customs/ph.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/customs/ph.txt
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | # from .datamodule import *
2 | # from .tokenizer import *
3 | from .collation import *
4 |
--------------------------------------------------------------------------------
/data/collation.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import List, Tuple
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from utils import SymbolTable
8 |
9 |
10 | class TextTokenCollater:
11 | """Collate list of text tokens
12 |
13 | Map sentences to integers. Sentences are padded to equal length.
14 | Beginning and end-of-sequence symbols can be added.
15 |
16 | Example:
17 | >>> token_collater = TextTokenCollater(text_tokens)
18 | >>> tokens_batch, tokens_lens = token_collater(text)
19 |
20 | Returns:
21 | tokens_batch: IntTensor of shape (B, L)
22 | B: batch dimension, number of input sentences
23 | L: length of the longest sentence
24 | tokens_lens: IntTensor of shape (B,)
25 | Length of each sentence after adding and
26 | but before padding.
27 | """
28 |
29 | def __init__(
30 | self,
31 | text_tokens: List[str],
32 | add_eos: bool = True,
33 | add_bos: bool = True,
34 | pad_symbol: str = "",
35 | bos_symbol: str = "",
36 | eos_symbol: str = "",
37 | ):
38 | self.pad_symbol = pad_symbol
39 |
40 | self.add_eos = add_eos
41 | self.add_bos = add_bos
42 |
43 | self.bos_symbol = bos_symbol
44 | self.eos_symbol = eos_symbol
45 |
46 | unique_tokens = (
47 | [pad_symbol]
48 | + ([bos_symbol] if add_bos else [])
49 | + ([eos_symbol] if add_eos else [])
50 | + sorted(text_tokens)
51 | )
52 |
53 | self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
54 | self.idx2token = [token for token in unique_tokens]
55 |
56 | def index(
57 | self, tokens_list: List[str]
58 | ) -> Tuple[torch.Tensor, torch.Tensor]:
59 | seqs, seq_lens = [], []
60 | for tokens in tokens_list:
61 | assert (
62 | all([True if s in self.token2idx else False for s in tokens])
63 | is True
64 | )
65 | seq = (
66 | ([self.bos_symbol] if self.add_bos else [])
67 | + list(tokens)
68 | + ([self.eos_symbol] if self.add_eos else [])
69 | )
70 | seqs.append(seq)
71 | seq_lens.append(len(seq))
72 |
73 | max_len = max(seq_lens)
74 | for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)):
75 | seq.extend([self.pad_symbol] * (max_len - seq_len))
76 |
77 | tokens = torch.from_numpy(
78 | np.array(
79 | [[self.token2idx[token] for token in seq] for seq in seqs],
80 | dtype=np.int64,
81 | )
82 | )
83 | tokens_lens = torch.IntTensor(seq_lens)
84 |
85 | return tokens, tokens_lens
86 |
87 | def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
88 | tokens_seqs = [[p for p in text] for text in texts]
89 | max_len = len(max(tokens_seqs, key=len))
90 |
91 | seqs = [
92 | ([self.bos_symbol] if self.add_bos else [])
93 | + list(seq)
94 | + ([self.eos_symbol] if self.add_eos else [])
95 | + [self.pad_symbol] * (max_len - len(seq))
96 | for seq in tokens_seqs
97 | ]
98 |
99 | tokens_batch = torch.from_numpy(
100 | np.array(
101 | [seq for seq in seqs],
102 | dtype=np.int64,
103 | )
104 | )
105 |
106 | tokens_lens = torch.IntTensor(
107 | [
108 | len(seq) + int(self.add_eos) + int(self.add_bos)
109 | for seq in tokens_seqs
110 | ]
111 | )
112 |
113 | return tokens_batch, tokens_lens
114 |
115 |
116 | def get_text_token_collater() -> TextTokenCollater:
117 | collater = TextTokenCollater(
118 | ['0'], add_bos=False, add_eos=False
119 | )
120 | return collater
121 |
--------------------------------------------------------------------------------
/data/datamodule.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 (authors: Feiteng Li)
2 | #
3 | # See ../../../../LICENSE for clarification regarding multiple authors
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 |
18 | import argparse
19 | import inspect
20 | import logging
21 | from functools import lru_cache
22 | from pathlib import Path
23 | from typing import Any, Dict, Optional
24 |
25 | import torch
26 | # from icefall.utils import str2bool
27 | # from lhotse import CutSet, load_manifest_lazy
28 | # from lhotse.dataset import (
29 | # CutConcatenate,
30 | # DynamicBucketingSampler,
31 | # PrecomputedFeatures,
32 | # SingleCutSampler,
33 | # SpecAugment,
34 | # )
35 | # from lhotse.dataset.input_strategies import OnTheFlyFeatures
36 | # from lhotse.utils import fix_random_seed
37 | from torch.utils.data import DataLoader
38 |
39 | from data.collation import get_text_token_collater
40 | # from data.dataset import SpeechSynthesisDataset
41 | from data.fbank import get_fbank_extractor
42 | from data.input_strategies import PromptedPrecomputedFeatures
43 |
44 | # PrecomputedFeatures = PrecomputedFeatures
45 |
46 |
47 | class _SeedWorkers:
48 | def __init__(self, seed: int):
49 | self.seed = seed
50 |
51 | def __call__(self, worker_id: int):
52 | fix_random_seed(self.seed + worker_id)
53 |
54 |
55 | def _get_input_strategy(input_strategy, dataset, cuts):
56 | if input_strategy == "PromptedPrecomputedFeatures":
57 | return PromptedPrecomputedFeatures(dataset, cuts)
58 |
59 | return eval(input_strategy)()
60 |
61 |
62 | class TtsDataModule:
63 | """
64 | DataModule for VALL-E TTS experiments.
65 | It assumes there is always one train and valid dataloader.
66 |
67 | It contains all the common data pipeline modules used in TTS
68 | experiments, e.g.:
69 | - dynamic batch size,
70 | - bucketing samplers,
71 | - cut concatenation[not used & tested yet],
72 | - augmentation[not used & tested yet],
73 | - on-the-fly feature extraction[not used & tested yet]
74 |
75 | This class should be derived for specific corpora used in TTS tasks.
76 | """
77 |
78 | def __init__(self, args: argparse.Namespace):
79 | self.args = args
80 |
81 | @classmethod
82 | def add_arguments(cls, parser: argparse.ArgumentParser):
83 | group = parser.add_argument_group(
84 | title="TTS data related options",
85 | description="These options are used for the preparation of "
86 | "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
87 | "effective batch sizes, sampling strategies, applied data "
88 | "augmentations, etc.",
89 | )
90 | group.add_argument(
91 | "--manifest-dir",
92 | type=Path,
93 | default=Path("data/tokenized"),
94 | help="Path to directory with train/valid/test cuts.",
95 | )
96 | group.add_argument(
97 | "--max-duration",
98 | type=int,
99 | default=40.0,
100 | help="Maximum pooled recordings duration (seconds) in a "
101 | "single batch. You can reduce it if it causes CUDA OOM.",
102 | )
103 | group.add_argument(
104 | "--bucketing-sampler",
105 | type=str2bool,
106 | default=True,
107 | help="When enabled, the batches will come from buckets of "
108 | "similar duration (saves padding frames).",
109 | )
110 | group.add_argument(
111 | "--num-buckets",
112 | type=int,
113 | default=10,
114 | help="The number of buckets for the DynamicBucketingSampler"
115 | "(you might want to increase it for larger datasets).",
116 | )
117 | group.add_argument(
118 | "--concatenate-cuts",
119 | type=str2bool,
120 | default=False,
121 | help="When enabled, utterances (cuts) will be concatenated "
122 | "to minimize the amount of padding.",
123 | )
124 | group.add_argument(
125 | "--duration-factor",
126 | type=float,
127 | default=1.0,
128 | help="Determines the maximum duration of a concatenated cut "
129 | "relative to the duration of the longest cut in a batch.",
130 | )
131 | group.add_argument(
132 | "--gap",
133 | type=float,
134 | default=0.1,
135 | help="The amount of padding (in seconds) inserted between "
136 | "concatenated cuts. This padding is filled with noise when "
137 | "noise augmentation is used.",
138 | )
139 | group.add_argument(
140 | "--on-the-fly-feats",
141 | type=str2bool,
142 | default=False,
143 | help="When enabled, use on-the-fly cut mixing and feature "
144 | "extraction. Will drop existing precomputed feature manifests "
145 | "if available.",
146 | )
147 | group.add_argument(
148 | "--shuffle",
149 | type=str2bool,
150 | default=True,
151 | help="When enabled (=default), the examples will be "
152 | "shuffled for each epoch.",
153 | )
154 | group.add_argument(
155 | "--drop-last",
156 | type=str2bool,
157 | default=False,
158 | help="Whether to drop last batch. Used by sampler.",
159 | )
160 | group.add_argument(
161 | "--return-cuts",
162 | type=str2bool,
163 | default=True,
164 | help="When enabled, each batch will have the "
165 | "field: batch['supervisions']['cut'] with the cuts that "
166 | "were used to construct it.",
167 | )
168 |
169 | group.add_argument(
170 | "--num-workers",
171 | type=int,
172 | default=8,
173 | help="The number of training dataloader workers that "
174 | "collect the batches.",
175 | )
176 |
177 | group.add_argument(
178 | "--enable-spec-aug",
179 | type=str2bool,
180 | default=False,
181 | help="When enabled, use SpecAugment for training dataset.",
182 | )
183 |
184 | group.add_argument(
185 | "--spec-aug-time-warp-factor",
186 | type=int,
187 | default=80,
188 | help="Used only when --enable-spec-aug is True. "
189 | "It specifies the factor for time warping in SpecAugment. "
190 | "Larger values mean more warping. "
191 | "A value less than 1 means to disable time warp.",
192 | )
193 |
194 | group.add_argument(
195 | "--input-strategy",
196 | type=str,
197 | default="PrecomputedFeatures",
198 | help="AudioSamples or PrecomputedFeatures or PromptedPrecomputedFeatures",
199 | )
200 |
201 | group.add_argument(
202 | "--dataset",
203 | type=str,
204 | default="ljspeech",
205 | help="--input-strategy PromptedPrecomputedFeatures needs dataset name to prepare prompts.",
206 | )
207 |
208 | parser.add_argument(
209 | "--text-tokens",
210 | type=str,
211 | default="data/tokenized/unique_text_tokens.k2symbols",
212 | help="Path to the unique text tokens file",
213 | )
214 |
215 | parser.add_argument(
216 | "--sampling-rate",
217 | type=int,
218 | default=24000,
219 | help="""Audio sampling rate.""",
220 | )
221 |
222 | def train_dataloaders(
223 | self,
224 | cuts_train: CutSet,
225 | sampler_state_dict: Optional[Dict[str, Any]] = None,
226 | ) -> DataLoader:
227 | """
228 | Args:
229 | cuts_train:
230 | CutSet for training.
231 | sampler_state_dict:
232 | The state dict for the training sampler.
233 | """
234 | transforms = []
235 |
236 | if self.args.concatenate_cuts:
237 | logging.info(
238 | f"Using cut concatenation with duration factor "
239 | f"{self.args.duration_factor} and gap {self.args.gap}."
240 | )
241 | # Cut concatenation should be the first transform in the list,
242 | # so that if we e.g. mix noise in, it will fill the gaps between
243 | # different utterances.
244 | transforms = [
245 | CutConcatenate(
246 | duration_factor=self.args.duration_factor, gap=self.args.gap
247 | )
248 | ] + transforms
249 |
250 | input_transforms = []
251 | if self.args.enable_spec_aug:
252 | logging.info("Enable SpecAugment")
253 | logging.info(
254 | f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
255 | )
256 | # Set the value of num_frame_masks according to Lhotse's version.
257 | # In different Lhotse's versions, the default of num_frame_masks is
258 | # different.
259 | num_frame_masks = 10
260 | num_frame_masks_parameter = inspect.signature(
261 | SpecAugment.__init__
262 | ).parameters["num_frame_masks"]
263 | if num_frame_masks_parameter.default == 1:
264 | num_frame_masks = 2
265 | logging.info(f"Num frame mask: {num_frame_masks}")
266 | input_transforms.append(
267 | SpecAugment(
268 | time_warp_factor=self.args.spec_aug_time_warp_factor,
269 | num_frame_masks=num_frame_masks,
270 | features_mask_size=27,
271 | num_feature_masks=2,
272 | frames_mask_size=100,
273 | )
274 | )
275 | else:
276 | logging.info("Disable SpecAugment")
277 |
278 | logging.info("About to create train dataset")
279 | if self.args.on_the_fly_feats:
280 | # NOTE: the PerturbSpeed transform should be added only if we
281 | # remove it from data prep stage.
282 | # Add on-the-fly speed perturbation; since originally it would
283 | # have increased epoch size by 3, we will apply prob 2/3 and use
284 | # 3x more epochs.
285 | # Speed perturbation probably should come first before
286 | # concatenation, but in principle the transforms order doesn't have
287 | # to be strict (e.g. could be randomized)
288 | # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
289 | # Drop feats to be on the safe side.
290 | train = SpeechSynthesisDataset(
291 | get_text_token_collater(self.args.text_tokens),
292 | cut_transforms=transforms,
293 | feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()),
294 | feature_transforms=input_transforms,
295 | )
296 | else:
297 | train = SpeechSynthesisDataset(
298 | get_text_token_collater(self.args.text_tokens),
299 | feature_input_strategy=_get_input_strategy(
300 | self.args.input_strategy, self.args.dataset, cuts_train
301 | ),
302 | cut_transforms=transforms,
303 | feature_transforms=input_transforms,
304 | )
305 |
306 | if self.args.bucketing_sampler:
307 | logging.info("Using DynamicBucketingSampler")
308 | train_sampler = DynamicBucketingSampler(
309 | cuts_train,
310 | max_duration=self.args.max_duration,
311 | shuffle=self.args.shuffle,
312 | num_buckets=self.args.num_buckets,
313 | drop_last=self.args.drop_last,
314 | )
315 | else:
316 | logging.info(
317 | "Using SingleCutSampler and sort by duraton(ascending=True)."
318 | )
319 | cuts_train = cuts_train.to_eager().sort_by_duration(ascending=True)
320 | train_sampler = SingleCutSampler(
321 | cuts_train,
322 | max_duration=self.args.max_duration,
323 | shuffle=self.args.shuffle,
324 | )
325 | logging.info("About to create train dataloader")
326 |
327 | if sampler_state_dict is not None:
328 | logging.info("Loading sampler state dict")
329 | train_sampler.load_state_dict(sampler_state_dict)
330 |
331 | # 'seed' is derived from the current random state, which will have
332 | # previously been set in the main process.
333 | seed = torch.randint(0, 100000, ()).item()
334 | worker_init_fn = _SeedWorkers(seed)
335 |
336 | train_dl = DataLoader(
337 | train,
338 | sampler=train_sampler,
339 | batch_size=None,
340 | num_workers=self.args.num_workers,
341 | persistent_workers=False,
342 | worker_init_fn=worker_init_fn,
343 | )
344 |
345 | return train_dl
346 |
347 | def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
348 | logging.info("About to create dev dataset")
349 | if self.args.on_the_fly_feats:
350 | validate = SpeechSynthesisDataset(
351 | get_text_token_collater(self.args.text_tokens),
352 | feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()),
353 | cut_transforms=[],
354 | )
355 | else:
356 | validate = SpeechSynthesisDataset(
357 | get_text_token_collater(self.args.text_tokens),
358 | feature_input_strategy=_get_input_strategy(
359 | self.args.input_strategy, self.args.dataset, cuts_valid
360 | ),
361 | cut_transforms=[],
362 | )
363 | valid_sampler = DynamicBucketingSampler(
364 | cuts_valid,
365 | max_duration=self.args.max_duration,
366 | shuffle=False,
367 | )
368 | logging.info("About to create dev dataloader")
369 | valid_dl = DataLoader(
370 | validate,
371 | sampler=valid_sampler,
372 | batch_size=None,
373 | num_workers=4,
374 | persistent_workers=False,
375 | )
376 |
377 | return valid_dl
378 |
379 | def test_dataloaders(self, cuts: CutSet) -> DataLoader:
380 | logging.debug("About to create test dataset")
381 | test = SpeechSynthesisDataset(
382 | get_text_token_collater(self.args.text_tokens),
383 | feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor())
384 | if self.args.on_the_fly_feats
385 | else _get_input_strategy(
386 | self.args.input_strategy, self.args.dataset, cuts
387 | ),
388 | cut_transforms=[],
389 | )
390 | sampler = DynamicBucketingSampler(
391 | cuts,
392 | max_duration=self.args.max_duration,
393 | shuffle=False,
394 | )
395 | logging.debug("About to create test dataloader")
396 | test_dl = DataLoader(
397 | test,
398 | batch_size=None,
399 | sampler=sampler,
400 | num_workers=self.args.num_workers,
401 | )
402 | return test_dl
403 |
404 | @lru_cache()
405 | def train_cuts(self) -> CutSet:
406 | logging.info("About to get train cuts")
407 | return load_manifest_lazy(
408 | self.args.manifest_dir / "cuts_train.jsonl.gz"
409 | )
410 |
411 | @lru_cache()
412 | def dev_cuts(self) -> CutSet:
413 | logging.info("About to get dev cuts")
414 | return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz")
415 |
416 | @lru_cache()
417 | def test_cuts(self) -> CutSet:
418 | logging.info("About to get test cuts")
419 | return load_manifest_lazy(self.args.manifest_dir / "cuts_test.jsonl.gz")
420 |
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 (authors: Feiteng Li)
2 | #
3 | # See ../../../../LICENSE for clarification regarding multiple authors
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """
18 | modified from lhoste.dataset.speech_synthesis.py
19 | """
20 |
21 | import torch
22 | import math
23 | import h5py
24 | from tokenizers import Tokenizer
25 | from typing import Union, List
26 | import numpy as np
27 | from tqdm import tqdm
28 |
29 | _pad = '_'
30 | _punctuation = ',.!?-~…'
31 | _letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
32 | symbols = [_pad] + list(_punctuation) + list(_letters)
33 |
34 | language_dict = {
35 | 'en': 0,
36 | 'zh': 1,
37 | 'ja': 2,
38 | }
39 | def seq2phone(tokens: Union[List, np.ndarray]):
40 | """
41 | Convert tokenized phoneme ID sequence back to phoneme string
42 | :param tokens: phoneme tokens
43 | :return: recovered phoneme sequence
44 | """
45 | phones = "".join([symbols[i] for i in tokens])
46 | return phones
47 |
48 | class DynamicBatchSampler(torch.utils.data.Sampler):
49 | def __init__(self, sampler, num_tokens_fn, num_buckets=100, min_size=0, max_size=1000,
50 | max_tokens=None, max_sentences=None, drop_last=False):
51 | """
52 |
53 | :param sampler:
54 | :param num_tokens_fn: 根据idx返回样本的长度的函数
55 | :param num_buckets: 利用桶原理将相似长度的样本放在一个batchsize中,桶的数量
56 | :param min_size: 最小长度的样本, 小于这个值的样本会被过滤掉。 依据这个值来创建样桶
57 | :param max_size: 最大长度的样本
58 | :param max_sentences: batch_size, 但是这里可以通过max_sentences 和 max_tokens 共同控制最终的大小
59 | """
60 | super(DynamicBatchSampler, self).__init__(sampler)
61 | self.sampler = sampler
62 | self.num_tokens_fn = num_tokens_fn
63 | self.num_buckets = num_buckets
64 |
65 | self.min_size = min_size
66 | self.max_size = max_size
67 |
68 | assert max_size <= max_tokens, "max_size should be smaller than max tokens"
69 | assert max_tokens is not None or max_sentences is not None, \
70 | "max_tokens and max_sentences should not be null at the same time, please specify one parameter at least"
71 | self.max_tokens = max_tokens if max_tokens is not None else float('Inf')
72 | self.max_sentences = max_sentences if max_sentences is not None else float('Inf')
73 | self.drop_last = drop_last
74 |
75 | def set_epoch(self, epoch):
76 | self.sampler.set_epoch(epoch)
77 | def is_batch_full(self, num_tokens, batch):
78 | if len(batch) == 0:
79 | return False
80 | if len(batch) == self.max_sentences:
81 | return True
82 | if num_tokens > self.max_tokens:
83 | return True
84 | return False
85 |
86 | def __iter__(self):
87 | buckets = [[] for _ in range(self.num_buckets)]
88 | sample_len = [0] * self.num_buckets
89 |
90 | for idx in self.sampler:
91 | idx_length = self.num_tokens_fn(idx)
92 | if not (self.min_size <= idx_length <= self.max_size):
93 | print("sentence at index {} of size {} exceeds max_tokens, the sentence is ignored".format(idx, idx_length))
94 | continue
95 |
96 | index_buckets = math.floor((idx_length - self.min_size) / (self.max_size - self.min_size + 1)
97 | * self.num_buckets)
98 | sample_len[index_buckets] = max(sample_len[index_buckets], idx_length)
99 |
100 | num_tokens = (len(buckets[index_buckets]) + 1) * sample_len[index_buckets]
101 | if self.is_batch_full(num_tokens, buckets[index_buckets]):
102 | # yield this batch
103 | yield buckets[index_buckets]
104 | buckets[index_buckets] = []
105 | sample_len[index_buckets] = 0
106 |
107 | buckets[index_buckets].append(idx)
108 |
109 | # process left-over
110 | leftover_batch = []
111 | leftover_sample_len = 0
112 | leftover = [idx for bucket in buckets for idx in bucket]
113 | for idx in leftover:
114 | idx_length = self.num_tokens_fn(idx)
115 | leftover_sample_len = max(leftover_sample_len, idx_length)
116 | num_tokens = (len(leftover_batch) + 1) * leftover_sample_len
117 | if self.is_batch_full(num_tokens, leftover_batch):
118 | yield leftover_batch
119 | leftover_batch = []
120 | leftover_sample_len = 0
121 | leftover_batch.append(idx)
122 |
123 | if len(leftover_batch) > 0 and not self.drop_last:
124 | yield leftover_batch
125 |
126 | def __len__(self):
127 | # we do not know the exactly batch size, so do not call len(dataloader)
128 | pass
129 |
130 |
131 | class AudioDataset(torch.utils.data.Dataset):
132 | def __init__(self, h5_path, ann_path, tokenizer_path):
133 | self.h5_path = h5_path
134 | with open(ann_path, 'r', encoding='utf-8') as f:
135 | lines = f.readlines()
136 | ls = [l.split("|") for l in lines]
137 | ls_T = list(zip(*ls))
138 | del ls_T[-1]
139 | self.h5_paths, self.durations, self.langs, self.texts = \
140 | list(ls_T[0]), list(ls_T[1]), list(ls_T[2]), list(ls_T[3])
141 | self.durations = [float(dur) for dur in self.durations]
142 | self.tokenizer = Tokenizer.from_file(tokenizer_path)
143 |
144 | self._archive = None
145 |
146 | def __len__(self):
147 | return len(self.h5_paths)
148 |
149 | def get_dur(self, idx):
150 | return self.durations[idx]
151 |
152 | @property
153 | def archive(self):
154 | if self._archive is None: # lazy loading here!
155 | self._archive = h5py.File(self.h5_path, "r")
156 | return self._archive
157 | def __getitem__(self, idx):
158 | archive = self.archive
159 | h5_path = self.h5_paths[idx]
160 | sub = archive[h5_path]
161 | audio_tokens = sub['audio'][()]
162 | phone_tokens = sub['text'][()]
163 | dur = self.durations[idx]
164 | lang = self.langs[idx]
165 | text = self.texts[idx]
166 | # tokenization should be done within dataloader
167 | phones = seq2phone(phone_tokens)
168 | phones = phones.replace(" ", "_")
169 | if not len(phones):
170 | cptpho_tokens = self.tokenizer.encode(text).ids
171 | else:
172 | cptpho_tokens = self.tokenizer.encode(phones).ids
173 | assert len(cptpho_tokens)
174 | return {
175 | 'utt_id': h5_path,
176 | 'text': text,
177 | 'audio': None,
178 | 'audio_lens': None,
179 | 'audio_features': audio_tokens,
180 | 'audio_features_lens': len(audio_tokens.T),
181 | 'text_tokens': np.array(cptpho_tokens),
182 | 'text_tokens_lens': len(cptpho_tokens),
183 | 'language': language_dict[lang],
184 | }
185 |
186 | def collate(batch):
187 | utt_id_s = [b['utt_id'] for b in batch]
188 | text_s = [b['text'] for b in batch]
189 |
190 | audio_s = [b['audio'] for b in batch]
191 | audio_lens_s = [b['audio_lens'] for b in batch]
192 |
193 | audio_features_lens_s = [b['audio_features_lens'] for b in batch]
194 | # create an empty tensor with maximum audio feature length
195 | audio_features_s = torch.zeros([len(batch), max(audio_features_lens_s), 8], dtype=torch.int64) - 1 # audio pad with -1
196 |
197 | text_tokens_lens_s = [b['text_tokens_lens'] for b in batch]
198 | # create an empty tensor with maximum text tokens length
199 | text_tokens_s = torch.zeros([len(batch), max(text_tokens_lens_s)], dtype=torch.int64) + 3 # [PAD] token id 3
200 |
201 | language_s = [b['language'] for b in batch]
202 |
203 | for i, b in enumerate(batch):
204 | audio_features = b['audio_features']
205 | audio_features_lens = b['audio_features_lens']
206 | audio_features_s[i, :audio_features_lens, :] = torch.LongTensor(audio_features.T)
207 |
208 | text_tokens = b['text_tokens']
209 | text_tokens_lens = b['text_tokens_lens']
210 | text_tokens_s[i, :text_tokens_lens] = torch.LongTensor(text_tokens)
211 |
212 | batch = {
213 | 'utt_id': utt_id_s,
214 | 'text': text_s,
215 | 'audio': audio_s,
216 | 'audio_lens': audio_lens_s,
217 | 'audio_features': audio_features_s,
218 | 'audio_features_lens': torch.LongTensor(np.array(audio_features_lens_s)),
219 | 'text_tokens': text_tokens_s,
220 | 'text_tokens_lens': torch.LongTensor(np.array(text_tokens_lens_s)),
221 | 'languages': torch.LongTensor(np.array(language_s)),
222 | }
223 | return batch
224 |
225 | def create_dataloader(data_dir="/root/valle/egs/mix", n_gpus=1, rank=0, num_workers=0, num_buckets=10, max_duration=120):
226 | train_dataset = AudioDataset(h5_path=f"{data_dir}/audio_sum.hdf5",
227 | ann_path=f"{data_dir}/audio_ann_sum.txt",
228 | tokenizer_path=f"{data_dir}/bpe_69.json")
229 | ran_sampler = torch.utils.data.distributed.DistributedSampler(
230 | train_dataset,
231 | num_replicas=n_gpus,
232 | rank=rank,
233 | shuffle=True,
234 | )
235 | dynamic_sampler = DynamicBatchSampler(ran_sampler, train_dataset.get_dur, num_buckets=num_buckets, max_size=20,
236 | max_tokens=max_duration)
237 |
238 |
239 | train_loader = torch.utils.data.DataLoader(train_dataset, num_workers=num_workers, collate_fn=collate,
240 | batch_sampler=dynamic_sampler)
241 |
242 | return train_loader
243 |
--------------------------------------------------------------------------------
/data/fbank.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 (authors: Feiteng Li)
2 | #
3 | # See ../../../../LICENSE for clarification regarding multiple authors
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 |
18 | from dataclasses import asdict, dataclass
19 | from typing import Any, Dict, Optional, Union
20 |
21 | import numpy as np
22 | import torch
23 | # from lhotse.features.base import FeatureExtractor
24 | # from lhotse.utils import EPSILON, Seconds, compute_num_frames
25 | from librosa.filters import mel as librosa_mel_fn
26 |
27 |
28 | @dataclass
29 | class BigVGANFbankConfig:
30 | # Spectogram-related part
31 | # Note that frame_length and frame_shift will be converted to milliseconds before torchaudio/Kaldi sees them
32 | frame_length: Seconds = 1024 / 24000.0
33 | frame_shift: Seconds = 256 / 24000.0
34 | remove_dc_offset: bool = True
35 | round_to_power_of_two: bool = True
36 |
37 | # Fbank-related part
38 | low_freq: float = 0.0
39 | high_freq: float = 12000.0
40 | num_mel_bins: int = 100
41 | use_energy: bool = False
42 |
43 | def to_dict(self) -> Dict[str, Any]:
44 | return asdict(self)
45 |
46 | @staticmethod
47 | def from_dict(data: Dict[str, Any]) -> "BigVGANFbankConfig":
48 | return BigVGANFbankConfig(**data)
49 |
50 |
51 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
52 | return torch.log(torch.clamp(x, min=clip_val) * C)
53 |
54 |
55 | def spectral_normalize_torch(magnitudes):
56 | output = dynamic_range_compression_torch(magnitudes)
57 | return output
58 |
59 |
60 | # https://github.com/NVIDIA/BigVGAN
61 | # bigvgan_24khz_100band https://drive.google.com/drive/folders/1EpxX6AsxjCbbk0mmAhE0td6eYiABr8Oz
62 | class BigVGANFbank(FeatureExtractor):
63 | name = "fbank"
64 | config_type = BigVGANFbankConfig
65 |
66 | def __init__(self, config: Optional[Any] = None):
67 | super(BigVGANFbank, self).__init__(config)
68 | sampling_rate = 24000
69 | self.mel_basis = torch.from_numpy(
70 | librosa_mel_fn(
71 | sampling_rate,
72 | 1024,
73 | self.config.num_mel_bins,
74 | self.config.low_freq,
75 | self.config.high_freq,
76 | ).astype(np.float32)
77 | )
78 | self.hann_window = torch.hann_window(1024)
79 |
80 | def _feature_fn(self, samples, **kwargs):
81 | win_length, n_fft = 1024, 1024
82 | hop_size = 256
83 | if True:
84 | sampling_rate = 24000
85 | duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
86 | expected_num_frames = compute_num_frames(
87 | duration=duration,
88 | frame_shift=self.frame_shift,
89 | sampling_rate=sampling_rate,
90 | )
91 | pad_size = (
92 | (expected_num_frames - 1) * hop_size
93 | + win_length
94 | - samples.shape[-1]
95 | )
96 | assert pad_size >= 0
97 |
98 | y = torch.nn.functional.pad(
99 | samples,
100 | (0, pad_size),
101 | mode="constant",
102 | )
103 | else:
104 | y = torch.nn.functional.pad(
105 | samples,
106 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
107 | mode="reflect",
108 | )
109 |
110 | y = y.squeeze(1)
111 |
112 | # complex tensor as default, then use view_as_real for future pytorch compatibility
113 | spec = torch.stft(
114 | y,
115 | n_fft,
116 | hop_length=hop_size,
117 | win_length=win_length,
118 | window=self.hann_window,
119 | center=False,
120 | pad_mode="reflect",
121 | normalized=False,
122 | onesided=True,
123 | return_complex=True,
124 | )
125 | spec = torch.view_as_real(spec)
126 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
127 |
128 | spec = torch.matmul(self.mel_basis, spec)
129 | spec = spectral_normalize_torch(spec)
130 |
131 | return spec.transpose(2, 1).squeeze(0)
132 |
133 | def extract(
134 | self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
135 | ) -> np.ndarray:
136 | assert sampling_rate == 24000
137 | params = asdict(self.config)
138 | params.update({"sample_frequency": sampling_rate, "snip_edges": False})
139 | params["frame_shift"] *= 1000.0
140 | params["frame_length"] *= 1000.0
141 | if not isinstance(samples, torch.Tensor):
142 | samples = torch.from_numpy(samples)
143 | # Torchaudio Kaldi feature extractors expect the channel dimension to be first.
144 | if len(samples.shape) == 1:
145 | samples = samples.unsqueeze(0)
146 | features = self._feature_fn(samples, **params).to(torch.float32)
147 | return features.numpy()
148 |
149 | @property
150 | def frame_shift(self) -> Seconds:
151 | return self.config.frame_shift
152 |
153 | def feature_dim(self, sampling_rate: int) -> int:
154 | return self.config.num_mel_bins
155 |
156 | @staticmethod
157 | def mix(
158 | features_a: np.ndarray,
159 | features_b: np.ndarray,
160 | energy_scaling_factor_b: float,
161 | ) -> np.ndarray:
162 | return np.log(
163 | np.maximum(
164 | # protection against log(0); max with EPSILON is adequate since these are energies (always >= 0)
165 | EPSILON,
166 | np.exp(features_a)
167 | + energy_scaling_factor_b * np.exp(features_b),
168 | )
169 | )
170 |
171 | @staticmethod
172 | def compute_energy(features: np.ndarray) -> float:
173 | return float(np.sum(np.exp(features)))
174 |
175 |
176 | def get_fbank_extractor() -> BigVGANFbank:
177 | return BigVGANFbank(BigVGANFbankConfig())
178 |
179 |
180 | if __name__ == "__main__":
181 | extractor = BigVGANFbank(BigVGANFbankConfig())
182 |
183 | samples = torch.from_numpy(np.random.random([1000]).astype(np.float32))
184 | samples = torch.clip(samples, -1.0, 1.0)
185 | fbank = extractor.extract(samples, 24000.0)
186 | print(f"fbank {fbank.shape}")
187 |
188 | from scipy.io.wavfile import read
189 |
190 | MAX_WAV_VALUE = 32768.0
191 |
192 | sampling_rate, samples = read(
193 | "egs/libritts/prompts/5639_40744_000000_000002.wav"
194 | )
195 | print(f"samples: [{samples.min()}, {samples.max()}]")
196 | fbank = extractor.extract(samples.astype(np.float32) / MAX_WAV_VALUE, 24000)
197 | print(f"fbank {fbank.shape}")
198 |
199 | import matplotlib.pyplot as plt
200 |
201 | _ = plt.figure(figsize=(18, 10))
202 | plt.imshow(
203 | X=fbank.transpose(1, 0),
204 | cmap=plt.get_cmap("jet"),
205 | aspect="auto",
206 | interpolation="nearest",
207 | )
208 | plt.gca().invert_yaxis()
209 | plt.savefig("egs/libritts/prompts/5639_40744_000000_000002.png")
210 | plt.close()
211 |
212 | print("fbank test PASS!")
213 |
--------------------------------------------------------------------------------
/data/input_strategies.py:
--------------------------------------------------------------------------------
1 | import random
2 | from collections import defaultdict
3 | from concurrent.futures import ThreadPoolExecutor
4 | from typing import Tuple, Type
5 |
6 | # from lhotse import CutSet
7 | # from lhotse.dataset.collation import collate_features
8 | # from lhotse.dataset.input_strategies import (
9 | # ExecutorType,
10 | # PrecomputedFeatures,
11 | # _get_executor,
12 | # )
13 | # from lhotse.utils import fastcopy
14 |
15 |
16 | class PromptedFeatures:
17 | def __init__(self, prompts, features):
18 | self.prompts = prompts
19 | self.features = features
20 |
21 | def to(self, device):
22 | return PromptedFeatures(
23 | self.prompts.to(device), self.features.to(device)
24 | )
25 |
26 | def sum(self):
27 | return self.features.sum()
28 |
29 | @property
30 | def ndim(self):
31 | return self.features.ndim
32 |
33 | @property
34 | def data(self):
35 | return (self.prompts, self.features)
36 |
37 |
38 | # class PromptedPrecomputedFeatures(PrecomputedFeatures):
39 | # """
40 | # :class:`InputStrategy` that reads pre-computed features, whose manifests
41 | # are attached to cuts, from disk.
42 | #
43 | # It automatically pads the feature matrices with pre or post feature.
44 | #
45 | # .. automethod:: __call__
46 | # """
47 | #
48 | # def __init__(
49 | # self,
50 | # dataset: str,
51 | # cuts: CutSet,
52 | # num_workers: int = 0,
53 | # executor_type: Type[ExecutorType] = ThreadPoolExecutor,
54 | # ) -> None:
55 | # super(PromptedPrecomputedFeatures, self).__init__(
56 | # num_workers, executor_type
57 | # )
58 | #
59 | # self.utt2neighbors = defaultdict(lambda: [])
60 | #
61 | # if dataset.lower() == "libritts":
62 | # # 909_131041_000013_000002
63 | # # 909_131041_000013_000003
64 | # speaker2utts = defaultdict(lambda: [])
65 | #
66 | # utt2cut = {}
67 | # for cut in cuts:
68 | # speaker = cut.supervisions[0].speaker
69 | # speaker2utts[speaker].append(cut.id)
70 | # utt2cut[cut.id] = cut
71 | #
72 | # for spk in speaker2utts:
73 | # uttids = sorted(speaker2utts[spk])
74 | # # Using the property of sorted keys to find previous utterance
75 | # # The keys has structure speaker_book_x_y e.g. 1089_134691_000004_000001
76 | # if len(uttids) == 1:
77 | # self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]])
78 | # continue
79 | #
80 | # utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1]))
81 | # utt2postutt = dict(zip(uttids[:-1], uttids[1:]))
82 | #
83 | # for utt in utt2prevutt:
84 | # self.utt2neighbors[utt].append(utt2cut[utt2prevutt[utt]])
85 | #
86 | # for utt in utt2postutt:
87 | # self.utt2neighbors[utt].append(utt2cut[utt2postutt[utt]])
88 | # elif dataset.lower() == "ljspeech":
89 | # utt2cut = {}
90 | # uttids = []
91 | # for cut in cuts:
92 | # uttids.append(cut.id)
93 | # utt2cut[cut.id] = cut
94 | #
95 | # if len(uttids) == 1:
96 | # self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]])
97 | # else:
98 | # # Using the property of sorted keys to find previous utterance
99 | # # The keys has structure: LJ001-0010
100 | # utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1]))
101 | # utt2postutt = dict(zip(uttids[:-1], uttids[1:]))
102 | #
103 | # for utt in utt2postutt:
104 | # postutt = utt2postutt[utt]
105 | # if utt[:5] == postutt[:5]:
106 | # self.utt2neighbors[utt].append(utt2cut[postutt])
107 | #
108 | # for utt in utt2prevutt:
109 | # prevutt = utt2prevutt[utt]
110 | # if utt[:5] == prevutt[:5] or not self.utt2neighbors[utt]:
111 | # self.utt2neighbors[utt].append(utt2cut[prevutt])
112 | # else:
113 | # raise ValueError
114 | #
115 | # def __call__(
116 | # self, cuts: CutSet
117 | # ) -> Tuple[PromptedFeatures, PromptedFeatures]:
118 | # """
119 | # Reads the pre-computed features from disk/other storage.
120 | # The returned shape is``(B, T, F) => (batch_size, num_frames, num_features)``.
121 | #
122 | # :return: a tensor with collated features, and a tensor of ``num_frames`` of each cut before padding.
123 | # """
124 | # features, features_lens = collate_features(
125 | # cuts,
126 | # executor=_get_executor(
127 | # self.num_workers, executor_type=self._executor_type
128 | # ),
129 | # )
130 | #
131 | # prompts_cuts = []
132 | # for k, cut in enumerate(cuts):
133 | # prompts_cut = random.choice(self.utt2neighbors[cut.id])
134 | # prompts_cuts.append(fastcopy(prompts_cut, id=f"{cut.id}-{str(k)}"))
135 | #
136 | # mini_duration = min([cut.duration for cut in prompts_cuts] + [3.0])
137 | # # prompts_cuts = CutSet.from_cuts(prompts_cuts).truncate(
138 | # # max_duration=mini_duration,
139 | # # offset_type="random",
140 | # # preserve_id=True,
141 | # # )
142 | # prompts_cuts = CutSet(
143 | # cuts={k: cut for k, cut in enumerate(prompts_cuts)}
144 | # ).truncate(
145 | # max_duration=mini_duration,
146 | # offset_type="random",
147 | # preserve_id=False,
148 | # )
149 | #
150 | # prompts, prompts_lens = collate_features(
151 | # prompts_cuts,
152 | # executor=_get_executor(
153 | # self.num_workers, executor_type=self._executor_type
154 | # ),
155 | # )
156 | #
157 | # return PromptedFeatures(prompts, features), PromptedFeatures(
158 | # prompts_lens, features_lens
159 | # )
160 |
--------------------------------------------------------------------------------
/data/tokenizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2023 (authors: Feiteng Li)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import re
17 | from dataclasses import asdict, dataclass
18 | from typing import Any, Dict, List, Optional, Pattern, Union
19 |
20 | import numpy as np
21 | import torch
22 | import torchaudio
23 | from encodec import EncodecModel
24 | from encodec.utils import convert_audio
25 |
26 | try:
27 | from pypinyin import Style, pinyin
28 | from pypinyin.style._utils import get_finals, get_initials
29 | except Exception:
30 | pass
31 |
32 |
33 | def remove_encodec_weight_norm(model):
34 | from encodec.modules import SConv1d
35 | from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
36 | from torch.nn.utils import remove_weight_norm
37 |
38 | encoder = model.encoder.model
39 | for key in encoder._modules:
40 | if isinstance(encoder._modules[key], SEANetResnetBlock):
41 | remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
42 | block_modules = encoder._modules[key].block._modules
43 | for skey in block_modules:
44 | if isinstance(block_modules[skey], SConv1d):
45 | remove_weight_norm(block_modules[skey].conv.conv)
46 | elif isinstance(encoder._modules[key], SConv1d):
47 | remove_weight_norm(encoder._modules[key].conv.conv)
48 |
49 | decoder = model.decoder.model
50 | for key in decoder._modules:
51 | if isinstance(decoder._modules[key], SEANetResnetBlock):
52 | remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
53 | block_modules = decoder._modules[key].block._modules
54 | for skey in block_modules:
55 | if isinstance(block_modules[skey], SConv1d):
56 | remove_weight_norm(block_modules[skey].conv.conv)
57 | elif isinstance(decoder._modules[key], SConvTranspose1d):
58 | remove_weight_norm(decoder._modules[key].convtr.convtr)
59 | elif isinstance(decoder._modules[key], SConv1d):
60 | remove_weight_norm(decoder._modules[key].conv.conv)
61 |
62 |
63 | class AudioTokenizer:
64 | """EnCodec audio."""
65 |
66 | def __init__(
67 | self,
68 | device: Any = None,
69 | ) -> None:
70 | # Instantiate a pretrained EnCodec model
71 | model = EncodecModel.encodec_model_24khz()
72 | model.set_target_bandwidth(6.0)
73 | remove_encodec_weight_norm(model)
74 |
75 | if not device:
76 | device = torch.device("cpu")
77 | if torch.cuda.is_available():
78 | device = torch.device("cuda:0")
79 | if torch.backends.mps.is_available():
80 | device = torch.device("mps")
81 |
82 | self._device = device
83 |
84 | self.codec = model.to(device)
85 | self.sample_rate = model.sample_rate
86 | self.channels = model.channels
87 |
88 | @property
89 | def device(self):
90 | return self._device
91 |
92 | def encode(self, wav: torch.Tensor) -> torch.Tensor:
93 | return self.codec.encode(wav.to(self.device))
94 |
95 | def decode(self, frames: torch.Tensor) -> torch.Tensor:
96 | return self.codec.decode(frames)
97 |
98 |
99 | def tokenize_audio(tokenizer: AudioTokenizer, audio):
100 | # Load and pre-process the audio waveform
101 | if isinstance(audio, str):
102 | wav, sr = torchaudio.load(audio)
103 | else:
104 | wav, sr = audio
105 | wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
106 | wav = wav.unsqueeze(0)
107 |
108 | # Extract discrete codes from EnCodec
109 | with torch.no_grad():
110 | encoded_frames = tokenizer.encode(wav)
111 | return encoded_frames
112 |
113 |
114 | if __name__ == "__main__":
115 | model = EncodecModel.encodec_model_24khz()
116 | model.set_target_bandwidth(6.0)
117 |
118 | samples = torch.from_numpy(np.random.random([4, 1, 1600])).type(
119 | torch.float32
120 | )
121 | codes_raw = model.encode(samples)
122 |
123 | remove_encodec_weight_norm(model)
124 | codes_norm = model.encode(samples)
125 |
126 | assert torch.allclose(codes_raw[0][0], codes_norm[0][0])
127 |
--------------------------------------------------------------------------------
/descriptions.py:
--------------------------------------------------------------------------------
1 | top_md = """
2 | # VALL-E X
3 | VALL-E X can synthesize high-quality personalized speech with only a 3-second enrolled recording of
4 | an unseen speaker as an acoustic prompt, even in another language for a monolingual speaker.
5 | This implementation supports zero-shot, mono-lingual/cross-lingual text-to-speech functionality of three languages (English, Chinese, Japanese)
6 | See this [demo](https://plachtaa.github.io/) page for more details.
7 | """
8 |
9 | infer_from_audio_md = """
10 | Upload a speech of 3~10 seconds as the audio prompt and type in the text you'd like to synthesize.
11 | The model will synthesize speech of given text with the same voice of your audio prompt.
12 | The model also tends to preserve the emotion & acoustic environment of your given speech.
13 | For faster inference, please use **"Make prompt"** to get a `.npz` file as the encoded audio prompt, and use it by **"Infer from prompt"**
14 | """
15 |
16 | make_prompt_md = """
17 | Upload a speech of 3~10 seconds as the audio prompt.
18 | Get a `.npz` file as the encoded audio prompt. Use it by **"Infer with prompt"**
19 | """
20 |
21 | infer_from_prompt_md = """
22 | Faster than **"Infer from audio"**.
23 | You need to **"Make prompt"** first, and upload the encoded prompt (a `.npz` file)
24 | """
25 |
26 | long_text_md = """
27 | Very long text is chunked into several sentences, and each sentence is synthesized separately.
28 | Please make a prompt or use a preset prompt to infer long text.
29 | """
30 |
31 | long_text_example = "Just a few years ago, there were no legions of deep learning scientists developing intelligent products and services at major companies and startups. When we entered the field, machine learning did not command headlines in daily newspapers. Our parents had no idea what machine learning was, let alone why we might prefer it to a career in medicine or law. Machine learning was a blue skies academic discipline whose industrial significance was limited to a narrow set of real-world applications, including speech recognition and computer vision. Moreover, many of these applications required so much domain knowledge that they were often regarded as entirely separate areas for which machine learning was one small component. At that time, neural networks—the predecessors of the deep learning methods that we focus on in this book—were generally regarded as outmoded."
--------------------------------------------------------------------------------
/examples.py:
--------------------------------------------------------------------------------
1 | infer_from_audio_examples = [
2 | ["This is how this machine has taken my voice.", 'English', 'no-accent', "prompts/en-2.wav", None, "Wow, look at that! That's no ordinary Teddy bear!"],
3 | ["我喜欢抽电子烟,尤其是锐刻五代。", '中文', 'no-accent', "prompts/zh-1.wav", None, "今天我很荣幸,"],
4 | ["私の声を真似するのはそんなに面白いですか?", '日本語', 'no-accent', "prompts/ja-2.ogg", None, "初めまして、朝武よしのです。"],
5 | ["你可以听得出来我有多困。", '中文', 'no-accent', "prompts/en-1.wav", None, ""],
6 | ["この文は、クロスリンガル合成の例です。", '日本語', 'no-accent', "prompts/zh-2.wav", None, ""],
7 | ["Actually, I can't speak English, but this machine helped me do it.", 'English', 'no-accent', "prompts/ja-1.wav", None, ""],
8 | ]
9 |
10 | make_npz_prompt_examples = [
11 | ["Gem-trader", "prompts/en-2.wav", None, "Wow, look at that! That's no ordinary Teddy bear!"],
12 | ["Ding Zhen", "prompts/zh-1.wav", None, "今天我很荣幸,"],
13 | ["Yoshino", "prompts/ja-2.ogg", None, "初めまして、朝武よしのです。"],
14 | ["Sleepy-woman", "prompts/en-1.wav", None, ""],
15 | ["Yae", "prompts/zh-2.wav", None, ""],
16 | ["Cafe", "prompts/ja-1.wav", None, ""],
17 | ]
18 |
19 | infer_from_prompt_examples = [
20 | ["A prompt contains voice, prosody and emotion information of a certain speaker.", "English", "no-accent", "vctk_1", None],
21 | ["This prompt is made with an audio of three seconds.", "English", "no-accent", "librispeech_1", None],
22 | ["This prompt is made with Chinese speech", "English", "no-accent", "seel", None],
23 | ]
24 |
25 |
--------------------------------------------------------------------------------
/images/vallex_framework.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/images/vallex_framework.jpg
--------------------------------------------------------------------------------
/macros.py:
--------------------------------------------------------------------------------
1 | NUM_LAYERS = 12
2 | NUM_HEAD = 16
3 | N_DIM = 1024
4 | PREFIX_MODE = 1
5 | NUM_QUANTIZERS = 8
6 | SAMPLE_RATE = 24000
7 |
8 | lang2token = {
9 | 'zh': "[ZH]",
10 | 'ja': "[JA]",
11 | "en": "[EN]",
12 | 'mix': "",
13 | }
14 |
15 | lang2code = {
16 | 'zh': 0,
17 | 'ja': 1,
18 | "en": 2,
19 | }
20 |
21 | token2lang = {
22 | '[ZH]': "zh",
23 | '[JA]': "ja",
24 | "[EN]": "en",
25 | "": "mix"
26 | }
27 |
28 | code2lang = {
29 | 0: 'zh',
30 | 1: 'ja',
31 | 2: "en",
32 | }
33 |
34 | langdropdown2token = {
35 | 'English': "[EN]",
36 | '中文': "[ZH]",
37 | '日本語': "[JA]",
38 | 'Mix': "",
39 | }
--------------------------------------------------------------------------------
/model-card.md:
--------------------------------------------------------------------------------
1 | # Model Card: VALL-E X
2 |
3 | **Author**: [Songting](https://github.com/Plachtaa).
4 |
5 | This is the official codebase for running open-sourced VALL-E X.
6 |
7 | The following is additional information about the models released here.
8 |
9 | ## Model Details
10 |
11 | VALL-E X is a series of two transformer models that turn text into audio.
12 |
13 | ### Phoneme to acoustic tokens
14 | - Input: IPAs converted from input text by a rule-based G2P tool.
15 | - Output: tokens from the first codebook of the [EnCodec Codec](https://github.com/facebookresearch/encodec) from facebook
16 |
17 | ### Coarse to fine tokens
18 | - Input: IPAs converted from input text by a rule-based G2P tool & the first codebook from EnCodec
19 | - Output: 8 codebooks from EnCodec
20 |
21 | ### Architecture
22 | | Model | Parameters | Attention | Output Vocab size |
23 | |:------------------------:|:----------:|------------|:-----------------:|
24 | | G2P tool | - | - | 69 |
25 | | Phoneme to coarse tokens | 150 M | Causal | 1x 1,024 |
26 | | Coarse to fine tokens | 150 M | Non-causal | 7x 1,024 |
27 |
28 | ### Release date
29 | August 2023
30 |
31 | ## Broader Implications
32 | We anticipate that this model's text to audio capabilities can be used to improve accessbility tools in a variety of languages.
33 | Straightforward improvements will allow models to run faster than realtime, rendering them useful for applications such as virtual assistants.
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch.nn as nn
4 | # from icefall.utils import AttributeDict, str2bool
5 |
6 | from .macros import (
7 | NUM_AUDIO_TOKENS,
8 | NUM_MEL_BINS,
9 | NUM_SPEAKER_CLASSES,
10 | NUM_TEXT_TOKENS,
11 | SPEAKER_EMBEDDING_DIM,
12 | )
13 | from .transformer import Transformer
14 | from .vallex import VALLE, VALLF
15 | from .visualizer import visualize
16 |
17 |
18 | def add_model_arguments(parser: argparse.ArgumentParser):
19 | parser.add_argument(
20 | "--model-name",
21 | type=str,
22 | default="VALL-E",
23 | help="VALL-E, VALL-F, Transformer.",
24 | )
25 | parser.add_argument(
26 | "--decoder-dim",
27 | type=int,
28 | default=1024,
29 | help="Embedding dimension in the decoder model.",
30 | )
31 | parser.add_argument(
32 | "--nhead",
33 | type=int,
34 | default=16,
35 | help="Number of attention heads in the Decoder layers.",
36 | )
37 | parser.add_argument(
38 | "--num-decoder-layers",
39 | type=int,
40 | default=12,
41 | help="Number of Decoder layers.",
42 | )
43 | parser.add_argument(
44 | "--scale-factor",
45 | type=float,
46 | default=1.0,
47 | help="Model scale factor which will be assigned different meanings in different models.",
48 | )
49 | parser.add_argument(
50 | "--norm-first",
51 | type=bool,
52 | default=True,
53 | help="Pre or Post Normalization.",
54 | )
55 | parser.add_argument(
56 | "--add-prenet",
57 | type=bool,
58 | default=False,
59 | help="Whether add PreNet after Inputs.",
60 | )
61 |
62 | # VALL-E & F
63 | parser.add_argument(
64 | "--prefix-mode",
65 | type=int,
66 | default=1,
67 | help="The mode for how to prefix VALL-E NAR Decoder, "
68 | "0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance.",
69 | )
70 | parser.add_argument(
71 | "--share-embedding",
72 | type=bool,
73 | default=True,
74 | help="Share the parameters of the output projection layer with the parameters of the acoustic embedding.",
75 | )
76 | parser.add_argument(
77 | "--prepend-bos",
78 | type=bool,
79 | default=False,
80 | help="Whether prepend to the acoustic tokens -> AR Decoder inputs.",
81 | )
82 | parser.add_argument(
83 | "--num-quantizers",
84 | type=int,
85 | default=8,
86 | help="Number of Audio/Semantic quantization layers.",
87 | )
88 |
89 | # Transformer
90 | parser.add_argument(
91 | "--scaling-xformers",
92 | type=bool,
93 | default=False,
94 | help="Apply Reworked Conformer scaling on Transformers.",
95 | )
96 |
97 |
98 | def get_model(params) -> nn.Module:
99 | if params.model_name.lower() in ["vall-f", "vallf"]:
100 | model = VALLF(
101 | params.decoder_dim,
102 | params.nhead,
103 | params.num_decoder_layers,
104 | norm_first=params.norm_first,
105 | add_prenet=params.add_prenet,
106 | prefix_mode=params.prefix_mode,
107 | share_embedding=params.share_embedding,
108 | nar_scale_factor=params.scale_factor,
109 | prepend_bos=params.prepend_bos,
110 | num_quantizers=params.num_quantizers,
111 | )
112 | elif params.model_name.lower() in ["vall-e", "valle"]:
113 | model = VALLE(
114 | params.decoder_dim,
115 | params.nhead,
116 | params.num_decoder_layers,
117 | norm_first=params.norm_first,
118 | add_prenet=params.add_prenet,
119 | prefix_mode=params.prefix_mode,
120 | share_embedding=params.share_embedding,
121 | nar_scale_factor=params.scale_factor,
122 | prepend_bos=params.prepend_bos,
123 | num_quantizers=params.num_quantizers,
124 | )
125 | else:
126 | assert params.model_name in ["Transformer"]
127 | model = Transformer(
128 | params.decoder_dim,
129 | params.nhead,
130 | params.num_decoder_layers,
131 | norm_first=params.norm_first,
132 | add_prenet=params.add_prenet,
133 | scaling_xformers=params.scaling_xformers,
134 | )
135 |
136 | return model
137 |
--------------------------------------------------------------------------------
/models/macros.py:
--------------------------------------------------------------------------------
1 | # Text
2 | NUM_TEXT_TOKENS = 2048
3 |
4 | # Audio
5 | NUM_AUDIO_TOKENS = 1024 # EnCodec RVQ bins
6 | NUM_MEL_BINS = 100 # BigVGAN bigvgan_24khz_100band
7 |
8 |
9 | # Speaker
10 | NUM_SPEAKER_CLASSES = 4096
11 | SPEAKER_EMBEDDING_DIM = 64
12 |
--------------------------------------------------------------------------------
/models/transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 (authors: Feiteng Li)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from functools import partial
16 | from typing import Any, Dict, List, Tuple, Union
17 |
18 | import torch
19 | import torch.nn as nn
20 | import torch.nn.functional as F
21 | # from icefall.utils import make_pad_mask
22 | # from torchmetrics.classification import BinaryAccuracy
23 |
24 | from models.vallex import Transpose
25 | from modules.embedding import SinePositionalEmbedding, TokenEmbedding
26 | from modules.scaling import BalancedDoubleSwish, ScaledLinear
27 | from modules.transformer import (
28 | BalancedBasicNorm,
29 | IdentityNorm,
30 | TransformerDecoderLayer,
31 | TransformerEncoder,
32 | TransformerEncoderLayer,
33 | )
34 |
35 | from .macros import NUM_MEL_BINS, NUM_TEXT_TOKENS
36 | from .visualizer import visualize
37 |
38 | IdentityNorm = IdentityNorm
39 |
40 |
41 | class Transformer(nn.Module):
42 | """It implements seq2seq Transformer TTS for debug(No StopPredictor and SpeakerEmbeding)
43 | Neural Speech Synthesis with Transformer Network
44 | https://arxiv.org/abs/1809.08895
45 | """
46 |
47 | def __init__(
48 | self,
49 | d_model: int,
50 | nhead: int,
51 | num_layers: int,
52 | norm_first: bool = True,
53 | add_prenet: bool = False,
54 | scaling_xformers: bool = False,
55 | ):
56 | """
57 | Args:
58 | d_model:
59 | The number of expected features in the input (required).
60 | nhead:
61 | The number of heads in the multiheadattention models (required).
62 | num_layers:
63 | The number of sub-decoder-layers in the decoder (required).
64 | """
65 | super().__init__()
66 | self.text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x
67 |
68 | if add_prenet:
69 | self.encoder_prenet = nn.Sequential(
70 | Transpose(),
71 | nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
72 | nn.BatchNorm1d(d_model),
73 | nn.ReLU(),
74 | nn.Dropout(0.5),
75 | nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
76 | nn.BatchNorm1d(d_model),
77 | nn.ReLU(),
78 | nn.Dropout(0.5),
79 | nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
80 | nn.BatchNorm1d(d_model),
81 | nn.ReLU(),
82 | nn.Dropout(0.5),
83 | Transpose(),
84 | nn.Linear(d_model, d_model),
85 | )
86 |
87 | self.decoder_prenet = nn.Sequential(
88 | nn.Linear(NUM_MEL_BINS, 256),
89 | nn.ReLU(),
90 | nn.Dropout(0.5),
91 | nn.Linear(256, 256),
92 | nn.ReLU(),
93 | nn.Dropout(0.5),
94 | nn.Linear(256, d_model),
95 | )
96 |
97 | assert scaling_xformers is False # TODO: update this block
98 | else:
99 | self.encoder_prenet = nn.Identity()
100 | if scaling_xformers:
101 | self.decoder_prenet = ScaledLinear(NUM_MEL_BINS, d_model)
102 | else:
103 | self.decoder_prenet = nn.Linear(NUM_MEL_BINS, d_model)
104 |
105 | self.encoder_position = SinePositionalEmbedding(
106 | d_model,
107 | dropout=0.1,
108 | scale=False,
109 | )
110 | self.decoder_position = SinePositionalEmbedding(
111 | d_model, dropout=0.1, scale=False
112 | )
113 |
114 | if scaling_xformers:
115 | self.encoder = TransformerEncoder(
116 | TransformerEncoderLayer(
117 | d_model,
118 | nhead,
119 | dim_feedforward=d_model * 4,
120 | dropout=0.1,
121 | batch_first=True,
122 | norm_first=norm_first,
123 | linear1_self_attention_cls=ScaledLinear,
124 | linear2_self_attention_cls=partial(
125 | ScaledLinear, initial_scale=0.01
126 | ),
127 | linear1_feedforward_cls=ScaledLinear,
128 | linear2_feedforward_cls=partial(
129 | ScaledLinear, initial_scale=0.01
130 | ),
131 | activation=partial(
132 | BalancedDoubleSwish,
133 | channel_dim=-1,
134 | max_abs=10.0,
135 | min_prob=0.25,
136 | ),
137 | layer_norm_cls=IdentityNorm,
138 | ),
139 | num_layers=num_layers,
140 | norm=BalancedBasicNorm(d_model) if norm_first else None,
141 | )
142 |
143 | self.decoder = nn.TransformerDecoder(
144 | TransformerDecoderLayer(
145 | d_model,
146 | nhead,
147 | dim_feedforward=d_model * 4,
148 | dropout=0.1,
149 | batch_first=True,
150 | norm_first=norm_first,
151 | linear1_self_attention_cls=ScaledLinear,
152 | linear2_self_attention_cls=partial(
153 | ScaledLinear, initial_scale=0.01
154 | ),
155 | linear1_feedforward_cls=ScaledLinear,
156 | linear2_feedforward_cls=partial(
157 | ScaledLinear, initial_scale=0.01
158 | ),
159 | activation=partial(
160 | BalancedDoubleSwish,
161 | channel_dim=-1,
162 | max_abs=10.0,
163 | min_prob=0.25,
164 | ),
165 | layer_norm_cls=IdentityNorm,
166 | ),
167 | num_layers=num_layers,
168 | norm=BalancedBasicNorm(d_model) if norm_first else None,
169 | )
170 |
171 | self.predict_layer = ScaledLinear(d_model, NUM_MEL_BINS)
172 | self.stop_layer = nn.Linear(d_model, 1)
173 | else:
174 | self.encoder = nn.TransformerEncoder(
175 | nn.TransformerEncoderLayer(
176 | d_model,
177 | nhead,
178 | dim_feedforward=d_model * 4,
179 | activation=F.relu,
180 | dropout=0.1,
181 | batch_first=True,
182 | norm_first=norm_first,
183 | ),
184 | num_layers=num_layers,
185 | norm=nn.LayerNorm(d_model) if norm_first else None,
186 | )
187 |
188 | self.decoder = nn.TransformerDecoder(
189 | nn.TransformerDecoderLayer(
190 | d_model,
191 | nhead,
192 | dim_feedforward=d_model * 4,
193 | activation=F.relu,
194 | dropout=0.1,
195 | batch_first=True,
196 | norm_first=norm_first,
197 | ),
198 | num_layers=num_layers,
199 | norm=nn.LayerNorm(d_model) if norm_first else None,
200 | )
201 |
202 | self.predict_layer = nn.Linear(d_model, NUM_MEL_BINS)
203 | self.stop_layer = nn.Linear(d_model, 1)
204 |
205 | self.stop_accuracy_metric = BinaryAccuracy(
206 | threshold=0.5, multidim_average="global"
207 | )
208 |
209 | # self.apply(self._init_weights)
210 |
211 | # def _init_weights(self, module):
212 | # if isinstance(module, (nn.Linear)):
213 | # module.weight.data.normal_(mean=0.0, std=0.02)
214 | # if isinstance(module, nn.Linear) and module.bias is not None:
215 | # module.bias.data.zero_()
216 | # elif isinstance(module, nn.LayerNorm):
217 | # module.bias.data.zero_()
218 | # module.weight.data.fill_(1.0)
219 | # elif isinstance(module, nn.Embedding):
220 | # module.weight.data.normal_(mean=0.0, std=0.02)
221 |
222 | def forward(
223 | self,
224 | x: torch.Tensor,
225 | x_lens: torch.Tensor,
226 | y: torch.Tensor,
227 | y_lens: torch.Tensor,
228 | reduction: str = "sum",
229 | train_stage: int = 0,
230 | **kwargs,
231 | ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
232 | """
233 | Args:
234 | x:
235 | A 2-D tensor of shape (N, S).
236 | x_lens:
237 | A 1-D tensor of shape (N,). It contains the number of tokens in `x`
238 | before padding.
239 | y:
240 | A 3-D tensor of shape (N, T, 8).
241 | y_lens:
242 | A 1-D tensor of shape (N,). It contains the number of tokens in `x`
243 | before padding.
244 | train_stage:
245 | Not used in this model.
246 | Returns:
247 | Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy.
248 | """
249 | del train_stage
250 |
251 | assert x.ndim == 2, x.shape
252 | assert x_lens.ndim == 1, x_lens.shape
253 | assert y.ndim == 3, y.shape
254 | assert y_lens.ndim == 1, y_lens.shape
255 |
256 | assert torch.all(x_lens > 0)
257 |
258 | # NOTE: x has been padded in TextTokenCollater
259 | x_mask = make_pad_mask(x_lens).to(x.device)
260 |
261 | x = self.text_embedding(x)
262 | x = self.encoder_prenet(x)
263 | x = self.encoder_position(x)
264 | x = self.encoder(x, src_key_padding_mask=x_mask)
265 |
266 | total_loss, metrics = 0.0, {}
267 |
268 | y_mask = make_pad_mask(y_lens).to(y.device)
269 | y_mask_float = y_mask.type(torch.float32)
270 | data_mask = 1.0 - y_mask_float.unsqueeze(-1)
271 |
272 | # Training
273 | # AR Decoder
274 | def pad_y(y):
275 | y = F.pad(y, (0, 0, 1, 0, 0, 0), value=0).detach()
276 | # inputs, targets
277 | return y[:, :-1], y[:, 1:]
278 |
279 | y, targets = pad_y(y * data_mask) # mask padding as zeros
280 |
281 | y_emb = self.decoder_prenet(y)
282 | y_pos = self.decoder_position(y_emb)
283 |
284 | y_len = y_lens.max()
285 | tgt_mask = torch.triu(
286 | torch.ones(y_len, y_len, device=y.device, dtype=torch.bool),
287 | diagonal=1,
288 | )
289 | y_dec = self.decoder(
290 | y_pos,
291 | x,
292 | tgt_mask=tgt_mask,
293 | memory_key_padding_mask=x_mask,
294 | )
295 |
296 | predict = self.predict_layer(y_dec)
297 | # loss
298 | total_loss = F.mse_loss(predict, targets, reduction=reduction)
299 |
300 | logits = self.stop_layer(y_dec).squeeze(-1)
301 | stop_loss = F.binary_cross_entropy_with_logits(
302 | logits,
303 | y_mask_float.detach(),
304 | weight=1.0 + y_mask_float.detach() * 4.0,
305 | reduction=reduction,
306 | )
307 | metrics["stop_loss"] = stop_loss.detach()
308 |
309 | stop_accuracy = self.stop_accuracy_metric(
310 | (torch.sigmoid(logits) >= 0.5).type(torch.int64),
311 | y_mask.type(torch.int64),
312 | )
313 | # icefall MetricsTracker.norm_items()
314 | metrics["stop_accuracy"] = stop_accuracy.item() * y_lens.sum().type(
315 | torch.float32
316 | )
317 |
318 | return ((x, predict), total_loss + 100.0 * stop_loss, metrics)
319 |
320 | def inference(
321 | self,
322 | x: torch.Tensor,
323 | x_lens: torch.Tensor,
324 | y: Any = None,
325 | **kwargs,
326 | ) -> torch.Tensor:
327 | """
328 | Args:
329 | x:
330 | A 2-D tensor of shape (1, S).
331 | x_lens:
332 | A 1-D tensor of shape (1,). It contains the number of tokens in `x`
333 | before padding.
334 | Returns:
335 | Return the predicted audio code matrix and cross-entropy loss.
336 | """
337 | assert x.ndim == 2, x.shape
338 | assert x_lens.ndim == 1, x_lens.shape
339 |
340 | assert torch.all(x_lens > 0)
341 |
342 | x_mask = make_pad_mask(x_lens).to(x.device)
343 |
344 | x = self.text_embedding(x)
345 | x = self.encoder_prenet(x)
346 | x = self.encoder_position(x)
347 | x = self.encoder(x, src_key_padding_mask=x_mask)
348 |
349 | x_mask = make_pad_mask(x_lens).to(x.device)
350 |
351 | # AR Decoder
352 | # TODO: Managing decoder steps avoid repetitive computation
353 | y = torch.zeros(
354 | [x.shape[0], 1, NUM_MEL_BINS], dtype=torch.float32, device=x.device
355 | )
356 | while True:
357 | y_emb = self.decoder_prenet(y)
358 | y_pos = self.decoder_position(y_emb)
359 |
360 | tgt_mask = torch.triu(
361 | torch.ones(
362 | y.shape[1], y.shape[1], device=y.device, dtype=torch.bool
363 | ),
364 | diagonal=1,
365 | )
366 |
367 | y_dec = self.decoder(
368 | y_pos,
369 | x,
370 | tgt_mask=tgt_mask,
371 | memory_mask=None,
372 | memory_key_padding_mask=x_mask,
373 | )
374 | predict = self.predict_layer(y_dec[:, -1:])
375 |
376 | logits = self.stop_layer(y_dec[:, -1:]) > 0 # sigmoid(0.0) = 0.5
377 | if y.shape[1] > x_lens.max() * 10 or all(logits.cpu().numpy()):
378 | print(
379 | f"TransformerTTS EOS [Text {x_lens[0]} -> Audio {y.shape[1]}]"
380 | )
381 | break
382 |
383 | y = torch.concat([y, predict], dim=1)
384 |
385 | return y[:, 1:]
386 |
387 | def visualize(
388 | self,
389 | predicts: Tuple[torch.Tensor],
390 | batch: Dict[str, Union[List, torch.Tensor]],
391 | output_dir: str,
392 | limit: int = 4,
393 | ) -> None:
394 | visualize(predicts, batch, output_dir, limit=limit)
395 |
--------------------------------------------------------------------------------
/models/visualizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2023 (authors: Feiteng Li)
3 | #
4 | # See ../../../../LICENSE for clarification regarding multiple authors
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 |
19 | from typing import Dict, List, Tuple, Union
20 |
21 | import matplotlib.pyplot as plt
22 | import numpy as np
23 | import torch
24 |
25 |
26 | def visualize(
27 | predicts: Tuple[torch.Tensor],
28 | batch: Dict[str, Union[List, torch.Tensor]],
29 | output_dir: str,
30 | limit: int = 4,
31 | ) -> None:
32 | text_tokens = batch["text_tokens"].to("cpu").detach().numpy()
33 | text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy()
34 | audio_features = batch["audio_features"].to("cpu").detach().numpy()
35 | audio_features_lens = (
36 | batch["audio_features_lens"].to("cpu").detach().numpy()
37 | )
38 | assert text_tokens.ndim == 2
39 |
40 | utt_ids, texts = batch["utt_id"], batch["text"]
41 |
42 | encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy()
43 | decoder_outputs = predicts[1]
44 | if isinstance(decoder_outputs, list):
45 | decoder_outputs = decoder_outputs[-1]
46 | decoder_outputs = (
47 | decoder_outputs.to("cpu").type(torch.float32).detach().numpy()
48 | )
49 |
50 | vmin, vmax = 0, 1024 # Encodec
51 | if decoder_outputs.dtype == np.float32:
52 | vmin, vmax = -6, 0 # Fbank
53 |
54 | num_figures = 3
55 | for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])):
56 | _ = plt.figure(figsize=(14, 8 * num_figures))
57 |
58 | S = text_tokens_lens[b]
59 | T = audio_features_lens[b]
60 |
61 | # encoder
62 | plt.subplot(num_figures, 1, 1)
63 | plt.title(f"Text: {text}")
64 | plt.imshow(
65 | X=np.transpose(encoder_outputs[b]),
66 | cmap=plt.get_cmap("jet"),
67 | aspect="auto",
68 | interpolation="nearest",
69 | )
70 | plt.gca().invert_yaxis()
71 | plt.axvline(x=S - 0.4, linewidth=2, color="r")
72 | plt.xlabel("Encoder Output")
73 | plt.colorbar()
74 |
75 | # decoder
76 | plt.subplot(num_figures, 1, 2)
77 | plt.imshow(
78 | X=np.transpose(decoder_outputs[b]),
79 | cmap=plt.get_cmap("jet"),
80 | aspect="auto",
81 | interpolation="nearest",
82 | vmin=vmin,
83 | vmax=vmax,
84 | )
85 | plt.gca().invert_yaxis()
86 | plt.axvline(x=T - 0.4, linewidth=2, color="r")
87 | plt.xlabel("Decoder Output")
88 | plt.colorbar()
89 |
90 | # target
91 | plt.subplot(num_figures, 1, 3)
92 | plt.imshow(
93 | X=np.transpose(audio_features[b]),
94 | cmap=plt.get_cmap("jet"),
95 | aspect="auto",
96 | interpolation="nearest",
97 | vmin=vmin,
98 | vmax=vmax,
99 | )
100 | plt.gca().invert_yaxis()
101 | plt.axvline(x=T - 0.4, linewidth=2, color="r")
102 | plt.xlabel("Decoder Target")
103 | plt.colorbar()
104 |
105 | plt.savefig(f"{output_dir}/{utt_id}.png")
106 | plt.close()
107 |
--------------------------------------------------------------------------------
/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/modules/__init__.py
--------------------------------------------------------------------------------
/modules/embedding.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 (authors: Feiteng Li)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import math
16 |
17 | import torch
18 | import torch.nn as nn
19 |
20 |
21 | class TokenEmbedding(nn.Module):
22 | def __init__(
23 | self,
24 | dim_model: int,
25 | vocab_size: int,
26 | dropout: float = 0.0,
27 | ):
28 | super().__init__()
29 |
30 | self.vocab_size = vocab_size
31 | self.dim_model = dim_model
32 |
33 | self.dropout = torch.nn.Dropout(p=dropout)
34 | self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
35 |
36 | @property
37 | def weight(self) -> torch.Tensor:
38 | return self.word_embeddings.weight
39 |
40 | def embedding(self, index: int) -> torch.Tensor:
41 | return self.word_embeddings.weight[index : index + 1]
42 |
43 | def forward(self, x: torch.Tensor):
44 | X = self.word_embeddings(x)
45 | X = self.dropout(X)
46 |
47 | return X
48 |
49 |
50 | class SinePositionalEmbedding(nn.Module):
51 | def __init__(
52 | self,
53 | dim_model: int,
54 | dropout: float = 0.0,
55 | scale: bool = False,
56 | alpha: bool = False,
57 | ):
58 | super().__init__()
59 | self.dim_model = dim_model
60 | self.x_scale = math.sqrt(dim_model) if scale else 1.0
61 | self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
62 | self.dropout = torch.nn.Dropout(p=dropout)
63 |
64 | self.reverse = False
65 | self.pe = None
66 | self.extend_pe(torch.tensor(0.0).expand(1, 4000))
67 |
68 | def extend_pe(self, x):
69 | """Reset the positional encodings."""
70 | if self.pe is not None:
71 | if self.pe.size(1) >= x.size(1):
72 | if self.pe.dtype != x.dtype or self.pe.device != x.device:
73 | self.pe = self.pe.to(dtype=x.dtype, device=x.device)
74 | return
75 | pe = torch.zeros(x.size(1), self.dim_model)
76 | if self.reverse:
77 | position = torch.arange(
78 | x.size(1) - 1, -1, -1.0, dtype=torch.float32
79 | ).unsqueeze(1)
80 | else:
81 | position = torch.arange(
82 | 0, x.size(1), dtype=torch.float32
83 | ).unsqueeze(1)
84 | div_term = torch.exp(
85 | torch.arange(0, self.dim_model, 2, dtype=torch.float32)
86 | * -(math.log(10000.0) / self.dim_model)
87 | )
88 | pe[:, 0::2] = torch.sin(position * div_term)
89 | pe[:, 1::2] = torch.cos(position * div_term)
90 | pe = pe.unsqueeze(0)
91 | self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
92 |
93 | def forward(self, x: torch.Tensor) -> torch.Tensor:
94 | self.extend_pe(x)
95 | output = x.unsqueeze(-1) if x.ndim == 2 else x
96 | output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
97 | return self.dropout(output)
98 |
--------------------------------------------------------------------------------
/modules/scheduler.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2023 (authors: Feiteng Li)
3 | #
4 | # See ../../../../LICENSE for clarification regarding multiple authors
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 |
19 | import torch
20 |
21 | from modules.optim import Eden
22 |
23 |
24 | def calc_lr(step, dim_embed, warmup_steps):
25 | return dim_embed ** (-0.5) * min(
26 | step ** (-0.5), step * warmup_steps ** (-1.5)
27 | )
28 |
29 |
30 | class NoamScheduler(torch.optim.lr_scheduler._LRScheduler):
31 | def __init__(
32 | self,
33 | base_lr: float,
34 | optimizer: torch.optim.Optimizer,
35 | dim_embed: int,
36 | warmup_steps: int,
37 | last_epoch: int = -1,
38 | verbose: bool = False,
39 | ) -> None:
40 |
41 | self.dim_embed = dim_embed
42 | self.base_lr = base_lr
43 | self.warmup_steps = warmup_steps
44 | self.num_param_groups = len(optimizer.param_groups)
45 |
46 | super().__init__(optimizer, last_epoch, verbose)
47 |
48 | def get_lr(self) -> float:
49 | lr = self.base_lr * calc_lr(
50 | self._step_count, self.dim_embed, self.warmup_steps
51 | )
52 | return [lr] * self.num_param_groups
53 |
54 | def set_step(self, step: int):
55 | self._step_count = step
56 |
57 |
58 | def get_scheduler(params, optimizer):
59 | if params.scheduler_name.lower() == "eden":
60 | scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps)
61 | elif params.scheduler_name.lower() == "noam":
62 | scheduler = NoamScheduler(
63 | params.base_lr,
64 | optimizer,
65 | params.decoder_dim,
66 | warmup_steps=params.warmup_steps,
67 | )
68 | # scheduler.set_step(params.start_batch or params.batch_idx_train)
69 | elif params.scheduler_name.lower() == "cosine":
70 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
71 | params.warmup_steps,
72 | optimizer,
73 | eta_min=params.base_lr,
74 | )
75 | else:
76 | raise NotImplementedError(f"{params.scheduler_name}")
77 |
78 | return scheduler
79 |
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/.DS_Store
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/README:
--------------------------------------------------------------------------------
1 | Pretrained Punkt Models -- Jan Strunk (New version trained after issues 313 and 514 had been corrected)
2 |
3 | Most models were prepared using the test corpora from Kiss and Strunk (2006). Additional models have
4 | been contributed by various people using NLTK for sentence boundary detection.
5 |
6 | For information about how to use these models, please confer the tokenization HOWTO:
7 | http://nltk.googlecode.com/svn/trunk/doc/howto/tokenize.html
8 | and chapter 3.8 of the NLTK book:
9 | http://nltk.googlecode.com/svn/trunk/doc/book/ch03.html#sec-segmentation
10 |
11 | There are pretrained tokenizers for the following languages:
12 |
13 | File Language Source Contents Size of training corpus(in tokens) Model contributed by
14 | =======================================================================================================================================================================
15 | czech.pickle Czech Multilingual Corpus 1 (ECI) Lidove Noviny ~345,000 Jan Strunk / Tibor Kiss
16 | Literarni Noviny
17 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
18 | danish.pickle Danish Avisdata CD-Rom Ver. 1.1. 1995 Berlingske Tidende ~550,000 Jan Strunk / Tibor Kiss
19 | (Berlingske Avisdata, Copenhagen) Weekend Avisen
20 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
21 | dutch.pickle Dutch Multilingual Corpus 1 (ECI) De Limburger ~340,000 Jan Strunk / Tibor Kiss
22 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
23 | english.pickle English Penn Treebank (LDC) Wall Street Journal ~469,000 Jan Strunk / Tibor Kiss
24 | (American)
25 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
26 | estonian.pickle Estonian University of Tartu, Estonia Eesti Ekspress ~359,000 Jan Strunk / Tibor Kiss
27 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
28 | finnish.pickle Finnish Finnish Parole Corpus, Finnish Books and major national ~364,000 Jan Strunk / Tibor Kiss
29 | Text Bank (Suomen Kielen newspapers
30 | Tekstipankki)
31 | Finnish Center for IT Science
32 | (CSC)
33 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
34 | french.pickle French Multilingual Corpus 1 (ECI) Le Monde ~370,000 Jan Strunk / Tibor Kiss
35 | (European)
36 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
37 | german.pickle German Neue Zürcher Zeitung AG Neue Zürcher Zeitung ~847,000 Jan Strunk / Tibor Kiss
38 | (Switzerland) CD-ROM
39 | (Uses "ss"
40 | instead of "ß")
41 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
42 | greek.pickle Greek Efstathios Stamatatos To Vima (TO BHMA) ~227,000 Jan Strunk / Tibor Kiss
43 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
44 | italian.pickle Italian Multilingual Corpus 1 (ECI) La Stampa, Il Mattino ~312,000 Jan Strunk / Tibor Kiss
45 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
46 | norwegian.pickle Norwegian Centre for Humanities Bergens Tidende ~479,000 Jan Strunk / Tibor Kiss
47 | (Bokmål and Information Technologies,
48 | Nynorsk) Bergen
49 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
50 | polish.pickle Polish Polish National Corpus Literature, newspapers, etc. ~1,000,000 Krzysztof Langner
51 | (http://www.nkjp.pl/)
52 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
53 | portuguese.pickle Portuguese CETENFolha Corpus Folha de São Paulo ~321,000 Jan Strunk / Tibor Kiss
54 | (Brazilian) (Linguateca)
55 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
56 | slovene.pickle Slovene TRACTOR Delo ~354,000 Jan Strunk / Tibor Kiss
57 | Slovene Academy for Arts
58 | and Sciences
59 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
60 | spanish.pickle Spanish Multilingual Corpus 1 (ECI) Sur ~353,000 Jan Strunk / Tibor Kiss
61 | (European)
62 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
63 | swedish.pickle Swedish Multilingual Corpus 1 (ECI) Dagens Nyheter ~339,000 Jan Strunk / Tibor Kiss
64 | (and some other texts)
65 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
66 | turkish.pickle Turkish METU Turkish Corpus Milliyet ~333,000 Jan Strunk / Tibor Kiss
67 | (Türkçe Derlem Projesi)
68 | University of Ankara
69 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
70 |
71 | The corpora contained about 400,000 tokens on average and mostly consisted of newspaper text converted to
72 | Unicode using the codecs module.
73 |
74 | Kiss, Tibor and Strunk, Jan (2006): Unsupervised Multilingual Sentence Boundary Detection.
75 | Computational Linguistics 32: 485-525.
76 |
77 | ---- Training Code ----
78 |
79 | # import punkt
80 | import nltk.tokenize.punkt
81 |
82 | # Make a new Tokenizer
83 | tokenizer = nltk.tokenize.punkt.PunktSentenceTokenizer()
84 |
85 | # Read in training corpus (one example: Slovene)
86 | import codecs
87 | text = codecs.open("slovene.plain","Ur","iso-8859-2").read()
88 |
89 | # Train tokenizer
90 | tokenizer.train(text)
91 |
92 | # Dump pickled tokenizer
93 | import pickle
94 | out = open("slovene.pickle","wb")
95 | pickle.dump(tokenizer, out)
96 | out.close()
97 |
98 | ---------
99 |
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/czech.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/PY3/czech.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/danish.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/PY3/danish.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/dutch.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/PY3/dutch.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/english.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/PY3/english.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/estonian.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/PY3/estonian.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/finnish.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/PY3/finnish.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/french.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/PY3/french.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/german.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/PY3/german.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/greek.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/PY3/greek.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/italian.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/PY3/italian.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/malayalam.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/PY3/malayalam.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/norwegian.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/PY3/norwegian.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/polish.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/PY3/polish.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/portuguese.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/PY3/portuguese.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/russian.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/PY3/russian.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/slovene.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/PY3/slovene.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/spanish.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/PY3/spanish.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/swedish.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/PY3/swedish.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/turkish.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/PY3/turkish.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/README:
--------------------------------------------------------------------------------
1 | Pretrained Punkt Models -- Jan Strunk (New version trained after issues 313 and 514 had been corrected)
2 |
3 | Most models were prepared using the test corpora from Kiss and Strunk (2006). Additional models have
4 | been contributed by various people using NLTK for sentence boundary detection.
5 |
6 | For information about how to use these models, please confer the tokenization HOWTO:
7 | http://nltk.googlecode.com/svn/trunk/doc/howto/tokenize.html
8 | and chapter 3.8 of the NLTK book:
9 | http://nltk.googlecode.com/svn/trunk/doc/book/ch03.html#sec-segmentation
10 |
11 | There are pretrained tokenizers for the following languages:
12 |
13 | File Language Source Contents Size of training corpus(in tokens) Model contributed by
14 | =======================================================================================================================================================================
15 | czech.pickle Czech Multilingual Corpus 1 (ECI) Lidove Noviny ~345,000 Jan Strunk / Tibor Kiss
16 | Literarni Noviny
17 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
18 | danish.pickle Danish Avisdata CD-Rom Ver. 1.1. 1995 Berlingske Tidende ~550,000 Jan Strunk / Tibor Kiss
19 | (Berlingske Avisdata, Copenhagen) Weekend Avisen
20 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
21 | dutch.pickle Dutch Multilingual Corpus 1 (ECI) De Limburger ~340,000 Jan Strunk / Tibor Kiss
22 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
23 | english.pickle English Penn Treebank (LDC) Wall Street Journal ~469,000 Jan Strunk / Tibor Kiss
24 | (American)
25 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
26 | estonian.pickle Estonian University of Tartu, Estonia Eesti Ekspress ~359,000 Jan Strunk / Tibor Kiss
27 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
28 | finnish.pickle Finnish Finnish Parole Corpus, Finnish Books and major national ~364,000 Jan Strunk / Tibor Kiss
29 | Text Bank (Suomen Kielen newspapers
30 | Tekstipankki)
31 | Finnish Center for IT Science
32 | (CSC)
33 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
34 | french.pickle French Multilingual Corpus 1 (ECI) Le Monde ~370,000 Jan Strunk / Tibor Kiss
35 | (European)
36 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
37 | german.pickle German Neue Zürcher Zeitung AG Neue Zürcher Zeitung ~847,000 Jan Strunk / Tibor Kiss
38 | (Switzerland) CD-ROM
39 | (Uses "ss"
40 | instead of "ß")
41 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
42 | greek.pickle Greek Efstathios Stamatatos To Vima (TO BHMA) ~227,000 Jan Strunk / Tibor Kiss
43 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
44 | italian.pickle Italian Multilingual Corpus 1 (ECI) La Stampa, Il Mattino ~312,000 Jan Strunk / Tibor Kiss
45 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
46 | norwegian.pickle Norwegian Centre for Humanities Bergens Tidende ~479,000 Jan Strunk / Tibor Kiss
47 | (Bokmål and Information Technologies,
48 | Nynorsk) Bergen
49 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
50 | polish.pickle Polish Polish National Corpus Literature, newspapers, etc. ~1,000,000 Krzysztof Langner
51 | (http://www.nkjp.pl/)
52 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
53 | portuguese.pickle Portuguese CETENFolha Corpus Folha de São Paulo ~321,000 Jan Strunk / Tibor Kiss
54 | (Brazilian) (Linguateca)
55 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
56 | slovene.pickle Slovene TRACTOR Delo ~354,000 Jan Strunk / Tibor Kiss
57 | Slovene Academy for Arts
58 | and Sciences
59 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
60 | spanish.pickle Spanish Multilingual Corpus 1 (ECI) Sur ~353,000 Jan Strunk / Tibor Kiss
61 | (European)
62 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
63 | swedish.pickle Swedish Multilingual Corpus 1 (ECI) Dagens Nyheter ~339,000 Jan Strunk / Tibor Kiss
64 | (and some other texts)
65 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
66 | turkish.pickle Turkish METU Turkish Corpus Milliyet ~333,000 Jan Strunk / Tibor Kiss
67 | (Türkçe Derlem Projesi)
68 | University of Ankara
69 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
70 |
71 | The corpora contained about 400,000 tokens on average and mostly consisted of newspaper text converted to
72 | Unicode using the codecs module.
73 |
74 | Kiss, Tibor and Strunk, Jan (2006): Unsupervised Multilingual Sentence Boundary Detection.
75 | Computational Linguistics 32: 485-525.
76 |
77 | ---- Training Code ----
78 |
79 | # import punkt
80 | import nltk.tokenize.punkt
81 |
82 | # Make a new Tokenizer
83 | tokenizer = nltk.tokenize.punkt.PunktSentenceTokenizer()
84 |
85 | # Read in training corpus (one example: Slovene)
86 | import codecs
87 | text = codecs.open("slovene.plain","Ur","iso-8859-2").read()
88 |
89 | # Train tokenizer
90 | tokenizer.train(text)
91 |
92 | # Dump pickled tokenizer
93 | import pickle
94 | out = open("slovene.pickle","wb")
95 | pickle.dump(tokenizer, out)
96 | out.close()
97 |
98 | ---------
99 |
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/czech.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/czech.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/danish.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/danish.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/dutch.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/dutch.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/estonian.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/estonian.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/finnish.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/finnish.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/french.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/french.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/german.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/german.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/italian.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/italian.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/malayalam.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/malayalam.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/norwegian.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/norwegian.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/polish.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/polish.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/portuguese.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/portuguese.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/russian.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/russian.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/slovene.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/slovene.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/spanish.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/spanish.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/swedish.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/swedish.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/turkish.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/nltk_data/tokenizers/punkt/turkish.pickle
--------------------------------------------------------------------------------
/presets/acou_1.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/acou_1.npz
--------------------------------------------------------------------------------
/presets/acou_2.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/acou_2.npz
--------------------------------------------------------------------------------
/presets/acou_3.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/acou_3.npz
--------------------------------------------------------------------------------
/presets/acou_4.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/acou_4.npz
--------------------------------------------------------------------------------
/presets/alan.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/alan.npz
--------------------------------------------------------------------------------
/presets/amused.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/amused.npz
--------------------------------------------------------------------------------
/presets/anger.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/anger.npz
--------------------------------------------------------------------------------
/presets/babara.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/babara.npz
--------------------------------------------------------------------------------
/presets/bronya.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/bronya.npz
--------------------------------------------------------------------------------
/presets/cafe.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/cafe.npz
--------------------------------------------------------------------------------
/presets/dingzhen.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/dingzhen.npz
--------------------------------------------------------------------------------
/presets/disgust.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/disgust.npz
--------------------------------------------------------------------------------
/presets/emo_amused.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/emo_amused.npz
--------------------------------------------------------------------------------
/presets/emo_anger.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/emo_anger.npz
--------------------------------------------------------------------------------
/presets/emo_neutral.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/emo_neutral.npz
--------------------------------------------------------------------------------
/presets/emo_sleepy.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/emo_sleepy.npz
--------------------------------------------------------------------------------
/presets/emotion_sleepiness.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/emotion_sleepiness.npz
--------------------------------------------------------------------------------
/presets/en2zh_tts_1.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/en2zh_tts_1.npz
--------------------------------------------------------------------------------
/presets/en2zh_tts_2.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/en2zh_tts_2.npz
--------------------------------------------------------------------------------
/presets/en2zh_tts_3.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/en2zh_tts_3.npz
--------------------------------------------------------------------------------
/presets/en2zh_tts_4.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/en2zh_tts_4.npz
--------------------------------------------------------------------------------
/presets/esta.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/esta.npz
--------------------------------------------------------------------------------
/presets/fuxuan.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/fuxuan.npz
--------------------------------------------------------------------------------
/presets/librispeech_1.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/librispeech_1.npz
--------------------------------------------------------------------------------
/presets/librispeech_2.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/librispeech_2.npz
--------------------------------------------------------------------------------
/presets/librispeech_3.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/librispeech_3.npz
--------------------------------------------------------------------------------
/presets/librispeech_4.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/librispeech_4.npz
--------------------------------------------------------------------------------
/presets/neutral.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/neutral.npz
--------------------------------------------------------------------------------
/presets/paimon.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/paimon.npz
--------------------------------------------------------------------------------
/presets/rosalia.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/rosalia.npz
--------------------------------------------------------------------------------
/presets/seel.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/seel.npz
--------------------------------------------------------------------------------
/presets/sleepiness.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/sleepiness.npz
--------------------------------------------------------------------------------
/presets/vctk_1.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/vctk_1.npz
--------------------------------------------------------------------------------
/presets/vctk_2.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/vctk_2.npz
--------------------------------------------------------------------------------
/presets/vctk_3.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/vctk_3.npz
--------------------------------------------------------------------------------
/presets/vctk_4.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/vctk_4.npz
--------------------------------------------------------------------------------
/presets/yaesakura.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/yaesakura.npz
--------------------------------------------------------------------------------
/presets/zh2en_tts_1.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/zh2en_tts_1.npz
--------------------------------------------------------------------------------
/presets/zh2en_tts_2.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/zh2en_tts_2.npz
--------------------------------------------------------------------------------
/presets/zh2en_tts_3.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/zh2en_tts_3.npz
--------------------------------------------------------------------------------
/presets/zh2en_tts_4.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/presets/zh2en_tts_4.npz
--------------------------------------------------------------------------------
/prompts/en-1.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/prompts/en-1.wav
--------------------------------------------------------------------------------
/prompts/en-2.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/prompts/en-2.wav
--------------------------------------------------------------------------------
/prompts/ja-1.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/prompts/ja-1.wav
--------------------------------------------------------------------------------
/prompts/ja-2.ogg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/prompts/ja-2.ogg
--------------------------------------------------------------------------------
/prompts/ph.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/prompts/ph.txt
--------------------------------------------------------------------------------
/prompts/zh-1.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/prompts/zh-1.wav
--------------------------------------------------------------------------------
/prompts/zh-2.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Plachtaa/VALL-E-X/3faaf8ccadb154d63b38070caf518ce9309ea0f4/prompts/zh-2.wav
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | soundfile
2 | numpy
3 | torch
4 | torchvision
5 | torchaudio
6 | tokenizers
7 | encodec
8 | langid
9 | wget
10 | unidecode
11 | pyopenjtalk-prebuilt
12 | pypinyin
13 | inflect
14 | cn2an
15 | jieba
16 | eng_to_ipa
17 | openai-whisper
18 | matplotlib
19 | gradio==3.41.2
20 | nltk
21 | sudachipy
22 | sudachidict_core
23 | vocos
24 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | # from icefall.utils import make_pad_mask
4 |
5 | from .symbol_table import SymbolTable
6 |
7 | # make_pad_mask = make_pad_mask
8 | SymbolTable = SymbolTable
9 |
10 |
11 | class Transpose(nn.Identity):
12 | """(N, T, D) -> (N, D, T)"""
13 |
14 | def forward(self, input: torch.Tensor) -> torch.Tensor:
15 | return input.transpose(1, 2)
16 |
--------------------------------------------------------------------------------
/utils/download.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import requests
3 |
4 |
5 | def download_file_from_google_drive(id, destination):
6 | URL = "https://docs.google.com/uc?export=download&confirm=1"
7 |
8 | session = requests.Session()
9 |
10 | response = session.get(URL, params={"id": id}, stream=True)
11 | token = get_confirm_token(response)
12 |
13 | if token:
14 | params = {"id": id, "confirm": token}
15 | response = session.get(URL, params=params, stream=True)
16 |
17 | save_response_content(response, destination)
18 |
19 |
20 | def get_confirm_token(response):
21 | for key, value in response.cookies.items():
22 | if key.startswith("download_warning"):
23 | return value
24 |
25 | return None
26 |
27 |
28 | def save_response_content(response, destination):
29 | CHUNK_SIZE = 32768
30 |
31 | with open(destination, "wb", encoding='utf-8') as f:
32 | for chunk in response.iter_content(CHUNK_SIZE):
33 | if chunk: # filter out keep-alive new chunks
34 | f.write(chunk)
35 |
36 |
37 | def main():
38 | if len(sys.argv) >= 3:
39 | file_id = sys.argv[1]
40 | destination = sys.argv[2]
41 | else:
42 | file_id = "TAKE_ID_FROM_SHAREABLE_LINK"
43 | destination = "DESTINATION_FILE_ON_YOUR_DISK"
44 | print(f"dowload {file_id} to {destination}")
45 | download_file_from_google_drive(file_id, destination)
46 |
47 |
48 | if __name__ == "__main__":
49 | main()
--------------------------------------------------------------------------------
/utils/g2p/__init__.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 | import utils.g2p.cleaners
3 | from utils.g2p.symbols import symbols
4 | from tokenizers import Tokenizer
5 |
6 | # Mappings from symbol to numeric ID and vice versa:
7 | _symbol_to_id = {s: i for i, s in enumerate(symbols)}
8 | _id_to_symbol = {i: s for i, s in enumerate(symbols)}
9 |
10 |
11 | class PhonemeBpeTokenizer:
12 | def __init__(self, tokenizer_path = "./utils/g2p/bpe_1024.json"):
13 | self.tokenizer = Tokenizer.from_file(tokenizer_path)
14 |
15 | def tokenize(self, text):
16 | # 1. convert text to phoneme
17 | phonemes, langs = _clean_text(text, ['cje_cleaners'])
18 | # 2. replace blank space " " with "_"
19 | phonemes = phonemes.replace(" ", "_")
20 | # 3. tokenize phonemes
21 | phoneme_tokens = self.tokenizer.encode(phonemes).ids
22 | assert(len(phoneme_tokens) == len(langs))
23 | if not len(phoneme_tokens):
24 | raise ValueError("Empty text is given")
25 | return phoneme_tokens, langs
26 |
27 | def text_to_sequence(text, cleaner_names):
28 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
29 | Args:
30 | text: string to convert to a sequence
31 | cleaner_names: names of the cleaner functions to run the text through
32 | Returns:
33 | List of integers corresponding to the symbols in the text
34 | '''
35 | sequence = []
36 | symbol_to_id = {s: i for i, s in enumerate(symbols)}
37 | clean_text = _clean_text(text, cleaner_names)
38 | for symbol in clean_text:
39 | if symbol not in symbol_to_id.keys():
40 | continue
41 | symbol_id = symbol_to_id[symbol]
42 | sequence += [symbol_id]
43 | return sequence
44 |
45 |
46 | def cleaned_text_to_sequence(cleaned_text):
47 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
48 | Args:
49 | text: string to convert to a sequence
50 | Returns:
51 | List of integers corresponding to the symbols in the text
52 | '''
53 | sequence = [_symbol_to_id[symbol] for symbol in cleaned_text if symbol in _symbol_to_id.keys()]
54 | return sequence
55 |
56 |
57 | def sequence_to_text(sequence):
58 | '''Converts a sequence of IDs back to a string'''
59 | result = ''
60 | for symbol_id in sequence:
61 | s = _id_to_symbol[symbol_id]
62 | result += s
63 | return result
64 |
65 |
66 | def _clean_text(text, cleaner_names):
67 | for name in cleaner_names:
68 | cleaner = getattr(utils.g2p.cleaners, name)
69 | if not cleaner:
70 | raise Exception('Unknown cleaner: %s' % name)
71 | text, langs = cleaner(text)
72 | return text, langs
73 |
--------------------------------------------------------------------------------
/utils/g2p/bpe_69.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "1.0",
3 | "truncation": null,
4 | "padding": null,
5 | "added_tokens": [
6 | {
7 | "id": 0,
8 | "content": "[UNK]",
9 | "single_word": false,
10 | "lstrip": false,
11 | "rstrip": false,
12 | "normalized": false,
13 | "special": true
14 | },
15 | {
16 | "id": 1,
17 | "content": "[CLS]",
18 | "single_word": false,
19 | "lstrip": false,
20 | "rstrip": false,
21 | "normalized": false,
22 | "special": true
23 | },
24 | {
25 | "id": 2,
26 | "content": "[SEP]",
27 | "single_word": false,
28 | "lstrip": false,
29 | "rstrip": false,
30 | "normalized": false,
31 | "special": true
32 | },
33 | {
34 | "id": 3,
35 | "content": "[PAD]",
36 | "single_word": false,
37 | "lstrip": false,
38 | "rstrip": false,
39 | "normalized": false,
40 | "special": true
41 | },
42 | {
43 | "id": 4,
44 | "content": "[MASK]",
45 | "single_word": false,
46 | "lstrip": false,
47 | "rstrip": false,
48 | "normalized": false,
49 | "special": true
50 | }
51 | ],
52 | "normalizer": null,
53 | "pre_tokenizer": {
54 | "type": "Whitespace"
55 | },
56 | "post_processor": null,
57 | "decoder": null,
58 | "model": {
59 | "type": "BPE",
60 | "dropout": null,
61 | "unk_token": "[UNK]",
62 | "continuing_subword_prefix": null,
63 | "end_of_word_suffix": null,
64 | "fuse_unk": false,
65 | "byte_fallback": false,
66 | "vocab": {
67 | "[UNK]": 0,
68 | "[CLS]": 1,
69 | "[SEP]": 2,
70 | "[PAD]": 3,
71 | "[MASK]": 4,
72 | "!": 5,
73 | "#": 6,
74 | "*": 7,
75 | ",": 8,
76 | "-": 9,
77 | ".": 10,
78 | "=": 11,
79 | "?": 12,
80 | "N": 13,
81 | "Q": 14,
82 | "^": 15,
83 | "_": 16,
84 | "`": 17,
85 | "a": 18,
86 | "b": 19,
87 | "d": 20,
88 | "e": 21,
89 | "f": 22,
90 | "g": 23,
91 | "h": 24,
92 | "i": 25,
93 | "j": 26,
94 | "k": 27,
95 | "l": 28,
96 | "m": 29,
97 | "n": 30,
98 | "o": 31,
99 | "p": 32,
100 | "s": 33,
101 | "t": 34,
102 | "u": 35,
103 | "v": 36,
104 | "w": 37,
105 | "x": 38,
106 | "y": 39,
107 | "z": 40,
108 | "~": 41,
109 | "æ": 42,
110 | "ç": 43,
111 | "ð": 44,
112 | "ŋ": 45,
113 | "ɑ": 46,
114 | "ɔ": 47,
115 | "ə": 48,
116 | "ɛ": 49,
117 | "ɥ": 50,
118 | "ɪ": 51,
119 | "ɫ": 52,
120 | "ɯ": 53,
121 | "ɸ": 54,
122 | "ɹ": 55,
123 | "ɾ": 56,
124 | "ʃ": 57,
125 | "ʊ": 58,
126 | "ʑ": 59,
127 | "ʒ": 60,
128 | "ʰ": 61,
129 | "ˈ": 62,
130 | "ˌ": 63,
131 | "θ": 64,
132 | "…": 65,
133 | "⁼": 66,
134 | "↑": 67,
135 | "→": 68,
136 | "↓": 69
137 | },
138 | "merges": [
139 | ]
140 | }
141 | }
--------------------------------------------------------------------------------
/utils/g2p/cleaners.py:
--------------------------------------------------------------------------------
1 | import re
2 | from utils.g2p.japanese import japanese_to_romaji_with_accent, japanese_to_ipa, japanese_to_ipa2, japanese_to_ipa3
3 | from utils.g2p.mandarin import number_to_chinese, chinese_to_bopomofo, latin_to_bopomofo, chinese_to_romaji, chinese_to_lazy_ipa, chinese_to_ipa, chinese_to_ipa2
4 | from utils.g2p.english import english_to_lazy_ipa, english_to_ipa2, english_to_lazy_ipa2
5 | patterns = [r'\[EN\](.*?)\[EN\]', r'\[ZH\](.*?)\[ZH\]', r'\[JA\](.*?)\[JA\]']
6 | def japanese_cleaners(text):
7 | text = japanese_to_romaji_with_accent(text)
8 | text = re.sub(r'([A-Za-z])$', r'\1.', text)
9 | return text
10 |
11 | def japanese_cleaners2(text):
12 | return japanese_cleaners(text).replace('ts', 'ʦ').replace('...', '…')
13 |
14 | def chinese_cleaners(text):
15 | '''Pipeline for Chinese text'''
16 | text = number_to_chinese(text)
17 | text = chinese_to_bopomofo(text)
18 | text = latin_to_bopomofo(text)
19 | text = re.sub(r'([ˉˊˇˋ˙])$', r'\1。', text)
20 | return text
21 |
22 | def cje_cleaners(text):
23 | matches = []
24 | for pattern in patterns:
25 | matches.extend(re.finditer(pattern, text))
26 |
27 | matches.sort(key=lambda x: x.start()) # Sort matches by their start positions
28 |
29 | outputs = ""
30 | output_langs = []
31 |
32 | for match in matches:
33 | text_segment = text[match.start():match.end()]
34 | phon = clean_one(text_segment)
35 | if "[EN]" in text_segment:
36 | lang = 'en'
37 | elif "[ZH]" in text_segment:
38 | lang = 'zh'
39 | elif "[JA]" in text_segment:
40 | lang = 'ja'
41 | else:
42 | raise ValueError("If you see this error, please report this bug to issues.")
43 | outputs += phon
44 | output_langs += [lang] * len(phon)
45 | assert len(outputs) == len(output_langs)
46 | return outputs, output_langs
47 |
48 |
49 | def clean_one(text):
50 | if text.find('[ZH]') != -1:
51 | text = re.sub(r'\[ZH\](.*?)\[ZH\]',
52 | lambda x: chinese_to_ipa(x.group(1))+' ', text)
53 | if text.find('[JA]') != -1:
54 | text = re.sub(r'\[JA\](.*?)\[JA\]',
55 | lambda x: japanese_to_ipa2(x.group(1))+' ', text)
56 | if text.find('[EN]') != -1:
57 | text = re.sub(r'\[EN\](.*?)\[EN\]',
58 | lambda x: english_to_ipa2(x.group(1))+' ', text)
59 | text = re.sub(r'\s+$', '', text)
60 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
61 | return text
62 |
--------------------------------------------------------------------------------
/utils/g2p/english.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | '''
4 | Cleaners are transformations that run over the input text at both training and eval time.
5 |
6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8 | 1. "english_cleaners" for English text
9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12 | the symbols in symbols.py to match your data).
13 | '''
14 |
15 |
16 | # Regular expression matching whitespace:
17 |
18 |
19 | import re
20 | from unidecode import unidecode
21 | import inflect
22 | _inflect = inflect.engine()
23 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
24 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
25 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
26 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
27 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
28 | _number_re = re.compile(r'[0-9]+')
29 |
30 | # List of (regular expression, replacement) pairs for abbreviations:
31 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
32 | ('mrs', 'misess'),
33 | ('mr', 'mister'),
34 | ('dr', 'doctor'),
35 | ('st', 'saint'),
36 | ('co', 'company'),
37 | ('jr', 'junior'),
38 | ('maj', 'major'),
39 | ('gen', 'general'),
40 | ('drs', 'doctors'),
41 | ('rev', 'reverend'),
42 | ('lt', 'lieutenant'),
43 | ('hon', 'honorable'),
44 | ('sgt', 'sergeant'),
45 | ('capt', 'captain'),
46 | ('esq', 'esquire'),
47 | ('ltd', 'limited'),
48 | ('col', 'colonel'),
49 | ('ft', 'fort'),
50 | ]]
51 |
52 |
53 | # List of (ipa, lazy ipa) pairs:
54 | _lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
55 | ('r', 'ɹ'),
56 | ('æ', 'e'),
57 | ('ɑ', 'a'),
58 | ('ɔ', 'o'),
59 | ('ð', 'z'),
60 | ('θ', 's'),
61 | ('ɛ', 'e'),
62 | ('ɪ', 'i'),
63 | ('ʊ', 'u'),
64 | ('ʒ', 'ʥ'),
65 | ('ʤ', 'ʥ'),
66 | ('ˈ', '↓'),
67 | ]]
68 |
69 | # List of (ipa, lazy ipa2) pairs:
70 | _lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
71 | ('r', 'ɹ'),
72 | ('ð', 'z'),
73 | ('θ', 's'),
74 | ('ʒ', 'ʑ'),
75 | ('ʤ', 'dʑ'),
76 | ('ˈ', '↓'),
77 | ]]
78 |
79 | # List of (ipa, ipa2) pairs
80 | _ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
81 | ('r', 'ɹ'),
82 | ('ʤ', 'dʒ'),
83 | ('ʧ', 'tʃ')
84 | ]]
85 |
86 |
87 | def expand_abbreviations(text):
88 | for regex, replacement in _abbreviations:
89 | text = re.sub(regex, replacement, text)
90 | return text
91 |
92 |
93 | def collapse_whitespace(text):
94 | return re.sub(r'\s+', ' ', text)
95 |
96 |
97 | def _remove_commas(m):
98 | return m.group(1).replace(',', '')
99 |
100 |
101 | def _expand_decimal_point(m):
102 | return m.group(1).replace('.', ' point ')
103 |
104 |
105 | def _expand_dollars(m):
106 | match = m.group(1)
107 | parts = match.split('.')
108 | if len(parts) > 2:
109 | return match + ' dollars' # Unexpected format
110 | dollars = int(parts[0]) if parts[0] else 0
111 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
112 | if dollars and cents:
113 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
114 | cent_unit = 'cent' if cents == 1 else 'cents'
115 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
116 | elif dollars:
117 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
118 | return '%s %s' % (dollars, dollar_unit)
119 | elif cents:
120 | cent_unit = 'cent' if cents == 1 else 'cents'
121 | return '%s %s' % (cents, cent_unit)
122 | else:
123 | return 'zero dollars'
124 |
125 |
126 | def _expand_ordinal(m):
127 | return _inflect.number_to_words(m.group(0))
128 |
129 |
130 | def _expand_number(m):
131 | num = int(m.group(0))
132 | if num > 1000 and num < 3000:
133 | if num == 2000:
134 | return 'two thousand'
135 | elif num > 2000 and num < 2010:
136 | return 'two thousand ' + _inflect.number_to_words(num % 100)
137 | elif num % 100 == 0:
138 | return _inflect.number_to_words(num // 100) + ' hundred'
139 | else:
140 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
141 | else:
142 | return _inflect.number_to_words(num, andword='')
143 |
144 |
145 | def normalize_numbers(text):
146 | text = re.sub(_comma_number_re, _remove_commas, text)
147 | text = re.sub(_pounds_re, r'\1 pounds', text)
148 | text = re.sub(_dollars_re, _expand_dollars, text)
149 | text = re.sub(_decimal_number_re, _expand_decimal_point, text)
150 | text = re.sub(_ordinal_re, _expand_ordinal, text)
151 | text = re.sub(_number_re, _expand_number, text)
152 | return text
153 |
154 |
155 | def mark_dark_l(text):
156 | return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text)
157 |
158 |
159 | def english_to_ipa(text):
160 | import eng_to_ipa as ipa
161 | text = unidecode(text).lower()
162 | text = expand_abbreviations(text)
163 | text = normalize_numbers(text)
164 | phonemes = ipa.convert(text)
165 | phonemes = collapse_whitespace(phonemes)
166 | return phonemes
167 |
168 |
169 | def english_to_lazy_ipa(text):
170 | text = english_to_ipa(text)
171 | for regex, replacement in _lazy_ipa:
172 | text = re.sub(regex, replacement, text)
173 | return text
174 |
175 |
176 | def english_to_ipa2(text):
177 | text = english_to_ipa(text)
178 | text = mark_dark_l(text)
179 | for regex, replacement in _ipa_to_ipa2:
180 | text = re.sub(regex, replacement, text)
181 | return text.replace('...', '…')
182 |
183 |
184 | def english_to_lazy_ipa2(text):
185 | text = english_to_ipa(text)
186 | for regex, replacement in _lazy_ipa2:
187 | text = re.sub(regex, replacement, text)
188 | return text
189 |
--------------------------------------------------------------------------------
/utils/g2p/japanese.py:
--------------------------------------------------------------------------------
1 | import re
2 | from unidecode import unidecode
3 |
4 |
5 |
6 | # Regular expression matching Japanese without punctuation marks:
7 | _japanese_characters = re.compile(
8 | r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
9 |
10 | # Regular expression matching non-Japanese characters or punctuation marks:
11 | _japanese_marks = re.compile(
12 | r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
13 |
14 | # List of (symbol, Japanese) pairs for marks:
15 | _symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [
16 | ('%', 'パーセント')
17 | ]]
18 |
19 | # List of (romaji, ipa) pairs for marks:
20 | _romaji_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
21 | ('ts', 'ʦ'),
22 | ('u', 'ɯ'),
23 | ('j', 'ʥ'),
24 | ('y', 'j'),
25 | ('ni', 'n^i'),
26 | ('nj', 'n^'),
27 | ('hi', 'çi'),
28 | ('hj', 'ç'),
29 | ('f', 'ɸ'),
30 | ('I', 'i*'),
31 | ('U', 'ɯ*'),
32 | ('r', 'ɾ')
33 | ]]
34 |
35 | # List of (romaji, ipa2) pairs for marks:
36 | _romaji_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
37 | ('u', 'ɯ'),
38 | ('ʧ', 'tʃ'),
39 | ('j', 'dʑ'),
40 | ('y', 'j'),
41 | ('ni', 'n^i'),
42 | ('nj', 'n^'),
43 | ('hi', 'çi'),
44 | ('hj', 'ç'),
45 | ('f', 'ɸ'),
46 | ('I', 'i*'),
47 | ('U', 'ɯ*'),
48 | ('r', 'ɾ')
49 | ]]
50 |
51 | # List of (consonant, sokuon) pairs:
52 | _real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [
53 | (r'Q([↑↓]*[kg])', r'k#\1'),
54 | (r'Q([↑↓]*[tdjʧ])', r't#\1'),
55 | (r'Q([↑↓]*[sʃ])', r's\1'),
56 | (r'Q([↑↓]*[pb])', r'p#\1')
57 | ]]
58 |
59 | # List of (consonant, hatsuon) pairs:
60 | _real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [
61 | (r'N([↑↓]*[pbm])', r'm\1'),
62 | (r'N([↑↓]*[ʧʥj])', r'n^\1'),
63 | (r'N([↑↓]*[tdn])', r'n\1'),
64 | (r'N([↑↓]*[kg])', r'ŋ\1')
65 | ]]
66 |
67 |
68 | def symbols_to_japanese(text):
69 | for regex, replacement in _symbols_to_japanese:
70 | text = re.sub(regex, replacement, text)
71 | return text
72 |
73 |
74 | def japanese_to_romaji_with_accent(text):
75 | '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html'''
76 | import pyopenjtalk
77 | text = symbols_to_japanese(text)
78 | sentences = re.split(_japanese_marks, text)
79 | marks = re.findall(_japanese_marks, text)
80 | text = ''
81 | for i, sentence in enumerate(sentences):
82 | if re.match(_japanese_characters, sentence):
83 | if text != '':
84 | text += ' '
85 | labels = pyopenjtalk.extract_fullcontext(sentence)
86 | for n, label in enumerate(labels):
87 | phoneme = re.search(r'\-([^\+]*)\+', label).group(1)
88 | if phoneme not in ['sil', 'pau']:
89 | text += phoneme.replace('ch', 'ʧ').replace('sh',
90 | 'ʃ').replace('cl', 'Q')
91 | else:
92 | continue
93 | # n_moras = int(re.search(r'/F:(\d+)_', label).group(1))
94 | a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1))
95 | a2 = int(re.search(r"\+(\d+)\+", label).group(1))
96 | a3 = int(re.search(r"\+(\d+)/", label).group(1))
97 | if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil', 'pau']:
98 | a2_next = -1
99 | else:
100 | a2_next = int(
101 | re.search(r"\+(\d+)\+", labels[n + 1]).group(1))
102 | # Accent phrase boundary
103 | if a3 == 1 and a2_next == 1:
104 | text += ' '
105 | # Falling
106 | elif a1 == 0 and a2_next == a2 + 1:
107 | text += '↓'
108 | # Rising
109 | elif a2 == 1 and a2_next == 2:
110 | text += '↑'
111 | if i < len(marks):
112 | text += unidecode(marks[i]).replace(' ', '')
113 | return text
114 |
115 |
116 | def get_real_sokuon(text):
117 | for regex, replacement in _real_sokuon:
118 | text = re.sub(regex, replacement, text)
119 | return text
120 |
121 |
122 | def get_real_hatsuon(text):
123 | for regex, replacement in _real_hatsuon:
124 | text = re.sub(regex, replacement, text)
125 | return text
126 |
127 |
128 | def japanese_to_ipa(text):
129 | text = japanese_to_romaji_with_accent(text).replace('...', '…')
130 | text = re.sub(
131 | r'([aiueo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
132 | text = get_real_sokuon(text)
133 | text = get_real_hatsuon(text)
134 | for regex, replacement in _romaji_to_ipa:
135 | text = re.sub(regex, replacement, text)
136 | return text
137 |
138 |
139 | def japanese_to_ipa2(text):
140 | text = japanese_to_romaji_with_accent(text).replace('...', '…')
141 | text = get_real_sokuon(text)
142 | text = get_real_hatsuon(text)
143 | for regex, replacement in _romaji_to_ipa2:
144 | text = re.sub(regex, replacement, text)
145 | return text
146 |
147 |
148 | def japanese_to_ipa3(text):
149 | text = japanese_to_ipa2(text).replace('n^', 'ȵ').replace(
150 | 'ʃ', 'ɕ').replace('*', '\u0325').replace('#', '\u031a')
151 | text = re.sub(
152 | r'([aiɯeo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
153 | text = re.sub(r'((?:^|\s)(?:ts|tɕ|[kpt]))', r'\1ʰ', text)
154 | return text
155 |
--------------------------------------------------------------------------------
/utils/g2p/mandarin.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import re
4 | import jieba
5 | import cn2an
6 | import logging
7 |
8 |
9 | # List of (Latin alphabet, bopomofo) pairs:
10 | _latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
11 | ('a', 'ㄟˉ'),
12 | ('b', 'ㄅㄧˋ'),
13 | ('c', 'ㄙㄧˉ'),
14 | ('d', 'ㄉㄧˋ'),
15 | ('e', 'ㄧˋ'),
16 | ('f', 'ㄝˊㄈㄨˋ'),
17 | ('g', 'ㄐㄧˋ'),
18 | ('h', 'ㄝˇㄑㄩˋ'),
19 | ('i', 'ㄞˋ'),
20 | ('j', 'ㄐㄟˋ'),
21 | ('k', 'ㄎㄟˋ'),
22 | ('l', 'ㄝˊㄛˋ'),
23 | ('m', 'ㄝˊㄇㄨˋ'),
24 | ('n', 'ㄣˉ'),
25 | ('o', 'ㄡˉ'),
26 | ('p', 'ㄆㄧˉ'),
27 | ('q', 'ㄎㄧㄡˉ'),
28 | ('r', 'ㄚˋ'),
29 | ('s', 'ㄝˊㄙˋ'),
30 | ('t', 'ㄊㄧˋ'),
31 | ('u', 'ㄧㄡˉ'),
32 | ('v', 'ㄨㄧˉ'),
33 | ('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'),
34 | ('x', 'ㄝˉㄎㄨˋㄙˋ'),
35 | ('y', 'ㄨㄞˋ'),
36 | ('z', 'ㄗㄟˋ')
37 | ]]
38 |
39 | # List of (bopomofo, romaji) pairs:
40 | _bopomofo_to_romaji = [(re.compile('%s' % x[0]), x[1]) for x in [
41 | ('ㄅㄛ', 'p⁼wo'),
42 | ('ㄆㄛ', 'pʰwo'),
43 | ('ㄇㄛ', 'mwo'),
44 | ('ㄈㄛ', 'fwo'),
45 | ('ㄅ', 'p⁼'),
46 | ('ㄆ', 'pʰ'),
47 | ('ㄇ', 'm'),
48 | ('ㄈ', 'f'),
49 | ('ㄉ', 't⁼'),
50 | ('ㄊ', 'tʰ'),
51 | ('ㄋ', 'n'),
52 | ('ㄌ', 'l'),
53 | ('ㄍ', 'k⁼'),
54 | ('ㄎ', 'kʰ'),
55 | ('ㄏ', 'h'),
56 | ('ㄐ', 'ʧ⁼'),
57 | ('ㄑ', 'ʧʰ'),
58 | ('ㄒ', 'ʃ'),
59 | ('ㄓ', 'ʦ`⁼'),
60 | ('ㄔ', 'ʦ`ʰ'),
61 | ('ㄕ', 's`'),
62 | ('ㄖ', 'ɹ`'),
63 | ('ㄗ', 'ʦ⁼'),
64 | ('ㄘ', 'ʦʰ'),
65 | ('ㄙ', 's'),
66 | ('ㄚ', 'a'),
67 | ('ㄛ', 'o'),
68 | ('ㄜ', 'ə'),
69 | ('ㄝ', 'e'),
70 | ('ㄞ', 'ai'),
71 | ('ㄟ', 'ei'),
72 | ('ㄠ', 'au'),
73 | ('ㄡ', 'ou'),
74 | ('ㄧㄢ', 'yeNN'),
75 | ('ㄢ', 'aNN'),
76 | ('ㄧㄣ', 'iNN'),
77 | ('ㄣ', 'əNN'),
78 | ('ㄤ', 'aNg'),
79 | ('ㄧㄥ', 'iNg'),
80 | ('ㄨㄥ', 'uNg'),
81 | ('ㄩㄥ', 'yuNg'),
82 | ('ㄥ', 'əNg'),
83 | ('ㄦ', 'əɻ'),
84 | ('ㄧ', 'i'),
85 | ('ㄨ', 'u'),
86 | ('ㄩ', 'ɥ'),
87 | ('ˉ', '→'),
88 | ('ˊ', '↑'),
89 | ('ˇ', '↓↑'),
90 | ('ˋ', '↓'),
91 | ('˙', ''),
92 | (',', ','),
93 | ('。', '.'),
94 | ('!', '!'),
95 | ('?', '?'),
96 | ('—', '-')
97 | ]]
98 |
99 | # List of (romaji, ipa) pairs:
100 | _romaji_to_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
101 | ('ʃy', 'ʃ'),
102 | ('ʧʰy', 'ʧʰ'),
103 | ('ʧ⁼y', 'ʧ⁼'),
104 | ('NN', 'n'),
105 | ('Ng', 'ŋ'),
106 | ('y', 'j'),
107 | ('h', 'x')
108 | ]]
109 |
110 | # List of (bopomofo, ipa) pairs:
111 | _bopomofo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
112 | ('ㄅㄛ', 'p⁼wo'),
113 | ('ㄆㄛ', 'pʰwo'),
114 | ('ㄇㄛ', 'mwo'),
115 | ('ㄈㄛ', 'fwo'),
116 | ('ㄅ', 'p⁼'),
117 | ('ㄆ', 'pʰ'),
118 | ('ㄇ', 'm'),
119 | ('ㄈ', 'f'),
120 | ('ㄉ', 't⁼'),
121 | ('ㄊ', 'tʰ'),
122 | ('ㄋ', 'n'),
123 | ('ㄌ', 'l'),
124 | ('ㄍ', 'k⁼'),
125 | ('ㄎ', 'kʰ'),
126 | ('ㄏ', 'x'),
127 | ('ㄐ', 'tʃ⁼'),
128 | ('ㄑ', 'tʃʰ'),
129 | ('ㄒ', 'ʃ'),
130 | ('ㄓ', 'ts`⁼'),
131 | ('ㄔ', 'ts`ʰ'),
132 | ('ㄕ', 's`'),
133 | ('ㄖ', 'ɹ`'),
134 | ('ㄗ', 'ts⁼'),
135 | ('ㄘ', 'tsʰ'),
136 | ('ㄙ', 's'),
137 | ('ㄚ', 'a'),
138 | ('ㄛ', 'o'),
139 | ('ㄜ', 'ə'),
140 | ('ㄝ', 'ɛ'),
141 | ('ㄞ', 'aɪ'),
142 | ('ㄟ', 'eɪ'),
143 | ('ㄠ', 'ɑʊ'),
144 | ('ㄡ', 'oʊ'),
145 | ('ㄧㄢ', 'jɛn'),
146 | ('ㄩㄢ', 'ɥæn'),
147 | ('ㄢ', 'an'),
148 | ('ㄧㄣ', 'in'),
149 | ('ㄩㄣ', 'ɥn'),
150 | ('ㄣ', 'ən'),
151 | ('ㄤ', 'ɑŋ'),
152 | ('ㄧㄥ', 'iŋ'),
153 | ('ㄨㄥ', 'ʊŋ'),
154 | ('ㄩㄥ', 'jʊŋ'),
155 | ('ㄥ', 'əŋ'),
156 | ('ㄦ', 'əɻ'),
157 | ('ㄧ', 'i'),
158 | ('ㄨ', 'u'),
159 | ('ㄩ', 'ɥ'),
160 | ('ˉ', '→'),
161 | ('ˊ', '↑'),
162 | ('ˇ', '↓↑'),
163 | ('ˋ', '↓'),
164 | ('˙', ''),
165 | (',', ','),
166 | ('。', '.'),
167 | ('!', '!'),
168 | ('?', '?'),
169 | ('—', '-')
170 | ]]
171 |
172 | # List of (bopomofo, ipa2) pairs:
173 | _bopomofo_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
174 | ('ㄅㄛ', 'pwo'),
175 | ('ㄆㄛ', 'pʰwo'),
176 | ('ㄇㄛ', 'mwo'),
177 | ('ㄈㄛ', 'fwo'),
178 | ('ㄅ', 'p'),
179 | ('ㄆ', 'pʰ'),
180 | ('ㄇ', 'm'),
181 | ('ㄈ', 'f'),
182 | ('ㄉ', 't'),
183 | ('ㄊ', 'tʰ'),
184 | ('ㄋ', 'n'),
185 | ('ㄌ', 'l'),
186 | ('ㄍ', 'k'),
187 | ('ㄎ', 'kʰ'),
188 | ('ㄏ', 'h'),
189 | ('ㄐ', 'tɕ'),
190 | ('ㄑ', 'tɕʰ'),
191 | ('ㄒ', 'ɕ'),
192 | ('ㄓ', 'tʂ'),
193 | ('ㄔ', 'tʂʰ'),
194 | ('ㄕ', 'ʂ'),
195 | ('ㄖ', 'ɻ'),
196 | ('ㄗ', 'ts'),
197 | ('ㄘ', 'tsʰ'),
198 | ('ㄙ', 's'),
199 | ('ㄚ', 'a'),
200 | ('ㄛ', 'o'),
201 | ('ㄜ', 'ɤ'),
202 | ('ㄝ', 'ɛ'),
203 | ('ㄞ', 'aɪ'),
204 | ('ㄟ', 'eɪ'),
205 | ('ㄠ', 'ɑʊ'),
206 | ('ㄡ', 'oʊ'),
207 | ('ㄧㄢ', 'jɛn'),
208 | ('ㄩㄢ', 'yæn'),
209 | ('ㄢ', 'an'),
210 | ('ㄧㄣ', 'in'),
211 | ('ㄩㄣ', 'yn'),
212 | ('ㄣ', 'ən'),
213 | ('ㄤ', 'ɑŋ'),
214 | ('ㄧㄥ', 'iŋ'),
215 | ('ㄨㄥ', 'ʊŋ'),
216 | ('ㄩㄥ', 'jʊŋ'),
217 | ('ㄥ', 'ɤŋ'),
218 | ('ㄦ', 'əɻ'),
219 | ('ㄧ', 'i'),
220 | ('ㄨ', 'u'),
221 | ('ㄩ', 'y'),
222 | ('ˉ', '˥'),
223 | ('ˊ', '˧˥'),
224 | ('ˇ', '˨˩˦'),
225 | ('ˋ', '˥˩'),
226 | ('˙', ''),
227 | (',', ','),
228 | ('。', '.'),
229 | ('!', '!'),
230 | ('?', '?'),
231 | ('—', '-')
232 | ]]
233 |
234 |
235 | def number_to_chinese(text):
236 | numbers = re.findall(r'\d+(?:\.?\d+)?', text)
237 | for number in numbers:
238 | text = text.replace(number, cn2an.an2cn(number), 1)
239 | return text
240 |
241 |
242 | def chinese_to_bopomofo(text):
243 | from pypinyin import lazy_pinyin, BOPOMOFO
244 | text = text.replace('、', ',').replace(';', ',').replace(':', ',')
245 | words = jieba.lcut(text, cut_all=False)
246 | text = ''
247 | for word in words:
248 | bopomofos = lazy_pinyin(word, BOPOMOFO)
249 | if not re.search('[\u4e00-\u9fff]', word):
250 | text += word
251 | continue
252 | for i in range(len(bopomofos)):
253 | bopomofos[i] = re.sub(r'([\u3105-\u3129])$', r'\1ˉ', bopomofos[i])
254 | if text != '':
255 | text += ' '
256 | text += ''.join(bopomofos)
257 | return text
258 |
259 |
260 | def latin_to_bopomofo(text):
261 | for regex, replacement in _latin_to_bopomofo:
262 | text = re.sub(regex, replacement, text)
263 | return text
264 |
265 |
266 | def bopomofo_to_romaji(text):
267 | for regex, replacement in _bopomofo_to_romaji:
268 | text = re.sub(regex, replacement, text)
269 | return text
270 |
271 |
272 | def bopomofo_to_ipa(text):
273 | for regex, replacement in _bopomofo_to_ipa:
274 | text = re.sub(regex, replacement, text)
275 | return text
276 |
277 |
278 | def bopomofo_to_ipa2(text):
279 | for regex, replacement in _bopomofo_to_ipa2:
280 | text = re.sub(regex, replacement, text)
281 | return text
282 |
283 |
284 | def chinese_to_romaji(text):
285 | text = number_to_chinese(text)
286 | text = chinese_to_bopomofo(text)
287 | text = latin_to_bopomofo(text)
288 | text = bopomofo_to_romaji(text)
289 | text = re.sub('i([aoe])', r'y\1', text)
290 | text = re.sub('u([aoəe])', r'w\1', text)
291 | text = re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑ ]+|$)',
292 | r'\1ɹ`\2', text).replace('ɻ', 'ɹ`')
293 | text = re.sub('([ʦs][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text)
294 | return text
295 |
296 |
297 | def chinese_to_lazy_ipa(text):
298 | text = chinese_to_romaji(text)
299 | for regex, replacement in _romaji_to_ipa:
300 | text = re.sub(regex, replacement, text)
301 | return text
302 |
303 |
304 | def chinese_to_ipa(text):
305 | text = number_to_chinese(text)
306 | text = chinese_to_bopomofo(text)
307 | text = latin_to_bopomofo(text)
308 | text = bopomofo_to_ipa(text)
309 | text = re.sub('i([aoe])', r'j\1', text)
310 | text = re.sub('u([aoəe])', r'w\1', text)
311 | text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)',
312 | r'\1ɹ`\2', text).replace('ɻ', 'ɹ`')
313 | text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text)
314 | return text
315 |
316 |
317 | def chinese_to_ipa2(text):
318 | text = number_to_chinese(text)
319 | text = chinese_to_bopomofo(text)
320 | text = latin_to_bopomofo(text)
321 | text = bopomofo_to_ipa2(text)
322 | text = re.sub(r'i([aoe])', r'j\1', text)
323 | text = re.sub(r'u([aoəe])', r'w\1', text)
324 | text = re.sub(r'([ʂɹ]ʰ?)([˩˨˧˦˥ ]+|$)', r'\1ʅ\2', text)
325 | text = re.sub(r'(sʰ?)([˩˨˧˦˥ ]+|$)', r'\1ɿ\2', text)
326 | return text
327 |
--------------------------------------------------------------------------------
/utils/g2p/symbols.py:
--------------------------------------------------------------------------------
1 | '''
2 | Defines the set of symbols used in text input to the model.
3 | '''
4 |
5 | # japanese_cleaners
6 | # _pad = '_'
7 | # _punctuation = ',.!?-'
8 | # _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ '
9 |
10 |
11 | '''# japanese_cleaners2
12 | _pad = '_'
13 | _punctuation = ',.!?-~…'
14 | _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ '
15 | '''
16 |
17 |
18 | '''# korean_cleaners
19 | _pad = '_'
20 | _punctuation = ',.!?…~'
21 | _letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
22 | '''
23 |
24 | '''# chinese_cleaners
25 | _pad = '_'
26 | _punctuation = ',。!?—…'
27 | _letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ '
28 | '''
29 |
30 | # # zh_ja_mixture_cleaners
31 | # _pad = '_'
32 | # _punctuation = ',.!?-~…'
33 | # _letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ '
34 |
35 |
36 | '''# sanskrit_cleaners
37 | _pad = '_'
38 | _punctuation = '।'
39 | _letters = 'ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऽािीुूृॄेैोौ्ॠॢ '
40 | '''
41 |
42 | '''# cjks_cleaners
43 | _pad = '_'
44 | _punctuation = ',.!?-~…'
45 | _letters = 'NQabdefghijklmnopstuvwxyzʃʧʥʦɯɹəɥçɸɾβŋɦː⁼ʰ`^#*=→↓↑ '
46 | '''
47 |
48 | '''# thai_cleaners
49 | _pad = '_'
50 | _punctuation = '.!? '
51 | _letters = 'กขฃคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลวศษสหฬอฮฯะัาำิีึืุูเแโใไๅๆ็่้๊๋์'
52 | '''
53 |
54 | # # cjke_cleaners2
55 | _pad = '_'
56 | _punctuation = ',.!?-~…'
57 | _letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
58 |
59 |
60 | '''# shanghainese_cleaners
61 | _pad = '_'
62 | _punctuation = ',.!?…'
63 | _letters = 'abdfghiklmnopstuvyzøŋȵɑɔɕəɤɦɪɿʑʔʰ̩̃ᴀᴇ15678 '
64 | '''
65 |
66 | '''# chinese_dialect_cleaners
67 | _pad = '_'
68 | _punctuation = ',.!?~…─'
69 | _letters = '#Nabdefghijklmnoprstuvwxyzæçøŋœȵɐɑɒɓɔɕɗɘəɚɛɜɣɤɦɪɭɯɵɷɸɻɾɿʂʅʊʋʌʏʑʔʦʮʰʷˀː˥˦˧˨˩̥̩̃̚ᴀᴇ↑↓∅ⱼ '
70 | '''
71 |
72 | # Export all symbols:
73 | symbols = [_pad] + list(_punctuation) + list(_letters)
74 |
75 | # Special symbol ids
76 | SPACE_ID = symbols.index(" ")
77 |
--------------------------------------------------------------------------------
/utils/generation.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | import os
3 | import torch
4 | from vocos import Vocos
5 | import logging
6 | import langid
7 | langid.set_languages(['en', 'zh', 'ja'])
8 |
9 | import pathlib
10 | import platform
11 | if platform.system().lower() == 'windows':
12 | temp = pathlib.PosixPath
13 | pathlib.PosixPath = pathlib.WindowsPath
14 | else:
15 | temp = pathlib.WindowsPath
16 | pathlib.WindowsPath = pathlib.PosixPath
17 |
18 | import numpy as np
19 | from data.tokenizer import (
20 | AudioTokenizer,
21 | tokenize_audio,
22 | )
23 | from data.collation import get_text_token_collater
24 | from models.vallex import VALLE
25 | from utils.g2p import PhonemeBpeTokenizer
26 | from utils.sentence_cutter import split_text_into_sentences
27 |
28 | from macros import *
29 |
30 | device = torch.device("cpu")
31 | if torch.cuda.is_available():
32 | device = torch.device("cuda", 0)
33 | if torch.backends.mps.is_available():
34 | device = torch.device("mps")
35 | url = 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'
36 |
37 | checkpoints_dir = "./checkpoints/"
38 |
39 | model_checkpoint_name = "vallex-checkpoint.pt"
40 |
41 | model = None
42 |
43 | codec = None
44 |
45 | vocos = None
46 |
47 | text_tokenizer = PhonemeBpeTokenizer(tokenizer_path="./utils/g2p/bpe_69.json")
48 | text_collater = get_text_token_collater()
49 |
50 | def preload_models():
51 | global model, codec, vocos
52 | if not os.path.exists(checkpoints_dir): os.mkdir(checkpoints_dir)
53 | if not os.path.exists(os.path.join(checkpoints_dir, model_checkpoint_name)):
54 | import wget
55 | try:
56 | logging.info(
57 | "Downloading model from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt ...")
58 | # download from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt to ./checkpoints/vallex-checkpoint.pt
59 | wget.download("https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt",
60 | out="./checkpoints/vallex-checkpoint.pt", bar=wget.bar_adaptive)
61 | except Exception as e:
62 | logging.info(e)
63 | raise Exception(
64 | "\n Model weights download failed, please go to 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'"
65 | "\n manually download model weights and put it to {} .".format(os.getcwd() + "\checkpoints"))
66 | # VALL-E
67 | model = VALLE(
68 | N_DIM,
69 | NUM_HEAD,
70 | NUM_LAYERS,
71 | norm_first=True,
72 | add_prenet=False,
73 | prefix_mode=PREFIX_MODE,
74 | share_embedding=True,
75 | nar_scale_factor=1.0,
76 | prepend_bos=True,
77 | num_quantizers=NUM_QUANTIZERS,
78 | ).to(device)
79 | checkpoint = torch.load(os.path.join(checkpoints_dir, model_checkpoint_name), map_location='cpu')
80 | missing_keys, unexpected_keys = model.load_state_dict(
81 | checkpoint["model"], strict=True
82 | )
83 | assert not missing_keys
84 | model.eval()
85 |
86 | # Encodec
87 | codec = AudioTokenizer(device)
88 |
89 | vocos = Vocos.from_pretrained('charactr/vocos-encodec-24khz').to(device)
90 |
91 | @torch.no_grad()
92 | def generate_audio(text, prompt=None, language='auto', accent='no-accent'):
93 | global model, codec, vocos, text_tokenizer, text_collater
94 | text = text.replace("\n", "").strip(" ")
95 | # detect language
96 | if language == "auto":
97 | language = langid.classify(text)[0]
98 | lang_token = lang2token[language]
99 | lang = token2lang[lang_token]
100 | text = lang_token + text + lang_token
101 |
102 | # load prompt
103 | if prompt is not None:
104 | prompt_path = prompt
105 | if not os.path.exists(prompt_path):
106 | prompt_path = "./presets/" + prompt + ".npz"
107 | if not os.path.exists(prompt_path):
108 | prompt_path = "./customs/" + prompt + ".npz"
109 | if not os.path.exists(prompt_path):
110 | raise ValueError(f"Cannot find prompt {prompt}")
111 | prompt_data = np.load(prompt_path)
112 | audio_prompts = prompt_data['audio_tokens']
113 | text_prompts = prompt_data['text_tokens']
114 | lang_pr = prompt_data['lang_code']
115 | lang_pr = code2lang[int(lang_pr)]
116 |
117 | # numpy to tensor
118 | audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)
119 | text_prompts = torch.tensor(text_prompts).type(torch.int32)
120 | else:
121 | audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device)
122 | text_prompts = torch.zeros([1, 0]).type(torch.int32)
123 | lang_pr = lang if lang != 'mix' else 'en'
124 |
125 | enroll_x_lens = text_prompts.shape[-1]
126 | logging.info(f"synthesize text: {text}")
127 | phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
128 | text_tokens, text_tokens_lens = text_collater(
129 | [
130 | phone_tokens
131 | ]
132 | )
133 | text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
134 | text_tokens_lens += enroll_x_lens
135 | # accent control
136 | lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
137 | encoded_frames = model.inference(
138 | text_tokens.to(device),
139 | text_tokens_lens.to(device),
140 | audio_prompts,
141 | enroll_x_lens=enroll_x_lens,
142 | top_k=-100,
143 | temperature=1,
144 | prompt_language=lang_pr,
145 | text_language=langs if accent == "no-accent" else lang,
146 | )
147 | # Decode with Vocos
148 | frames = encoded_frames.permute(2,0,1)
149 | features = vocos.codes_to_features(frames)
150 | samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
151 |
152 | return samples.squeeze().cpu().numpy()
153 |
154 | @torch.no_grad()
155 | def generate_audio_from_long_text(text, prompt=None, language='auto', accent='no-accent', mode='sliding-window'):
156 | """
157 | For long audio generation, two modes are available.
158 | fixed-prompt: This mode will keep using the same prompt the user has provided, and generate audio sentence by sentence.
159 | sliding-window: This mode will use the last sentence as the prompt for the next sentence, but has some concern on speaker maintenance.
160 | """
161 | global model, codec, vocos, text_tokenizer, text_collater
162 | if prompt is None or prompt == "":
163 | mode = 'sliding-window' # If no prompt is given, use sliding-window mode
164 | sentences = split_text_into_sentences(text)
165 | # detect language
166 | if language == "auto":
167 | language = langid.classify(text)[0]
168 |
169 | # if initial prompt is given, encode it
170 | if prompt is not None and prompt != "":
171 | prompt_path = prompt
172 | if not os.path.exists(prompt_path):
173 | prompt_path = "./presets/" + prompt + ".npz"
174 | if not os.path.exists(prompt_path):
175 | prompt_path = "./customs/" + prompt + ".npz"
176 | if not os.path.exists(prompt_path):
177 | raise ValueError(f"Cannot find prompt {prompt}")
178 | prompt_data = np.load(prompt_path)
179 | audio_prompts = prompt_data['audio_tokens']
180 | text_prompts = prompt_data['text_tokens']
181 | lang_pr = prompt_data['lang_code']
182 | lang_pr = code2lang[int(lang_pr)]
183 |
184 | # numpy to tensor
185 | audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)
186 | text_prompts = torch.tensor(text_prompts).type(torch.int32)
187 | else:
188 | audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device)
189 | text_prompts = torch.zeros([1, 0]).type(torch.int32)
190 | lang_pr = language if language != 'mix' else 'en'
191 | if mode == 'fixed-prompt':
192 | complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)
193 | for text in sentences:
194 | text = text.replace("\n", "").strip(" ")
195 | if text == "":
196 | continue
197 | lang_token = lang2token[language]
198 | lang = token2lang[lang_token]
199 | text = lang_token + text + lang_token
200 |
201 | enroll_x_lens = text_prompts.shape[-1]
202 | logging.info(f"synthesize text: {text}")
203 | phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
204 | text_tokens, text_tokens_lens = text_collater(
205 | [
206 | phone_tokens
207 | ]
208 | )
209 | text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
210 | text_tokens_lens += enroll_x_lens
211 | # accent control
212 | lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
213 | encoded_frames = model.inference(
214 | text_tokens.to(device),
215 | text_tokens_lens.to(device),
216 | audio_prompts,
217 | enroll_x_lens=enroll_x_lens,
218 | top_k=-100,
219 | temperature=1,
220 | prompt_language=lang_pr,
221 | text_language=langs if accent == "no-accent" else lang,
222 | )
223 | complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)
224 | # Decode with Vocos
225 | frames = complete_tokens.permute(1,0,2)
226 | features = vocos.codes_to_features(frames)
227 | samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
228 | return samples.squeeze().cpu().numpy()
229 | elif mode == "sliding-window":
230 | complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)
231 | original_audio_prompts = audio_prompts
232 | original_text_prompts = text_prompts
233 | for text in sentences:
234 | text = text.replace("\n", "").strip(" ")
235 | if text == "":
236 | continue
237 | lang_token = lang2token[language]
238 | lang = token2lang[lang_token]
239 | text = lang_token + text + lang_token
240 |
241 | enroll_x_lens = text_prompts.shape[-1]
242 | logging.info(f"synthesize text: {text}")
243 | phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
244 | text_tokens, text_tokens_lens = text_collater(
245 | [
246 | phone_tokens
247 | ]
248 | )
249 | text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
250 | text_tokens_lens += enroll_x_lens
251 | # accent control
252 | lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
253 | encoded_frames = model.inference(
254 | text_tokens.to(device),
255 | text_tokens_lens.to(device),
256 | audio_prompts,
257 | enroll_x_lens=enroll_x_lens,
258 | top_k=-100,
259 | temperature=1,
260 | prompt_language=lang_pr,
261 | text_language=langs if accent == "no-accent" else lang,
262 | )
263 | complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)
264 | if torch.rand(1) < 0.5:
265 | audio_prompts = encoded_frames[:, :, -NUM_QUANTIZERS:]
266 | text_prompts = text_tokens[:, enroll_x_lens:]
267 | else:
268 | audio_prompts = original_audio_prompts
269 | text_prompts = original_text_prompts
270 | # Decode with Vocos
271 | frames = complete_tokens.permute(1,0,2)
272 | features = vocos.codes_to_features(frames)
273 | samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
274 | return samples.squeeze().cpu().numpy()
275 | else:
276 | raise ValueError(f"No such mode {mode}")
277 |
--------------------------------------------------------------------------------
/utils/prompt_making.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchaudio
4 | import logging
5 | import langid
6 | import whisper
7 | langid.set_languages(['en', 'zh', 'ja'])
8 |
9 | import numpy as np
10 | from data.tokenizer import (
11 | AudioTokenizer,
12 | tokenize_audio,
13 | )
14 | from data.collation import get_text_token_collater
15 | from utils.g2p import PhonemeBpeTokenizer
16 |
17 | from macros import *
18 |
19 | text_tokenizer = PhonemeBpeTokenizer(tokenizer_path="./utils/g2p/bpe_69.json")
20 | text_collater = get_text_token_collater()
21 |
22 | device = torch.device("cpu")
23 | if torch.cuda.is_available():
24 | device = torch.device("cuda", 0)
25 | if torch.backends.mps.is_available():
26 | device = torch.device("mps")
27 | codec = AudioTokenizer(device)
28 |
29 | if not os.path.exists("./whisper/"): os.mkdir("./whisper/")
30 | whisper_model = None
31 |
32 | @torch.no_grad()
33 | def transcribe_one(model, audio_path):
34 | # load audio and pad/trim it to fit 30 seconds
35 | audio = whisper.load_audio(audio_path)
36 | audio = whisper.pad_or_trim(audio)
37 |
38 | # make log-Mel spectrogram and move to the same device as the model
39 | mel = whisper.log_mel_spectrogram(audio).to(model.device)
40 |
41 | # detect the spoken language
42 | _, probs = model.detect_language(mel)
43 | print(f"Detected language: {max(probs, key=probs.get)}")
44 | lang = max(probs, key=probs.get)
45 | # decode the audio
46 | options = whisper.DecodingOptions(temperature=1.0, best_of=5, fp16=False if device == torch.device("cpu") else True, sample_len=150)
47 | result = whisper.decode(model, mel, options)
48 |
49 | # print the recognized text
50 | print(result.text)
51 |
52 | text_pr = result.text
53 | if text_pr.strip(" ")[-1] not in "?!.,。,?!。、":
54 | text_pr += "."
55 | return lang, text_pr
56 |
57 | def make_prompt(name, audio_prompt_path, transcript=None):
58 | global model, text_collater, text_tokenizer, codec
59 | wav_pr, sr = torchaudio.load(audio_prompt_path)
60 | # check length
61 | if wav_pr.size(-1) / sr > 15:
62 | raise ValueError(f"Prompt too long, expect length below 15 seconds, got {wav_pr / sr} seconds.")
63 | if wav_pr.size(0) == 2:
64 | wav_pr = wav_pr.mean(0, keepdim=True)
65 | text_pr, lang_pr = make_transcript(name, wav_pr, sr, transcript)
66 |
67 | # tokenize audio
68 | encoded_frames = tokenize_audio(codec, (wav_pr, sr))
69 | audio_tokens = encoded_frames[0][0].transpose(2, 1).cpu().numpy()
70 |
71 | # tokenize text
72 | phonemes, langs = text_tokenizer.tokenize(text=f"{text_pr}".strip())
73 | text_tokens, enroll_x_lens = text_collater(
74 | [
75 | phonemes
76 | ]
77 | )
78 |
79 | message = f"Detected language: {lang_pr}\n Detected text {text_pr}\n"
80 |
81 | # save as npz file
82 | save_path = os.path.join("./customs/", f"{name}.npz")
83 | np.savez(save_path, audio_tokens=audio_tokens, text_tokens=text_tokens, lang_code=lang2code[lang_pr])
84 | logging.info(f"Successful. Prompt saved to {save_path}")
85 |
86 |
87 | def make_transcript(name, wav, sr, transcript=None):
88 |
89 | if not isinstance(wav, torch.FloatTensor):
90 | wav = torch.tensor(wav)
91 | if wav.abs().max() > 1:
92 | wav /= wav.abs().max()
93 | if wav.size(-1) == 2:
94 | wav = wav.mean(-1, keepdim=False)
95 | if wav.ndim == 1:
96 | wav = wav.unsqueeze(0)
97 | assert wav.ndim and wav.size(0) == 1
98 | if transcript is None or transcript == "":
99 | logging.info("Transcript not given, using Whisper...")
100 | global whisper_model
101 | if whisper_model is None:
102 | whisper_model = whisper.load_model("medium", download_root=os.path.join(os.getcwd(), "whisper"))
103 | whisper_model.to(device)
104 | torchaudio.save(f"./prompts/{name}.wav", wav, sr)
105 | lang, text = transcribe_one(whisper_model, f"./prompts/{name}.wav")
106 | lang_token = lang2token[lang]
107 | text = lang_token + text + lang_token
108 | os.remove(f"./prompts/{name}.wav")
109 | whisper_model.cpu()
110 | else:
111 | text = transcript
112 | lang, _ = langid.classify(text)
113 | lang_token = lang2token[lang]
114 | text = lang_token + text + lang_token
115 |
116 | torch.cuda.empty_cache()
117 | return text, lang
118 |
--------------------------------------------------------------------------------
/utils/sentence_cutter.py:
--------------------------------------------------------------------------------
1 | import nltk
2 | import jieba
3 | import sudachipy
4 | import langid
5 | langid.set_languages(['en', 'zh', 'ja'])
6 |
7 | def split_text_into_sentences(text):
8 | if langid.classify(text)[0] == "en":
9 | sentences = nltk.tokenize.sent_tokenize(text)
10 |
11 | return sentences
12 | elif langid.classify(text)[0] == "zh":
13 | sentences = []
14 | segs = jieba.cut(text, cut_all=False)
15 | segs = list(segs)
16 | start = 0
17 | for i, seg in enumerate(segs):
18 | if seg in ["。", "!", "?", "……"]:
19 | sentences.append("".join(segs[start:i + 1]))
20 | start = i + 1
21 | if start < len(segs):
22 | sentences.append("".join(segs[start:]))
23 |
24 | return sentences
25 | elif langid.classify(text)[0] == "ja":
26 | sentences = []
27 | tokenizer = sudachipy.Dictionary().create()
28 | tokens = tokenizer.tokenize(text)
29 | current_sentence = ""
30 |
31 | for token in tokens:
32 | current_sentence += token.surface()
33 | if token.part_of_speech()[0] == "補助記号" and token.part_of_speech()[1] == "句点":
34 | sentences.append(current_sentence)
35 | current_sentence = ""
36 |
37 | if current_sentence:
38 | sentences.append(current_sentence)
39 |
40 | return sentences
41 |
42 | raise RuntimeError("It is impossible to reach here.")
43 |
44 | long_text = """
45 | This is a very long paragraph, so most TTS model is unable to handle it. Hence, we have to split it into several sentences. With the help of NLTK, we can split it into sentences. However, the punctuation is not preserved, so we have to add it back. How are we going to do write this code? Let's see.
46 | """
47 |
48 | long_text = """
49 | 现在我们要来尝试一下中文分句。因为很不幸的是,NLTK不支持中文分句。幸运的是,我们可以使用jieba来分句。但是,jieba分句后,标点符号会丢失,所以我们要手动添加回去。我现在正在想办法把这个例句写的更长更复杂一点,来测试jieba分句的性能。嗯......省略号,感觉不太好,因为省略号不是句号,所以jieba不会把它当作句子的结尾。会这样吗?我们来试试看。
50 | """
51 |
52 | long_text = """
53 | これなら、英語と中国語の分句もできる。でも、日本語はどうする?まつわ、ChatGPTに僕と教えてください。ちょーと待ってください。あ、出来た!
54 | """
--------------------------------------------------------------------------------
/utils/symbol_table.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang)
2 | #
3 | # See ../../../LICENSE for clarification regarding multiple authors
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from dataclasses import dataclass
18 | from dataclasses import field
19 | from typing import Dict
20 | from typing import Generic
21 | from typing import List
22 | from typing import Optional
23 | from typing import TypeVar
24 | from typing import Union
25 |
26 | Symbol = TypeVar('Symbol')
27 |
28 |
29 | # Disable __repr__ otherwise it could freeze e.g. Jupyter.
30 | @dataclass(repr=False)
31 | class SymbolTable(Generic[Symbol]):
32 | '''SymbolTable that maps symbol IDs, found on the FSA arcs to
33 | actual objects. These objects can be arbitrary Python objects
34 | that can serve as keys in a dictionary (i.e. they need to be
35 | hashable and immutable).
36 |
37 | The SymbolTable can only be read to/written from disk if the
38 | symbols are strings.
39 | '''
40 | _id2sym: Dict[int, Symbol] = field(default_factory=dict)
41 | '''Map an integer to a symbol.
42 | '''
43 |
44 | _sym2id: Dict[Symbol, int] = field(default_factory=dict)
45 | '''Map a symbol to an integer.
46 | '''
47 |
48 | _next_available_id: int = 1
49 | '''A helper internal field that helps adding new symbols
50 | to the table efficiently.
51 | '''
52 |
53 | eps: Symbol = ''
54 | '''Null symbol, always mapped to index 0.
55 | '''
56 |
57 | def __post_init__(self):
58 | for idx, sym in self._id2sym.items():
59 | assert self._sym2id[sym] == idx
60 | assert idx >= 0
61 |
62 | for sym, idx in self._sym2id.items():
63 | assert idx >= 0
64 | assert self._id2sym[idx] == sym
65 |
66 | if 0 not in self._id2sym:
67 | self._id2sym[0] = self.eps
68 | self._sym2id[self.eps] = 0
69 | else:
70 | assert self._id2sym[0] == self.eps
71 | assert self._sym2id[self.eps] == 0
72 |
73 | self._next_available_id = max(self._id2sym) + 1
74 |
75 | @staticmethod
76 | def from_str(s: str) -> 'SymbolTable':
77 | '''Build a symbol table from a string.
78 |
79 | The string consists of lines. Every line has two fields separated
80 | by space(s), tab(s) or both. The first field is the symbol and the
81 | second the integer id of the symbol.
82 |
83 | Args:
84 | s:
85 | The input string with the format described above.
86 | Returns:
87 | An instance of :class:`SymbolTable`.
88 | '''
89 | id2sym: Dict[int, str] = dict()
90 | sym2id: Dict[str, int] = dict()
91 |
92 | for line in s.split('\n'):
93 | fields = line.split()
94 | if len(fields) == 0:
95 | continue # skip empty lines
96 | assert len(fields) == 2, \
97 | f'Expect a line with 2 fields. Given: {len(fields)}'
98 | sym, idx = fields[0], int(fields[1])
99 | assert sym not in sym2id, f'Duplicated symbol {sym}'
100 | assert idx not in id2sym, f'Duplicated id {idx}'
101 | id2sym[idx] = sym
102 | sym2id[sym] = idx
103 |
104 | eps = id2sym.get(0, '')
105 |
106 | return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=eps)
107 |
108 | @staticmethod
109 | def from_file(filename: str) -> 'SymbolTable':
110 | '''Build a symbol table from file.
111 |
112 | Every line in the symbol table file has two fields separated by
113 | space(s), tab(s) or both. The following is an example file:
114 |
115 | .. code-block::
116 |
117 | 0
118 | a 1
119 | b 2
120 | c 3
121 |
122 | Args:
123 | filename:
124 | Name of the symbol table file. Its format is documented above.
125 |
126 | Returns:
127 | An instance of :class:`SymbolTable`.
128 |
129 | '''
130 | with open(filename, 'r', encoding='utf-8') as f:
131 | return SymbolTable.from_str(f.read().strip())
132 |
133 | def to_str(self) -> str:
134 | '''
135 | Returns:
136 | Return a string representation of this object. You can pass
137 | it to the method ``from_str`` to recreate an identical object.
138 | '''
139 | s = ''
140 | for idx, symbol in sorted(self._id2sym.items()):
141 | s += f'{symbol} {idx}\n'
142 | return s
143 |
144 | def to_file(self, filename: str):
145 | '''Serialize the SymbolTable to a file.
146 |
147 | Every line in the symbol table file has two fields separated by
148 | space(s), tab(s) or both. The following is an example file:
149 |
150 | .. code-block::
151 |
152 | 0
153 | a 1
154 | b 2
155 | c 3
156 |
157 | Args:
158 | filename:
159 | Name of the symbol table file. Its format is documented above.
160 | '''
161 | with open(filename, 'w', encoding='utf-8') as f:
162 | for idx, symbol in sorted(self._id2sym.items()):
163 | print(symbol, idx, file=f)
164 |
165 | def add(self, symbol: Symbol, index: Optional[int] = None) -> int:
166 | '''Add a new symbol to the SymbolTable.
167 |
168 | Args:
169 | symbol:
170 | The symbol to be added.
171 | index:
172 | Optional int id to which the symbol should be assigned.
173 | If it is not available, a ValueError will be raised.
174 |
175 | Returns:
176 | The int id to which the symbol has been assigned.
177 | '''
178 | # Already in the table? Return its ID.
179 | if symbol in self._sym2id:
180 | return self._sym2id[symbol]
181 | # Specific ID not provided - use next available.
182 | if index is None:
183 | index = self._next_available_id
184 | # Specific ID provided but not available.
185 | if index in self._id2sym:
186 | raise ValueError(f"Cannot assign id '{index}' to '{symbol}' - "
187 | f"already occupied by {self._id2sym[index]}")
188 | self._sym2id[symbol] = index
189 | self._id2sym[index] = symbol
190 |
191 | # Update next available ID if needed
192 | if self._next_available_id <= index:
193 | self._next_available_id = index + 1
194 |
195 | return index
196 |
197 | def get(self, k: Union[int, Symbol]) -> Union[Symbol, int]:
198 | '''Get a symbol for an id or get an id for a symbol
199 |
200 | Args:
201 | k:
202 | If it is an id, it tries to find the symbol corresponding
203 | to the id; if it is a symbol, it tries to find the id
204 | corresponding to the symbol.
205 |
206 | Returns:
207 | An id or a symbol depending on the given `k`.
208 | '''
209 | if isinstance(k, int):
210 | return self._id2sym[k]
211 | else:
212 | return self._sym2id[k]
213 |
214 | def merge(self, other: 'SymbolTable') -> 'SymbolTable':
215 | '''Create a union of two SymbolTables.
216 | Raises an AssertionError if the same IDs are occupied by
217 | different symbols.
218 |
219 | Args:
220 | other:
221 | A symbol table to merge with ``self``.
222 |
223 | Returns:
224 | A new symbol table.
225 | '''
226 | self._check_compatible(other)
227 |
228 | id2sym = {**self._id2sym, **other._id2sym}
229 | sym2id = {**self._sym2id, **other._sym2id}
230 |
231 | return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=self.eps)
232 |
233 | def _check_compatible(self, other: 'SymbolTable') -> None:
234 | # Epsilon compatibility
235 | assert self.eps == other.eps, f'Mismatched epsilon symbol: ' \
236 | f'{self.eps} != {other.eps}'
237 | # IDs compatibility
238 | common_ids = set(self._id2sym).intersection(other._id2sym)
239 | for idx in common_ids:
240 | assert self[idx] == other[idx], f'ID conflict for id: {idx}, ' \
241 | f'self[idx] = "{self[idx]}", ' \
242 | f'other[idx] = "{other[idx]}"'
243 | # Symbols compatibility
244 | common_symbols = set(self._sym2id).intersection(other._sym2id)
245 | for sym in common_symbols:
246 | assert self[sym] == other[sym], f'ID conflict for id: {sym}, ' \
247 | f'self[sym] = "{self[sym]}", ' \
248 | f'other[sym] = "{other[sym]}"'
249 |
250 | def __getitem__(self, item: Union[int, Symbol]) -> Union[Symbol, int]:
251 | return self.get(item)
252 |
253 | def __contains__(self, item: Union[int, Symbol]) -> bool:
254 | if isinstance(item, int):
255 | return item in self._id2sym
256 | else:
257 | return item in self._sym2id
258 |
259 | def __len__(self) -> int:
260 | return len(self._id2sym)
261 |
262 | def __eq__(self, other: 'SymbolTable') -> bool:
263 | if len(self) != len(other):
264 | return False
265 |
266 | for s in self.symbols:
267 | if self[s] != other[s]:
268 | return False
269 |
270 | return True
271 |
272 | @property
273 | def ids(self) -> List[int]:
274 | '''Returns a list of integer IDs corresponding to the symbols.
275 | '''
276 | ans = list(self._id2sym.keys())
277 | ans.sort()
278 | return ans
279 |
280 | @property
281 | def symbols(self) -> List[Symbol]:
282 | '''Returns a list of symbols (e.g., strings) corresponding to
283 | the integer IDs.
284 | '''
285 | ans = list(self._sym2id.keys())
286 | ans.sort()
287 | return ans
288 |
--------------------------------------------------------------------------------