├── LICENSE
├── README-ZH.md
├── README.md
├── customs
├── make_custom_dataset.py
└── ph.txt
├── data
├── __init__.py
├── collation.py
├── datamodule.py
├── dataset.py
├── fbank.py
├── input_strategies.py
└── tokenizer.py
├── descriptions.py
├── examples.py
├── images
└── vallex_framework.jpg
├── launch-ui.py
├── macros.py
├── model-card.md
├── models
├── __init__.py
├── macros.py
├── transformer.py
├── vallex.py
└── visualizer.py
├── modules
├── __init__.py
├── activation.py
├── embedding.py
├── optim.py
├── scaling.py
├── scheduler.py
└── transformer.py
├── nltk_data
└── tokenizers
│ └── punkt
│ ├── .DS_Store
│ ├── PY3
│ ├── README
│ ├── czech.pickle
│ ├── danish.pickle
│ ├── dutch.pickle
│ ├── english.pickle
│ ├── estonian.pickle
│ ├── finnish.pickle
│ ├── french.pickle
│ ├── german.pickle
│ ├── greek.pickle
│ ├── italian.pickle
│ ├── malayalam.pickle
│ ├── norwegian.pickle
│ ├── polish.pickle
│ ├── portuguese.pickle
│ ├── russian.pickle
│ ├── slovene.pickle
│ ├── spanish.pickle
│ ├── swedish.pickle
│ └── turkish.pickle
│ ├── README
│ ├── czech.pickle
│ ├── danish.pickle
│ ├── dutch.pickle
│ ├── english.pickle
│ ├── estonian.pickle
│ ├── finnish.pickle
│ ├── french.pickle
│ ├── german.pickle
│ ├── greek.pickle
│ ├── italian.pickle
│ ├── malayalam.pickle
│ ├── norwegian.pickle
│ ├── polish.pickle
│ ├── portuguese.pickle
│ ├── russian.pickle
│ ├── slovene.pickle
│ ├── spanish.pickle
│ ├── swedish.pickle
│ └── turkish.pickle
├── presets
├── acou_1.npz
├── acou_2.npz
├── acou_3.npz
├── acou_4.npz
├── alan.npz
├── amused.npz
├── anger.npz
├── babara.npz
├── bronya.npz
├── cafe.npz
├── dingzhen.npz
├── disgust.npz
├── emo_amused.npz
├── emo_anger.npz
├── emo_neutral.npz
├── emo_sleepy.npz
├── emotion_sleepiness.npz
├── en2zh_tts_1.npz
├── en2zh_tts_2.npz
├── en2zh_tts_3.npz
├── en2zh_tts_4.npz
├── esta.npz
├── fuxuan.npz
├── librispeech_1.npz
├── librispeech_2.npz
├── librispeech_3.npz
├── librispeech_4.npz
├── neutral.npz
├── paimon.npz
├── rosalia.npz
├── seel.npz
├── sleepiness.npz
├── vctk_1.npz
├── vctk_2.npz
├── vctk_3.npz
├── vctk_4.npz
├── yaesakura.npz
├── zh2en_tts_1.npz
├── zh2en_tts_2.npz
├── zh2en_tts_3.npz
└── zh2en_tts_4.npz
├── prompts
├── en-1.wav
├── en-2.wav
├── ja-1.wav
├── ja-2.ogg
├── ph.txt
├── zh-1.wav
└── zh-2.wav
├── requirements.txt
├── test.py
├── train.py
├── train_utils
├── __pycache__
│ └── utils.cpython-310.pyc
├── icefall
│ └── utils.py
├── lhotse
│ └── utils.py
├── model.py
└── utils.py
└── utils
├── __init__.py
├── download.py
├── g2p
├── __init__.py
├── bpe_1024.json
├── bpe_69.json
├── cleaners.py
├── english.py
├── japanese.py
├── mandarin.py
└── symbols.py
├── generation.py
├── prompt_making.py
├── sentence_cutter.py
└── symbol_table.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Songting
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README-ZH.md:
--------------------------------------------------------------------------------
1 | # VALL-E X: 多语言文本到语音合成与语音克隆 🔊
2 | [](https://discord.gg/qCBRmAnTxg)
3 |
4 | [English](README.md) | 中文
5 |
6 | 微软[VALL-E X](https://arxiv.org/pdf/2303.03926) 零样本语音合成模型的开源实现.
7 | **预训练模型现已向公众开放,供研究或应用使用。**
8 | 
9 |
10 | VALL-E X 是一个强大而创新的多语言文本转语音(TTS)模型,最初由微软发布。虽然微软最初在他们的研究论文中提出了该概念,但并未发布任何代码或预训练模型。我们认识到了这项技术的潜力和价值,复现并训练了一个开源可用的VALL-E X模型。我们很乐意与社区分享我们的预训练模型,让每个人都能体验到次世代TTS的威力。 🎧
11 |
12 | 更多细节请查看 [model card](./model-card.md).
13 |
14 | ## 📖 目录
15 | * [🚀 更新日志](#-更新日志)
16 | * [📢 功能特点](#-功能特点)
17 | * [💻 本地安装](#-本地安装)
18 | * [🎧 在线Demo](#-在线Demo)
19 | * [🐍 使用方法](#-Python中的使用方法)
20 | * [❓ FAQ](#-faq)
21 | * [🧠 TODO](#-todo)
22 |
23 | ## 🚀 Updates
24 | **2023.09.10**
25 | - 支持AR decoder的batch decoding以实现更稳定的生成结果
26 |
27 | **2023.08.30**
28 | - 将EnCodec解码器替换成了Vocos解码器,提升了音质。 (感谢[@v0xie](https://github.com/v0xie))
29 |
30 | **2023.08.23**
31 | - 加入了长文本生成功能
32 |
33 | **2023.08.20**
34 | - 加入了中文版README
35 |
36 | **2023.08.14**
37 | - 预训练模型权重已发布,从[这里](https://drive.google.com/file/d/10gdQWvP-K_e1undkvv0p2b7SU6I4Egyl/view?usp=sharing)下载。
38 |
39 | ## 💻 本地安装
40 | ### 使用pip安装,推荐使用Python 3.10,CUDA 11.7 ~ 12.0,PyTorch 2.0+
41 | ```commandline
42 | git clone https://github.com/Plachtaa/VALL-E-X.git
43 | cd VALL-E-X
44 | pip install -r requirements.txt
45 | ```
46 |
47 | > 注意:如果需要制作prompt,需要安装 ffmpeg 并将其所在文件夹加入到环境变量PATH中
48 |
49 | 第一次运行程序时,会自动下载相应的模型。如果下载失败并报错,请按照以下步骤手动下载模型。
50 |
51 | (请注意目录和文件夹的大小写)
52 |
53 | 1.检查安装目录下是否存在`checkpoints`文件夹,如果没有,在安装目录下手动创建`checkpoints`文件夹(`./checkpoints/`)。
54 |
55 | 2.检查`checkpoints`文件夹中是否有`vallex-checkpoint.pt`文件。如果没有,请从[这里](https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt)
56 | 手动下载`vallex-checkpoint.pt`文件并放到`checkpoints`文件夹里。
57 |
58 | 3.检查安装目录下是否存在`whisper`文件夹,如果没有,在安装目录下手动创建`whisper`文件夹(`./whisper/`)。
59 |
60 | 4.检查`whisper`文件夹中是否有`medium.pt`文件。如果没有,请从[这里](https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt)
61 | 手动下载`medium.pt`文件并放到`whisper`文件夹里。
62 |
63 | ## 🎧 在线Demo
64 | 如果你不想在本地安装,你可以在线体验VALL-E X的功能,点击下面的任意一个链接即可开始体验。
65 |
66 | [](https://huggingface.co/spaces/Plachta/VALL-E-X)
67 | [](https://colab.research.google.com/drive/1yyD_sz531QntLKowMHo-XxorsFBCfKul?usp=sharing)
68 |
69 |
70 | ## 📢 功能特点
71 |
72 | VALL-E X 配备有一系列尖端功能:
73 |
74 | 1. **多语言 TTS**: 可使用三种语言 - 英语、中文和日语 - 进行自然、富有表现力的语音合成。
75 |
76 | 2. **零样本语音克隆**: 仅需录制任意说话人的短短的 3~10 秒录音,VALL-E X 就能生成个性化、高质量的语音,完美还原他们的声音。
77 |
78 |
79 | 查看示例
80 |
81 | [prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/a7baa51d-a53a-41cc-a03d-6970f25fcca7)
82 |
83 |
84 | [output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/b895601a-d126-4138-beff-061aabdc7985)
85 |
86 |
87 |
88 | 3. **语音情感控制**: VALL-E X 可以合成与给定说话人录音相同情感的语音,为音频增添更多表现力。
89 |
90 |
91 | 查看示例
92 |
93 | https://github.com/Plachtaa/VALL-E-X/assets/112609742/56fa9988-925e-4757-82c5-83ecb0df6266
94 |
95 |
96 | https://github.com/Plachtaa/VALL-E-X/assets/112609742/699c47a3-d502-4801-8364-bd89bcc0b8f1
97 |
98 |
99 |
100 | 4. **零样本跨语言语音合成**: VALL-E X 可以合成与给定说话人母语不同的另一种语言,在不影响口音和流利度的同时,保留该说话人的音色与情感。以下是一个使用日语母语者进行英文与中文合成的样例: 🇯🇵 🗣
101 |
102 |
103 | 查看示例
104 |
105 | [jp-prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/ea6e2ee4-139a-41b4-837e-0bd04dda6e19)
106 |
107 |
108 | [en-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/db8f9782-923f-425e-ba94-e8c1bd48f207)
109 |
110 |
111 | [zh-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/15829d79-e448-44d3-8965-fafa7a3f8c28)
112 |
113 |
114 |
115 | 5. **口音控制**: VALL-E X 允许您控制所合成音频的口音,比如说中文带英语口音或反之。 🇨🇳 💬
116 |
117 |
118 | 查看示例
119 |
120 | [en-prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/f688d7f6-70ef-46ec-b1cc-355c31e78b3b)
121 |
122 |
123 | [zh-accent-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/be59c7ca-b45b-44ca-a30d-4d800c950ccc)
124 |
125 |
126 | [en-accent-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/8b4f4f9b-f299-4ea4-a548-137437b71738)
127 |
128 |
129 |
130 | 6. **声学环境保留**: 当给定说话人的录音在不同的声学环境下录制时,VALL-E X 可以保留该声学环境,使合成语音听起来更加自然。
131 |
132 |
133 | 查看示例
134 |
135 | [noise-prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/68986d88-abd0-4d1d-96e4-4f893eb9259e)
136 |
137 |
138 | [noise-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/96c4c612-4516-4683-8804-501b70938608)
139 |
140 |
141 |
142 |
143 | 你可以访问我们的[demo页面](https://plachtaa.github.io/) 来浏览更多示例!
144 |
145 | ## 💻 Python中的使用方法
146 |
147 |
148 | 🪑 基本使用
149 |
150 | ```python
151 | from utils.generation import SAMPLE_RATE, generate_audio, preload_models
152 | from scipy.io.wavfile import write as write_wav
153 | from IPython.display import Audio
154 |
155 | # download and load all models
156 | preload_models()
157 |
158 | # generate audio from text
159 | text_prompt = """
160 | Hello, my name is Nose. And uh, and I like hamburger. Hahaha... But I also have other interests such as playing tactic toast.
161 | """
162 | audio_array = generate_audio(text_prompt)
163 |
164 | # save audio to disk
165 | write_wav("vallex_generation.wav", SAMPLE_RATE, audio_array)
166 |
167 | # play text in notebook
168 | Audio(audio_array, rate=SAMPLE_RATE)
169 | ```
170 |
171 | [hamburger.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/578d7bbe-cda9-483e-898c-29646edc8f2e)
172 |
173 |
174 |
175 |
176 | 🌎 多语言
177 |
178 | 该VALL-E X实现支持三种语言:英语、中文和日语。您可以通过设置`language`参数来指定语言。默认情况下,该模型将自动检测语言。
179 |
180 |
181 | ```python
182 |
183 | text_prompt = """
184 | チュソクは私のお気に入りの祭りです。 私は数日間休んで、友人や家族との時間を過ごすことができます。
185 | """
186 | audio_array = generate_audio(text_prompt)
187 | ```
188 |
189 | [vallex_japanese.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/ee57a688-3e83-4be5-b0fe-019d16eec51c)
190 |
191 | *注意:即使在一句话中混合多种语言的情况下,VALL-E X也能完美地控制口音,但是您需要手动标记各个句子对应的语言以便于我们的G2P工具识别它们。*
192 | ```python
193 | text_prompt = """
194 | [EN]The Thirty Years' War was a devastating conflict that had a profound impact on Europe.[EN]
195 | [ZH]这是历史的开始。 如果您想听更多,请继续。[ZH]
196 | """
197 | audio_array = generate_audio(text_prompt, language='mix')
198 | ```
199 |
200 | [vallex_codeswitch.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/d8667abf-bd08-499f-a383-a861d852f98a)
201 |
202 |
203 |
204 |
205 | 📼 预设音色
206 |
207 | 我们提供十几种说话人音色可直接VALL-E X使用! 在[这里](/presets)浏览所有可用音色。
208 |
209 | > VALL-E X 尝试匹配给定预设音色的音调、音高、情感和韵律。该模型还尝试保留音乐、环境噪声等。
210 | ```python
211 | text_prompt = """
212 | I am an innocent boy with a smoky voice. It is a great honor for me to speak at the United Nations today.
213 | """
214 | audio_array = generate_audio(text_prompt, prompt="dingzhen")
215 | ```
216 |
217 | [smoky.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/d3f55732-b1cd-420f-87d6-eab60db14dc5)
218 |
219 |
220 |
221 |
222 | 🎙声音克隆
223 |
224 | VALL-E X 支持声音克隆!你可以使用任何人,角色,甚至是你自己的声音,来制作一个音频提示。在你使用该音频提示时,VALL-E X 将会使用与其相似的声音来合成文本。
225 |
226 | 你需要提供一段3~10秒长的语音,以及该语音对应的文本,来制作音频提示。你也可以将文本留空,让[Whisper](https://github.com/openai/whisper)模型为你生成文本。
227 | > VALL-E X 尝试匹配给定音频提示的音调、音高、情感和韵律。该模型还尝试保留音乐、环境噪声等。
228 |
229 | ```python
230 | from utils.prompt_making import make_prompt
231 |
232 | ### Use given transcript
233 | make_prompt(name="paimon", audio_prompt_path="paimon_prompt.wav",
234 | transcript="Just, what was that? Paimon thought we were gonna get eaten.")
235 |
236 | ### Alternatively, use whisper
237 | make_prompt(name="paimon", audio_prompt_path="paimon_prompt.wav")
238 | ```
239 | 来尝试一下刚刚做好的音频提示吧!
240 | ```python
241 | from utils.generation import SAMPLE_RATE, generate_audio, preload_models
242 | from scipy.io.wavfile import write as write_wav
243 |
244 | # download and load all models
245 | preload_models()
246 |
247 | text_prompt = """
248 | Hey, Traveler, Listen to this, This machine has taken my voice, and now it can talk just like me!
249 | """
250 | audio_array = generate_audio(text_prompt, prompt="paimon")
251 |
252 | write_wav("paimon_cloned.wav", SAMPLE_RATE, audio_array)
253 |
254 | ```
255 |
256 | [paimon_prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/e7922859-9d12-4e2a-8651-e156e4280311)
257 |
258 |
259 | [paimon_cloned.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/60d3b7e9-5ead-4024-b499-a897ce5f3d5e)
260 |
261 |
262 |
263 |
264 |
265 |
266 | 🎢用户界面
267 |
268 | 如果你不擅长代码,我们还为VALL-E X创建了一个用户友好的图形界面。它可以让您轻松地与模型进行交互,使语音克隆和多语言语音合成变得轻而易举。
269 |
270 | 使用以下命令启动用户界面:
271 | ```commandline
272 | python -X utf8 launch-ui.py
273 | ```
274 |
275 |
276 | ## 🛠️ 硬件要求及推理速度
277 |
278 | VALL-E X 可以在CPU或GPU上运行 (`pytorch 2.0+`, CUDA 11.7 ~ CUDA 12.0).
279 |
280 | 若使用GPU运行,你需要至少6GB的显存。
281 |
282 | ## ⚙️ Details
283 |
284 | VALL-E X 与 [Bark](https://github.com/suno-ai/bark), [VALL-E](https://arxiv.org/abs/2301.02111) and [AudioLM](https://arxiv.org/abs/2209.03143)类似, 使用GPT风格的模型以自回归方式预测量化音频token,并由[EnCodec](https://github.com/facebookresearch/encodec)解码.
285 |
286 | 与 [Bark](https://github.com/suno-ai/bark) 相比:
287 | - ✔ **轻量**: 3️⃣ ✖ 更小,
288 | - ✔ **快速**: 4️⃣ ✖ 更快,
289 | - ✔ **中文&日文的更高质量**
290 | - ✔ **跨语言合成时没有外国口音**
291 | - ✔ **开放且易于操作的声音克隆**
292 | - ❌ **支持的语言较少**
293 | - ❌ **没有用于合成音乐及特殊音效的token**
294 |
295 | ### 支持的语言
296 |
297 | | 语言 | 状态 |
298 | |---------| :---: |
299 | | 英语 (en) | ✅ |
300 | | 日语 (ja) | ✅ |
301 | | 中文 (zh) | ✅ |
302 |
303 | ## ❓ FAQ
304 |
305 | #### 在哪里可以下载checkpoint?
306 | * 当您第一次运行程序时,我们使用`wget`将模型下载到`./checkpoints/`目录里。
307 | * 如果第一次运行时下载失败,请从[这里](https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt)手动下载模型,并将文件放在`./checkpoints/`里。
308 |
309 | #### 需要多少显存?
310 | * 6GB 显存(GPU VRAM) - 几乎所有NVIDIA GPU都满足要求.
311 |
312 | #### 为什么模型无法生成长文本?
313 | * 当序列长度增加时,Transformer的计算复杂度呈二次方增长。因此,所有训练音频都保持在22秒以下。请确保音频提示(audio prompt)和生成的音频的总长度小于22秒以确保可接受的性能。
314 |
315 | #### 更多...
316 |
317 | ## 🧠 待办事项
318 | - [x] 添加中文 README
319 | - [x] 长文本生成
320 | - [x] 用Vocos解码器替换Encodec解码器
321 | - [ ] 微调以实现更好的语音自适应
322 | - [ ] 给非python用户的`.bat`脚本
323 | - [ ] 更多...
324 |
325 | ## 🙏 感谢
326 | - [VALL-E X paper](https://arxiv.org/pdf/2303.03926) for the brilliant idea
327 | - [lifeiteng's vall-e](https://github.com/lifeiteng/vall-e) for related training code
328 | - [bark](https://github.com/suno-ai/bark) for the amazing pioneering work in neuro-codec TTS model
329 |
330 | ## ⭐️ 表示出你的支持
331 |
332 | 如果您觉得VALL-E X有趣且有用,请在GitHub上给我们一颗星! ⭐️ 它鼓励我们不断改进模型并添加令人兴奋的功能。
333 |
334 | ## 📜 License
335 |
336 | VALL-E X 使用 [MIT License](./LICENSE).
337 |
338 | ---
339 |
340 | 有问题或需要帮助? 可以随便 [open an issue](https://github.com/Plachtaa/VALL-E-X/issues/new) 或加入我们的 [Discord](https://discord.gg/qCBRmAnTxg)
341 |
342 | Happy voice cloning! 🎤
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VALL-E X: Multilingual Text-to-Speech Synthesis and Voice Cloning 🔊
2 |
3 | ## Fork README
4 | This repository eliminates the cumbersome dependencies of VALL-E-X and allows for fine tuning on custom data sets.
5 | Please refer to the original README as the basic operation has not been changed at all from the original.
6 |
7 | ## Current Accomplishments
8 | The training code worked.
9 | It was possible to train on custom datasets.
10 |
11 | ## How to create and use CustomDataset
12 | ```python
13 | from customs.make_custom_dataset import create_dataset
14 |
15 | '''
16 | How should the data_dir be created?
17 | Place the necessary audio files in data_dir.
18 | Transcription, tokenization, etc. of the audio files are done by the create_dataset function.
19 |
20 | data_dir
21 | ├── bpe_69.json
22 | ├── utt1.wav
23 | ├── utt2.wav
24 | ├── utt3.wav
25 | ......
26 | └── utt{n}.wav
27 | '''
28 |
29 | data_dir = "your data_dir"
30 | create_dataset(data_dir, dataloader_process_only=True)
31 | ```
32 |
33 | # When Training
34 | When training, please specify data_dir for training data and data_dir for validation data as "--train_dir" and "--valid_dir" as arguments on the command line.
35 |
36 |
37 | ## Original README
38 | [](https://discord.gg/qCBRmAnTxg)
39 |
40 | English | [中文](README-ZH.md)
41 |
42 | An open source implementation of Microsoft's [VALL-E X](https://arxiv.org/pdf/2303.03926) zero-shot TTS model.
43 | **We release our trained model to the public for research or application usage.**
44 |
45 | 
46 |
47 | VALL-E X is an amazing multilingual text-to-speech (TTS) model proposed by Microsoft. While Microsoft initially publish in their research paper, they did not release any code or pretrained models. Recognizing the potential and value of this technology, our team took on the challenge to reproduce the results and train our own model. We are glad to share our trained VALL-E X model with the community, allowing everyone to experience the power next-generation TTS! 🎧
48 |
49 |
50 | More details about the model are presented in [model card](./model-card.md).
51 |
52 | ## 📖 Quick Index
53 | * [🚀 Updates](#-updates)
54 | * [📢 Features](#-features)
55 | * [💻 Installation](#-installation)
56 | * [🎧 Demos](#-demos)
57 | * [🐍 Usage](#-usage-in-python)
58 | * [❓ FAQ](#-faq)
59 | * [🧠 TODO](#-todo)
60 |
61 | ## 🚀 Updates
62 | **2023.09.10**
63 | - Added AR decoder batch decoding for more stable generation result.
64 |
65 | **2023.08.30**
66 | - Replaced EnCodec decoder with Vocos decoder, improved audio quality. (Thanks to [@v0xie](https://github.com/v0xie))
67 |
68 | **2023.08.23**
69 | - Added long text generation.
70 |
71 | **2023.08.20**
72 | - Added [Chinese README](README-ZH.md).
73 |
74 | **2023.08.14**
75 | - Pretrained VALL-E X checkpoint is now released. Download it [here](https://drive.google.com/file/d/10gdQWvP-K_e1undkvv0p2b7SU6I4Egyl/view?usp=sharing)
76 |
77 | ## 💻 Installation
78 | ### Install with pip, recommended with Python 3.10, CUDA 11.7 ~ 12.0, PyTorch 2.0+
79 | ```commandline
80 | git clone https://github.com/Plachtaa/VALL-E-X.git
81 | cd VALL-E-X
82 | pip install -r requirements.txt
83 | ```
84 |
85 | > Note: If you want to make prompt, you need to install ffmpeg and add its folder to the environment variable PATH.
86 |
87 | When you run the program for the first time, it will automatically download the corresponding model.
88 |
89 | If the download fails and reports an error, please follow the steps below to manually download the model.
90 |
91 | (Please pay attention to the capitalization of folders)
92 |
93 | 1. Check whether there is a `checkpoints` folder in the installation directory.
94 | If not, manually create a `checkpoints` folder (`./checkpoints/`) in the installation directory.
95 |
96 | 2. Check whether there is a `vallex-checkpoint.pt` file in the `checkpoints` folder.
97 | If not, please manually download the `vallex-checkpoint.pt` file from [here](https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt) and put it in the `checkpoints` folder.
98 |
99 | 3. Check whether there is a `whisper` folder in the installation directory.
100 | If not, manually create a `whisper` folder (`./whisper/`) in the installation directory.
101 |
102 | 4. Check whether there is a `medium.pt` file in the `whisper` folder.
103 | If not, please manually download the `medium.pt` file from [here](https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt) and put it in the `whisper` folder.
104 |
105 | ## 🎧 Demos
106 | Not ready to set up the environment on your local machine just yet? No problem! We've got you covered with our online demos. You can try out VALL-E X directly on Hugging Face or Google Colab, experiencing the model's capabilities hassle-free!
107 |
108 | [](https://huggingface.co/spaces/Plachta/VALL-E-X)
109 | [](https://colab.research.google.com/drive/1yyD_sz531QntLKowMHo-XxorsFBCfKul?usp=sharing)
110 |
111 |
112 | ## 📢 Features
113 |
114 | VALL-E X comes packed with cutting-edge functionalities:
115 |
116 | 1. **Multilingual TTS**: Speak in three languages - English, Chinese, and Japanese - with natural and expressive speech synthesis.
117 |
118 | 2. **Zero-shot Voice Cloning**: Enroll a short 3~10 seconds recording of an unseen speaker, and watch VALL-E X create personalized, high-quality speech that sounds just like them!
119 |
120 |
121 | see example
122 |
123 | [prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/a7baa51d-a53a-41cc-a03d-6970f25fcca7)
124 |
125 |
126 | [output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/b895601a-d126-4138-beff-061aabdc7985)
127 |
128 |
129 |
130 | 3. **Speech Emotion Control**: Experience the power of emotions! VALL-E X can synthesize speech with the same emotion as the acoustic prompt provided, adding an extra layer of expressiveness to your audio.
131 |
132 |
133 | see example
134 |
135 | https://github.com/Plachtaa/VALL-E-X/assets/112609742/56fa9988-925e-4757-82c5-83ecb0df6266
136 |
137 |
138 | https://github.com/Plachtaa/VALL-E-X/assets/112609742/699c47a3-d502-4801-8364-bd89bcc0b8f1
139 |
140 |
141 |
142 | 4. **Zero-shot Cross-Lingual Speech Synthesis**: Take monolingual speakers on a linguistic journey! VALL-E X can produce personalized speech in another language without compromising on fluency or accent. Below is a Japanese speaker talk in Chinese & English. 🇯🇵 🗣
143 |
144 |
145 | see example
146 |
147 | [jp-prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/ea6e2ee4-139a-41b4-837e-0bd04dda6e19)
148 |
149 |
150 | [en-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/db8f9782-923f-425e-ba94-e8c1bd48f207)
151 |
152 |
153 | [zh-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/15829d79-e448-44d3-8965-fafa7a3f8c28)
154 |
155 |
156 |
157 | 5. **Accent Control**: Get creative with accents! VALL-E X allows you to experiment with different accents, like speaking Chinese with an English accent or vice versa. 🇨🇳 💬
158 |
159 |
160 | see example
161 |
162 | [en-prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/f688d7f6-70ef-46ec-b1cc-355c31e78b3b)
163 |
164 |
165 | [zh-accent-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/be59c7ca-b45b-44ca-a30d-4d800c950ccc)
166 |
167 |
168 | [en-accent-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/8b4f4f9b-f299-4ea4-a548-137437b71738)
169 |
170 |
171 |
172 | 6. **Acoustic Environment Maintenance**: No need for perfectly clean audio prompts! VALL-E X adapts to the acoustic environment of the input, making speech generation feel natural and immersive.
173 |
174 |
175 | see example
176 |
177 | [noise-prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/68986d88-abd0-4d1d-96e4-4f893eb9259e)
178 |
179 |
180 | [noise-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/96c4c612-4516-4683-8804-501b70938608)
181 |
182 |
183 |
184 |
185 | Explore our [demo page](https://plachtaa.github.io/) for a lot more examples!
186 |
187 | ## 🐍 Usage in Python
188 |
189 |
190 | 🪑 Basics
191 |
192 | ```python
193 | from utils.generation import SAMPLE_RATE, generate_audio, preload_models
194 | from scipy.io.wavfile import write as write_wav
195 | from IPython.display import Audio
196 |
197 | # download and load all models
198 | preload_models()
199 |
200 | # generate audio from text
201 | text_prompt = """
202 | Hello, my name is Nose. And uh, and I like hamburger. Hahaha... But I also have other interests such as playing tactic toast.
203 | """
204 | audio_array = generate_audio(text_prompt)
205 |
206 | # save audio to disk
207 | write_wav("vallex_generation.wav", SAMPLE_RATE, audio_array)
208 |
209 | # play text in notebook
210 | Audio(audio_array, rate=SAMPLE_RATE)
211 | ```
212 |
213 | [hamburger.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/578d7bbe-cda9-483e-898c-29646edc8f2e)
214 |
215 |
216 |
217 |
218 | 🌎 Foreign Language
219 |
220 | This VALL-E X implementation also supports Chinese and Japanese. All three languages have equally awesome performance!
221 |
222 |
223 | ```python
224 |
225 | text_prompt = """
226 | チュソクは私のお気に入りの祭りです。 私は数日間休んで、友人や家族との時間を過ごすことができます。
227 | """
228 | audio_array = generate_audio(text_prompt)
229 | ```
230 |
231 | [vallex_japanese.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/ee57a688-3e83-4be5-b0fe-019d16eec51c)
232 |
233 | *Note: VALL-E X controls accent perfectly even when synthesizing code-switch text. However, you need to manually denote language of respective sentences (since our g2p tool is rule-base)*
234 | ```python
235 | text_prompt = """
236 | [EN]The Thirty Years' War was a devastating conflict that had a profound impact on Europe.[EN]
237 | [ZH]这是历史的开始。 如果您想听更多,请继续。[ZH]
238 | """
239 | audio_array = generate_audio(text_prompt, language='mix')
240 | ```
241 |
242 | [vallex_codeswitch.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/d8667abf-bd08-499f-a383-a861d852f98a)
243 |
244 |
245 |
246 |
247 | 📼 Voice Presets
248 |
249 | VALL-E X provides tens of speaker voices which you can directly used for inference! Browse all voices in the [code](/presets)
250 |
251 | > VALL-E X tries to match the tone, pitch, emotion and prosody of a given preset. The model also attempts to preserve music, ambient noise, etc.
252 |
253 | ```python
254 | text_prompt = """
255 | I am an innocent boy with a smoky voice. It is a great honor for me to speak at the United Nations today.
256 | """
257 | audio_array = generate_audio(text_prompt, prompt="dingzhen")
258 | ```
259 |
260 | [smoky.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/d3f55732-b1cd-420f-87d6-eab60db14dc5)
261 |
262 |
263 |
264 |
265 | 🎙Voice Cloning
266 |
267 | VALL-E X supports voice cloning! You can make a voice prompt with any person, character or even your own voice, and use it like other voice presets.
268 | To make a voice prompt, you need to provide a speech of 3~10 seconds long, as well as the transcript of the speech.
269 | You can also leave the transcript blank to let the [Whisper](https://github.com/openai/whisper) model to generate the transcript.
270 | > VALL-E X tries to match the tone, pitch, emotion and prosody of a given prompt. The model also attempts to preserve music, ambient noise, etc.
271 |
272 | ```python
273 | from utils.prompt_making import make_prompt
274 |
275 | ### Use given transcript
276 | make_prompt(name="paimon", audio_prompt_path="paimon_prompt.wav",
277 | transcript="Just, what was that? Paimon thought we were gonna get eaten.")
278 |
279 | ### Alternatively, use whisper
280 | make_prompt(name="paimon", audio_prompt_path="paimon_prompt.wav")
281 | ```
282 | Now let's try out the prompt we've just made!
283 | ```python
284 | from utils.generation import SAMPLE_RATE, generate_audio, preload_models
285 | from scipy.io.wavfile import write as write_wav
286 |
287 | # download and load all models
288 | preload_models()
289 |
290 | text_prompt = """
291 | Hey, Traveler, Listen to this, This machine has taken my voice, and now it can talk just like me!
292 | """
293 | audio_array = generate_audio(text_prompt, prompt="paimon")
294 |
295 | write_wav("paimon_cloned.wav", SAMPLE_RATE, audio_array)
296 |
297 | ```
298 |
299 | [paimon_prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/e7922859-9d12-4e2a-8651-e156e4280311)
300 |
301 |
302 | [paimon_cloned.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/60d3b7e9-5ead-4024-b499-a897ce5f3d5e)
303 |
304 |
305 |
306 |
307 |
308 |
309 | 🎢User Interface
310 |
311 | Not comfortable with codes? No problem! We've also created a user-friendly graphical interface for VALL-E X. It allows you to interact with the model effortlessly, making voice cloning and multilingual speech synthesis a breeze.
312 |
313 | You can launch the UI by the following command:
314 | ```commandline
315 | python -X utf8 launch-ui.py
316 | ```
317 |
318 |
319 | ## 🛠️ Hardware and Inference Speed
320 |
321 | VALL-E X works well on both CPU and GPU (`pytorch 2.0+`, CUDA 11.7 and CUDA 12.0).
322 |
323 | A GPU VRAM of 6GB is enough for running VALL-E X without offloading.
324 |
325 | ## ⚙️ Details
326 |
327 | VALL-E X is similar to [Bark](https://github.com/suno-ai/bark), [VALL-E](https://arxiv.org/abs/2301.02111) and [AudioLM](https://arxiv.org/abs/2209.03143), which generates audio in GPT-style by predicting audio tokens quantized by [EnCodec](https://github.com/facebookresearch/encodec).
328 |
329 | Comparing to [Bark](https://github.com/suno-ai/bark):
330 | - ✔ **Light-weighted**: 3️⃣ ✖ smaller,
331 | - ✔ **Efficient**: 4️⃣ ✖ faster,
332 | - ✔ **Better quality on Chinese & Japanese**
333 | - ✔ **Cross-lingual speech without foreign accent**
334 | - ✔ **Easy voice-cloning**
335 | - ❌ **Less languages**
336 | - ❌ **No special tokens for music / sound effects**
337 |
338 | ### Supported Languages
339 |
340 | | Language | Status |
341 | | --- | :---: |
342 | | English (en) | ✅ |
343 | | Japanese (ja) | ✅ |
344 | | Chinese, simplified (zh) | ✅ |
345 |
346 | ## ❓ FAQ
347 |
348 | #### Where can I download the model checkpoint?
349 | * We use `wget` to download the model to directory `./checkpoints/` when you run the program for the first time.
350 | * If the download fails on the first run, please manually download from [this link](https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt), and put the file under directory `./checkpoints/`.
351 |
352 | #### How much VRAM do I need?
353 | * 6GB GPU VRAM - Almost all NVIDIA GPUs satisfy the requirement.
354 |
355 | #### Why the model fails to generate long text?
356 | * Transformer's computation complexity increases quadratically while the sequence length increases. Hence, all training
357 | are kept under 22 seconds. Please make sure the total length of audio prompt and generated audio is less than 22 seconds
358 | to ensure acceptable performance.
359 |
360 |
361 | #### MORE TO BE ADDED...
362 |
363 | ## 🧠 TODO
364 | - [x] Add Chinese README
365 | - [x] Long text generation
366 | - [x] Replace Encodec decoder with Vocos decoder
367 | - [ ] Fine-tuning for better voice adaptation
368 | - [ ] `.bat` scripts for non-python users
369 | - [ ] To be added...
370 |
371 | ## 🙏 Appreciation
372 | - [VALL-E X paper](https://arxiv.org/pdf/2303.03926) for the brilliant idea
373 | - [lifeiteng's vall-e](https://github.com/lifeiteng/vall-e) for related training code
374 | - [bark](https://github.com/suno-ai/bark) for the amazing pioneering work in neuro-codec TTS model
375 |
376 | ## ⭐️ Show Your Support
377 |
378 | If you find VALL-E X interesting and useful, give us a star on GitHub! ⭐️ It encourages us to keep improving the model and adding exciting features.
379 |
380 | ## 📜 License
381 |
382 | VALL-E X is licensed under the [MIT License](./LICENSE).
383 |
384 | ---
385 |
386 | Have questions or need assistance? Feel free to [open an issue](https://github.com/Plachtaa/VALL-E-X/issues/new) or join our [Discord](https://discord.gg/qCBRmAnTxg)
387 |
388 | Happy voice cloning! 🎤
389 |
--------------------------------------------------------------------------------
/customs/make_custom_dataset.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import glob
3 | import torch
4 | import numpy as np
5 | import os
6 | import torchaudio
7 | import soundfile as sf
8 | from utils.g2p.symbols import symbols
9 | from utils.g2p import PhonemeBpeTokenizer
10 | from utils.prompt_making import make_prompt, make_transcript
11 | from data.collation import get_text_token_collater
12 | from data.dataset import create_dataloader
13 |
14 | # Mappings from symbol to numeric ID and vice versa:
15 | _symbol_to_id = {s: i for i, s in enumerate(symbols)}
16 | _id_to_symbol = {i: s for i, s in enumerate(symbols)}
17 | from data.tokenizer import (
18 | AudioTokenizer,
19 | tokenize_audio,
20 | )
21 |
22 | tokenizer_path = "./utils/g2p/bpe_69.json"
23 | tokenizer = PhonemeBpeTokenizer(tokenizer_path)
24 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
25 |
26 | def make_prompts(name, audio_prompt_path, transcript=None):
27 | text_tokenizer = PhonemeBpeTokenizer(tokenizer_path="./utils/g2p/bpe_69.json")
28 | text_collater = get_text_token_collater()
29 | codec = AudioTokenizer(device)
30 | wav_pr, sr = torchaudio.load(audio_prompt_path)
31 | # check length
32 | if wav_pr.size(-1) / sr > 15:
33 | raise ValueError(f"Prompt too long, expect length below 15 seconds, got {wav_pr / sr} seconds.")
34 | if wav_pr.size(0) == 2:
35 | wav_pr = wav_pr.mean(0, keepdim=True)
36 | text_pr, lang_pr = make_transcript(name, wav_pr, sr, transcript)
37 |
38 | # tokenize audio
39 | encoded_frames = tokenize_audio(codec, (wav_pr, sr))
40 | audio_tokens = encoded_frames[0][0].transpose(2, 1).cpu().numpy()
41 |
42 | # tokenize text
43 | phonemes, langs = text_tokenizer.tokenize(text=f"{text_pr}".strip())
44 | text_tokens, enroll_x_lens = text_collater(
45 | [
46 | phonemes
47 | ]
48 | )
49 |
50 | return audio_tokens, text_tokens, langs, text_pr
51 |
52 | def create_dataset(data_dir, dataloader_process_only):
53 | if dataloader_process_only:
54 | h5_output_path=f"{data_dir}/audio_sum.hdf5"
55 | ann_output_path=f"{data_dir}/audio_ann_sum.txt"
56 | #audio_folder = os.path.join(data_dir, 'audio')
57 | audio_paths = glob.glob(f"{data_dir}/*.wav") # Change this to match your audio file extension
58 |
59 | # Create or open an HDF5 file
60 | with h5py.File(h5_output_path, 'w') as h5_file:
61 | # Loop through each audio and text file, assuming they have the same stem
62 | for audio_path in audio_paths:
63 | stem = os.path.splitext(os.path.basename(audio_path))[0]
64 | audio_tokens, text_tokens, langs, text = make_prompts(name=stem, audio_prompt_path=audio_path)
65 |
66 | text_tokens = text_tokens.squeeze(0)
67 | # Create a group for each stem
68 | grp = h5_file.create_group(stem)
69 | # Add audio and text tokens as datasets to the group
70 | grp.create_dataset('audio', data=audio_tokens)
71 | #grp.create_dataset('text', data=text_tokens)
72 |
73 | with open(ann_output_path, 'a', encoding='utf-8') as ann_file:
74 | try:
75 | audio, sample_rate = sf.read(audio_path)
76 | duration = len(audio) / sample_rate
77 | ann_file.write(f'{stem}|{duration}|{langs[0]}|{text}\n') # 改行を追加
78 | print(f"Successfully wrote to {ann_output_path}")
79 | except Exception as e:
80 | print(f"An error occurred: {e}")
81 | else:
82 | dataloader = create_dataloader(data_dir=data_dir)
83 | return dataloader
--------------------------------------------------------------------------------
/customs/ph.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/customs/ph.txt
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | # from .datamodule import *
2 | # from .tokenizer import *
3 | from .collation import *
4 |
--------------------------------------------------------------------------------
/data/collation.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import List, Tuple
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from utils import SymbolTable
8 |
9 |
10 | class TextTokenCollater:
11 | """Collate list of text tokens
12 |
13 | Map sentences to integers. Sentences are padded to equal length.
14 | Beginning and end-of-sequence symbols can be added.
15 |
16 | Example:
17 | >>> token_collater = TextTokenCollater(text_tokens)
18 | >>> tokens_batch, tokens_lens = token_collater(text)
19 |
20 | Returns:
21 | tokens_batch: IntTensor of shape (B, L)
22 | B: batch dimension, number of input sentences
23 | L: length of the longest sentence
24 | tokens_lens: IntTensor of shape (B,)
25 | Length of each sentence after adding and
26 | but before padding.
27 | """
28 |
29 | def __init__(
30 | self,
31 | text_tokens: List[str],
32 | add_eos: bool = True,
33 | add_bos: bool = True,
34 | pad_symbol: str = "",
35 | bos_symbol: str = "",
36 | eos_symbol: str = "",
37 | ):
38 | self.pad_symbol = pad_symbol
39 |
40 | self.add_eos = add_eos
41 | self.add_bos = add_bos
42 |
43 | self.bos_symbol = bos_symbol
44 | self.eos_symbol = eos_symbol
45 |
46 | unique_tokens = (
47 | [pad_symbol]
48 | + ([bos_symbol] if add_bos else [])
49 | + ([eos_symbol] if add_eos else [])
50 | + sorted(text_tokens)
51 | )
52 |
53 | self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
54 | self.idx2token = [token for token in unique_tokens]
55 |
56 | def index(
57 | self, tokens_list: List[str]
58 | ) -> Tuple[torch.Tensor, torch.Tensor]:
59 | seqs, seq_lens = [], []
60 | for tokens in tokens_list:
61 | assert (
62 | all([True if s in self.token2idx else False for s in tokens])
63 | is True
64 | )
65 | seq = (
66 | ([self.bos_symbol] if self.add_bos else [])
67 | + list(tokens)
68 | + ([self.eos_symbol] if self.add_eos else [])
69 | )
70 | seqs.append(seq)
71 | seq_lens.append(len(seq))
72 |
73 | max_len = max(seq_lens)
74 | for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)):
75 | seq.extend([self.pad_symbol] * (max_len - seq_len))
76 |
77 | tokens = torch.from_numpy(
78 | np.array(
79 | [[self.token2idx[token] for token in seq] for seq in seqs],
80 | dtype=np.int64,
81 | )
82 | )
83 | tokens_lens = torch.IntTensor(seq_lens)
84 |
85 | return tokens, tokens_lens
86 |
87 | def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
88 | tokens_seqs = [[p for p in text] for text in texts]
89 | max_len = len(max(tokens_seqs, key=len))
90 |
91 | seqs = [
92 | ([self.bos_symbol] if self.add_bos else [])
93 | + list(seq)
94 | + ([self.eos_symbol] if self.add_eos else [])
95 | + [self.pad_symbol] * (max_len - len(seq))
96 | for seq in tokens_seqs
97 | ]
98 |
99 | tokens_batch = torch.from_numpy(
100 | np.array(
101 | [seq for seq in seqs],
102 | dtype=np.int64,
103 | )
104 | )
105 |
106 | tokens_lens = torch.IntTensor(
107 | [
108 | len(seq) + int(self.add_eos) + int(self.add_bos)
109 | for seq in tokens_seqs
110 | ]
111 | )
112 |
113 | return tokens_batch, tokens_lens
114 |
115 |
116 | def get_text_token_collater() -> TextTokenCollater:
117 | collater = TextTokenCollater(
118 | ['0'], add_bos=False, add_eos=False
119 | )
120 | return collater
121 |
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 (authors: Feiteng Li)
2 | #
3 | # See ../../../../LICENSE for clarification regarding multiple authors
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """
18 | modified from lhoste.dataset.speech_synthesis.py
19 | """
20 |
21 | import torch
22 | import math
23 | import h5py
24 | from tokenizers import Tokenizer
25 | from typing import Union, List
26 | import numpy as np
27 | from tqdm import tqdm
28 | from utils.g2p import PhonemeBpeTokenizer
29 | from data.collation import get_text_token_collater
30 | text_collater = get_text_token_collater()
31 |
32 | _pad = '_'
33 | _punctuation = ',.!?-~…'
34 | _letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
35 | symbols = [_pad] + list(_punctuation) + list(_letters)
36 |
37 | language_dict = {
38 | 'en': 0,
39 | 'zh': 1,
40 | 'ja': 2,
41 | }
42 | def seq2phone(tokens: Union[List, np.ndarray]):
43 | """
44 | Convert tokenized phoneme ID sequence back to phoneme string
45 | :param tokens: phoneme tokens
46 | :return: recovered phoneme sequence
47 | """
48 | phones = "".join([symbols[i] for i in tokens])
49 | return phones
50 |
51 | class DynamicBatchSampler(torch.utils.data.Sampler):
52 | def __init__(self, sampler, num_tokens_fn, num_buckets=100, min_size=0, max_size=1000,
53 | max_tokens=None, max_sentences=None, drop_last=False):
54 | """
55 |
56 | :param sampler:
57 | :param num_tokens_fn: 根据idx返回样本的长度的函数
58 | :param num_buckets: 利用桶原理将相似长度的样本放在一个batchsize中,桶的数量
59 | :param min_size: 最小长度的样本, 小于这个值的样本会被过滤掉。 依据这个值来创建样桶
60 | :param max_size: 最大长度的样本
61 | :param max_sentences: batch_size, 但是这里可以通过max_sentences 和 max_tokens 共同控制最终的大小
62 | """
63 | super(DynamicBatchSampler, self).__init__(sampler)
64 | self.sampler = sampler
65 | self.num_tokens_fn = num_tokens_fn
66 | self.num_buckets = num_buckets
67 |
68 | self.min_size = min_size
69 | self.max_size = max_size
70 |
71 | assert max_size <= max_tokens, "max_size should be smaller than max tokens"
72 | assert max_tokens is not None or max_sentences is not None, \
73 | "max_tokens and max_sentences should not be null at the same time, please specify one parameter at least"
74 | self.max_tokens = max_tokens if max_tokens is not None else float('Inf')
75 | self.max_sentences = max_sentences if max_sentences is not None else float('Inf')
76 | self.drop_last = drop_last
77 |
78 | def set_epoch(self, epoch):
79 | self.sampler.set_epoch(epoch)
80 | def is_batch_full(self, num_tokens, batch):
81 | if len(batch) == 0:
82 | return False
83 | if len(batch) == self.max_sentences:
84 | return True
85 | if num_tokens > self.max_tokens:
86 | return True
87 | return False
88 |
89 | def __iter__(self):
90 | buckets = [[] for _ in range(self.num_buckets)]
91 | sample_len = [0] * self.num_buckets
92 |
93 | for idx in self.sampler:
94 | idx_length = self.num_tokens_fn(idx)
95 | if not (self.min_size <= idx_length <= self.max_size):
96 | print("sentence at index {} of size {} exceeds max_tokens, the sentence is ignored".format(idx, idx_length))
97 | continue
98 |
99 | index_buckets = math.floor((idx_length - self.min_size) / (self.max_size - self.min_size + 1)
100 | * self.num_buckets)
101 | sample_len[index_buckets] = max(sample_len[index_buckets], idx_length)
102 |
103 | num_tokens = (len(buckets[index_buckets]) + 1) * sample_len[index_buckets]
104 | if self.is_batch_full(num_tokens, buckets[index_buckets]):
105 | # yield this batch
106 | yield buckets[index_buckets]
107 | buckets[index_buckets] = []
108 | sample_len[index_buckets] = 0
109 |
110 | buckets[index_buckets].append(idx)
111 |
112 | # process left-over
113 | leftover_batch = []
114 | leftover_sample_len = 0
115 | leftover = [idx for bucket in buckets for idx in bucket]
116 | for idx in leftover:
117 | idx_length = self.num_tokens_fn(idx)
118 | leftover_sample_len = max(leftover_sample_len, idx_length)
119 | num_tokens = (len(leftover_batch) + 1) * leftover_sample_len
120 | if self.is_batch_full(num_tokens, leftover_batch):
121 | yield leftover_batch
122 | leftover_batch = []
123 | leftover_sample_len = 0
124 | leftover_batch.append(idx)
125 |
126 | if len(leftover_batch) > 0 and not self.drop_last:
127 | yield leftover_batch
128 |
129 | def __len__(self):
130 | # we do not know the exactly batch size, so do not call len(dataloader)
131 | pass
132 |
133 |
134 | class AudioDataset(torch.utils.data.Dataset):
135 | def __init__(self, h5_path, ann_path, tokenizer_path):
136 | self.h5_path = h5_path
137 | with open(ann_path, 'r', encoding='utf-8') as f:
138 | lines = f.readlines()
139 | ls = [l.split("|") for l in lines]
140 | ls_T = list(zip(*ls))
141 | #del ls_T[-1]
142 | self.h5_paths, self.durations, self.langs, self.texts = \
143 | list(ls_T[0]), list(ls_T[1]), list(ls_T[2]), list(ls_T[3])
144 | self.durations = [float(dur) for dur in self.durations]
145 | self.tokenizer = PhonemeBpeTokenizer(tokenizer_path)
146 | self._archive = None
147 |
148 | def __len__(self):
149 | return len(self.h5_paths)
150 |
151 | def get_dur(self, idx):
152 | return self.durations[idx]
153 |
154 | @property
155 | def archive(self):
156 | if self._archive is None: # lazy loading here!
157 | self._archive = h5py.File(self.h5_path, "r")
158 | return self._archive
159 | def __getitem__(self, idx):
160 | archive = self.archive
161 | h5_path = self.h5_paths[idx]
162 | sub = archive[h5_path]
163 | audio_tokens = sub['audio'][()]
164 | #phone_tokens = sub['text'][()]
165 | dur = self.durations[idx]
166 | lang = self.langs[idx]
167 | text = self.texts[idx]
168 | # tokenization should be done within dataloader
169 | #phones = seq2phone(phone_tokens)
170 | #phones = phones.replace(" ", "_")
171 | phonemes, langs = self.tokenizer.tokenize(text=f"{text}".strip())
172 | cptpho_tokens, enroll_x_lens = text_collater([phonemes])
173 | cptpho_tokens = cptpho_tokens.squeeze(0)
174 | text_token_lens = enroll_x_lens[0]
175 | '''
176 | if not len(phones):
177 | cptpho_tokens = self.tokenizer.encode(text).ids
178 | else:
179 | cptpho_tokens = self.tokenizer.encode(phones).ids
180 | '''
181 | assert len(cptpho_tokens)
182 | return {
183 | 'utt_id': h5_path,
184 | 'text': text,
185 | 'audio': None,
186 | 'audio_lens': None,
187 | 'audio_features': audio_tokens,
188 | 'audio_features_lens': audio_tokens.shape[1],
189 | 'text_tokens': np.array(cptpho_tokens),
190 | 'text_tokens_lens': text_token_lens,
191 | 'language': language_dict[lang],
192 | }
193 |
194 | def collate(batch):
195 | utt_id_s = [b['utt_id'] for b in batch]
196 | text_s = [b['text'] for b in batch]
197 |
198 | audio_s = [b['audio'] for b in batch]
199 | audio_lens_s = [b['audio_lens'] for b in batch]
200 |
201 | audio_features_lens_s = [b['audio_features_lens'] for b in batch]
202 | # create an empty tensor with maximum audio feature length
203 | audio_features_s = torch.zeros([len(batch), max(audio_features_lens_s), 8], dtype=torch.int64) - 1 # audio pad with -1
204 |
205 | text_tokens_lens_s = [b['text_tokens_lens'] for b in batch]
206 | # create an empty tensor with maximum text tokens length
207 | text_tokens_s = torch.zeros([len(batch), max(text_tokens_lens_s)], dtype=torch.int64) + 3 # [PAD] token id 3
208 |
209 | language_s = [b['language'] for b in batch]
210 |
211 | for i, b in enumerate(batch):
212 | audio_features = b['audio_features']
213 | audio_features_lens = b['audio_features_lens']
214 | audio_features_s[i, :audio_features_lens, :] = torch.LongTensor(audio_features)
215 |
216 | text_tokens = b['text_tokens']
217 | text_tokens_lens = b['text_tokens_lens']
218 | text_tokens_s[i, :text_tokens_lens] = torch.LongTensor(text_tokens)
219 |
220 | batch = {
221 | 'utt_id': utt_id_s,
222 | 'text': text_s,
223 | 'audio': audio_s,
224 | 'audio_lens': audio_lens_s,
225 | 'audio_features': audio_features_s,
226 | 'audio_features_lens': torch.LongTensor(np.array(audio_features_lens_s)),
227 | 'text_tokens': text_tokens_s,
228 | 'text_tokens_lens': torch.LongTensor(np.array(text_tokens_lens_s)),
229 | 'languages': torch.LongTensor(np.array(language_s)),
230 | }
231 | return batch
232 |
233 | def create_dataloader(data_dir="/root/valle/egs/mix", n_gpus=1, rank=0, num_workers=0, num_buckets=10, max_duration=120):
234 | train_dataset = AudioDataset(h5_path=f"{data_dir}/audio_sum.hdf5",
235 | ann_path=f"{data_dir}/audio_ann_sum.txt",
236 | tokenizer_path=f"{data_dir}/bpe_69.json")
237 | ran_sampler = torch.utils.data.distributed.DistributedSampler(
238 | train_dataset,
239 | num_replicas=n_gpus,
240 | rank=rank,
241 | shuffle=True,
242 | )
243 | dynamic_sampler = DynamicBatchSampler(ran_sampler, train_dataset.get_dur, num_buckets=num_buckets, max_size=20,
244 | max_tokens=max_duration)
245 |
246 |
247 | train_loader = torch.utils.data.DataLoader(train_dataset, num_workers=num_workers, collate_fn=collate,
248 | batch_sampler=dynamic_sampler)
249 |
250 | return train_loader
251 |
--------------------------------------------------------------------------------
/data/fbank.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 (authors: Feiteng Li)
2 | #
3 | # See ../../../../LICENSE for clarification regarding multiple authors
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 |
18 | from dataclasses import asdict, dataclass
19 | from typing import Any, Dict, Optional, Union
20 |
21 | import numpy as np
22 | import torch
23 | # from lhotse.features.base import FeatureExtractor
24 | # from lhotse.utils import EPSILON, Seconds, compute_num_frames
25 | from librosa.filters import mel as librosa_mel_fn
26 |
27 |
28 | @dataclass
29 | class BigVGANFbankConfig:
30 | # Spectogram-related part
31 | # Note that frame_length and frame_shift will be converted to milliseconds before torchaudio/Kaldi sees them
32 | frame_length: Seconds = 1024 / 24000.0
33 | frame_shift: Seconds = 256 / 24000.0
34 | remove_dc_offset: bool = True
35 | round_to_power_of_two: bool = True
36 |
37 | # Fbank-related part
38 | low_freq: float = 0.0
39 | high_freq: float = 12000.0
40 | num_mel_bins: int = 100
41 | use_energy: bool = False
42 |
43 | def to_dict(self) -> Dict[str, Any]:
44 | return asdict(self)
45 |
46 | @staticmethod
47 | def from_dict(data: Dict[str, Any]) -> "BigVGANFbankConfig":
48 | return BigVGANFbankConfig(**data)
49 |
50 |
51 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
52 | return torch.log(torch.clamp(x, min=clip_val) * C)
53 |
54 |
55 | def spectral_normalize_torch(magnitudes):
56 | output = dynamic_range_compression_torch(magnitudes)
57 | return output
58 |
59 |
60 | # https://github.com/NVIDIA/BigVGAN
61 | # bigvgan_24khz_100band https://drive.google.com/drive/folders/1EpxX6AsxjCbbk0mmAhE0td6eYiABr8Oz
62 | class BigVGANFbank(FeatureExtractor):
63 | name = "fbank"
64 | config_type = BigVGANFbankConfig
65 |
66 | def __init__(self, config: Optional[Any] = None):
67 | super(BigVGANFbank, self).__init__(config)
68 | sampling_rate = 24000
69 | self.mel_basis = torch.from_numpy(
70 | librosa_mel_fn(
71 | sampling_rate,
72 | 1024,
73 | self.config.num_mel_bins,
74 | self.config.low_freq,
75 | self.config.high_freq,
76 | ).astype(np.float32)
77 | )
78 | self.hann_window = torch.hann_window(1024)
79 |
80 | def _feature_fn(self, samples, **kwargs):
81 | win_length, n_fft = 1024, 1024
82 | hop_size = 256
83 | if True:
84 | sampling_rate = 24000
85 | duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
86 | expected_num_frames = compute_num_frames(
87 | duration=duration,
88 | frame_shift=self.frame_shift,
89 | sampling_rate=sampling_rate,
90 | )
91 | pad_size = (
92 | (expected_num_frames - 1) * hop_size
93 | + win_length
94 | - samples.shape[-1]
95 | )
96 | assert pad_size >= 0
97 |
98 | y = torch.nn.functional.pad(
99 | samples,
100 | (0, pad_size),
101 | mode="constant",
102 | )
103 | else:
104 | y = torch.nn.functional.pad(
105 | samples,
106 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
107 | mode="reflect",
108 | )
109 |
110 | y = y.squeeze(1)
111 |
112 | # complex tensor as default, then use view_as_real for future pytorch compatibility
113 | spec = torch.stft(
114 | y,
115 | n_fft,
116 | hop_length=hop_size,
117 | win_length=win_length,
118 | window=self.hann_window,
119 | center=False,
120 | pad_mode="reflect",
121 | normalized=False,
122 | onesided=True,
123 | return_complex=True,
124 | )
125 | spec = torch.view_as_real(spec)
126 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
127 |
128 | spec = torch.matmul(self.mel_basis, spec)
129 | spec = spectral_normalize_torch(spec)
130 |
131 | return spec.transpose(2, 1).squeeze(0)
132 |
133 | def extract(
134 | self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
135 | ) -> np.ndarray:
136 | assert sampling_rate == 24000
137 | params = asdict(self.config)
138 | params.update({"sample_frequency": sampling_rate, "snip_edges": False})
139 | params["frame_shift"] *= 1000.0
140 | params["frame_length"] *= 1000.0
141 | if not isinstance(samples, torch.Tensor):
142 | samples = torch.from_numpy(samples)
143 | # Torchaudio Kaldi feature extractors expect the channel dimension to be first.
144 | if len(samples.shape) == 1:
145 | samples = samples.unsqueeze(0)
146 | features = self._feature_fn(samples, **params).to(torch.float32)
147 | return features.numpy()
148 |
149 | @property
150 | def frame_shift(self) -> Seconds:
151 | return self.config.frame_shift
152 |
153 | def feature_dim(self, sampling_rate: int) -> int:
154 | return self.config.num_mel_bins
155 |
156 | @staticmethod
157 | def mix(
158 | features_a: np.ndarray,
159 | features_b: np.ndarray,
160 | energy_scaling_factor_b: float,
161 | ) -> np.ndarray:
162 | return np.log(
163 | np.maximum(
164 | # protection against log(0); max with EPSILON is adequate since these are energies (always >= 0)
165 | EPSILON,
166 | np.exp(features_a)
167 | + energy_scaling_factor_b * np.exp(features_b),
168 | )
169 | )
170 |
171 | @staticmethod
172 | def compute_energy(features: np.ndarray) -> float:
173 | return float(np.sum(np.exp(features)))
174 |
175 |
176 | def get_fbank_extractor() -> BigVGANFbank:
177 | return BigVGANFbank(BigVGANFbankConfig())
178 |
179 |
180 | if __name__ == "__main__":
181 | extractor = BigVGANFbank(BigVGANFbankConfig())
182 |
183 | samples = torch.from_numpy(np.random.random([1000]).astype(np.float32))
184 | samples = torch.clip(samples, -1.0, 1.0)
185 | fbank = extractor.extract(samples, 24000.0)
186 | print(f"fbank {fbank.shape}")
187 |
188 | from scipy.io.wavfile import read
189 |
190 | MAX_WAV_VALUE = 32768.0
191 |
192 | sampling_rate, samples = read(
193 | "egs/libritts/prompts/5639_40744_000000_000002.wav"
194 | )
195 | print(f"samples: [{samples.min()}, {samples.max()}]")
196 | fbank = extractor.extract(samples.astype(np.float32) / MAX_WAV_VALUE, 24000)
197 | print(f"fbank {fbank.shape}")
198 |
199 | import matplotlib.pyplot as plt
200 |
201 | _ = plt.figure(figsize=(18, 10))
202 | plt.imshow(
203 | X=fbank.transpose(1, 0),
204 | cmap=plt.get_cmap("jet"),
205 | aspect="auto",
206 | interpolation="nearest",
207 | )
208 | plt.gca().invert_yaxis()
209 | plt.savefig("egs/libritts/prompts/5639_40744_000000_000002.png")
210 | plt.close()
211 |
212 | print("fbank test PASS!")
213 |
--------------------------------------------------------------------------------
/data/input_strategies.py:
--------------------------------------------------------------------------------
1 | import random
2 | from collections import defaultdict
3 | from concurrent.futures import ThreadPoolExecutor
4 | from typing import Tuple, Type
5 |
6 | # from lhotse import CutSet
7 | # from lhotse.dataset.collation import collate_features
8 | # from lhotse.dataset.input_strategies import (
9 | # ExecutorType,
10 | # PrecomputedFeatures,
11 | # _get_executor,
12 | # )
13 | # from lhotse.utils import fastcopy
14 |
15 |
16 | class PromptedFeatures:
17 | def __init__(self, prompts, features):
18 | self.prompts = prompts
19 | self.features = features
20 |
21 | def to(self, device):
22 | return PromptedFeatures(
23 | self.prompts.to(device), self.features.to(device)
24 | )
25 |
26 | def sum(self):
27 | return self.features.sum()
28 |
29 | @property
30 | def ndim(self):
31 | return self.features.ndim
32 |
33 | @property
34 | def data(self):
35 | return (self.prompts, self.features)
36 |
37 |
38 | # class PromptedPrecomputedFeatures(PrecomputedFeatures):
39 | # """
40 | # :class:`InputStrategy` that reads pre-computed features, whose manifests
41 | # are attached to cuts, from disk.
42 | #
43 | # It automatically pads the feature matrices with pre or post feature.
44 | #
45 | # .. automethod:: __call__
46 | # """
47 | #
48 | # def __init__(
49 | # self,
50 | # dataset: str,
51 | # cuts: CutSet,
52 | # num_workers: int = 0,
53 | # executor_type: Type[ExecutorType] = ThreadPoolExecutor,
54 | # ) -> None:
55 | # super(PromptedPrecomputedFeatures, self).__init__(
56 | # num_workers, executor_type
57 | # )
58 | #
59 | # self.utt2neighbors = defaultdict(lambda: [])
60 | #
61 | # if dataset.lower() == "libritts":
62 | # # 909_131041_000013_000002
63 | # # 909_131041_000013_000003
64 | # speaker2utts = defaultdict(lambda: [])
65 | #
66 | # utt2cut = {}
67 | # for cut in cuts:
68 | # speaker = cut.supervisions[0].speaker
69 | # speaker2utts[speaker].append(cut.id)
70 | # utt2cut[cut.id] = cut
71 | #
72 | # for spk in speaker2utts:
73 | # uttids = sorted(speaker2utts[spk])
74 | # # Using the property of sorted keys to find previous utterance
75 | # # The keys has structure speaker_book_x_y e.g. 1089_134691_000004_000001
76 | # if len(uttids) == 1:
77 | # self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]])
78 | # continue
79 | #
80 | # utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1]))
81 | # utt2postutt = dict(zip(uttids[:-1], uttids[1:]))
82 | #
83 | # for utt in utt2prevutt:
84 | # self.utt2neighbors[utt].append(utt2cut[utt2prevutt[utt]])
85 | #
86 | # for utt in utt2postutt:
87 | # self.utt2neighbors[utt].append(utt2cut[utt2postutt[utt]])
88 | # elif dataset.lower() == "ljspeech":
89 | # utt2cut = {}
90 | # uttids = []
91 | # for cut in cuts:
92 | # uttids.append(cut.id)
93 | # utt2cut[cut.id] = cut
94 | #
95 | # if len(uttids) == 1:
96 | # self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]])
97 | # else:
98 | # # Using the property of sorted keys to find previous utterance
99 | # # The keys has structure: LJ001-0010
100 | # utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1]))
101 | # utt2postutt = dict(zip(uttids[:-1], uttids[1:]))
102 | #
103 | # for utt in utt2postutt:
104 | # postutt = utt2postutt[utt]
105 | # if utt[:5] == postutt[:5]:
106 | # self.utt2neighbors[utt].append(utt2cut[postutt])
107 | #
108 | # for utt in utt2prevutt:
109 | # prevutt = utt2prevutt[utt]
110 | # if utt[:5] == prevutt[:5] or not self.utt2neighbors[utt]:
111 | # self.utt2neighbors[utt].append(utt2cut[prevutt])
112 | # else:
113 | # raise ValueError
114 | #
115 | # def __call__(
116 | # self, cuts: CutSet
117 | # ) -> Tuple[PromptedFeatures, PromptedFeatures]:
118 | # """
119 | # Reads the pre-computed features from disk/other storage.
120 | # The returned shape is``(B, T, F) => (batch_size, num_frames, num_features)``.
121 | #
122 | # :return: a tensor with collated features, and a tensor of ``num_frames`` of each cut before padding.
123 | # """
124 | # features, features_lens = collate_features(
125 | # cuts,
126 | # executor=_get_executor(
127 | # self.num_workers, executor_type=self._executor_type
128 | # ),
129 | # )
130 | #
131 | # prompts_cuts = []
132 | # for k, cut in enumerate(cuts):
133 | # prompts_cut = random.choice(self.utt2neighbors[cut.id])
134 | # prompts_cuts.append(fastcopy(prompts_cut, id=f"{cut.id}-{str(k)}"))
135 | #
136 | # mini_duration = min([cut.duration for cut in prompts_cuts] + [3.0])
137 | # # prompts_cuts = CutSet.from_cuts(prompts_cuts).truncate(
138 | # # max_duration=mini_duration,
139 | # # offset_type="random",
140 | # # preserve_id=True,
141 | # # )
142 | # prompts_cuts = CutSet(
143 | # cuts={k: cut for k, cut in enumerate(prompts_cuts)}
144 | # ).truncate(
145 | # max_duration=mini_duration,
146 | # offset_type="random",
147 | # preserve_id=False,
148 | # )
149 | #
150 | # prompts, prompts_lens = collate_features(
151 | # prompts_cuts,
152 | # executor=_get_executor(
153 | # self.num_workers, executor_type=self._executor_type
154 | # ),
155 | # )
156 | #
157 | # return PromptedFeatures(prompts, features), PromptedFeatures(
158 | # prompts_lens, features_lens
159 | # )
160 |
--------------------------------------------------------------------------------
/data/tokenizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2023 (authors: Feiteng Li)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import re
17 | from dataclasses import asdict, dataclass
18 | from typing import Any, Dict, List, Optional, Pattern, Union
19 |
20 | import numpy as np
21 | import torch
22 | import torchaudio
23 | from encodec import EncodecModel
24 | from encodec.utils import convert_audio
25 | from phonemizer.backend import EspeakBackend
26 | from phonemizer.backend.espeak.language_switch import LanguageSwitch
27 | from phonemizer.backend.espeak.words_mismatch import WordMismatch
28 | from phonemizer.punctuation import Punctuation
29 | from phonemizer.separator import Separator
30 |
31 | try:
32 | from pypinyin import Style, pinyin
33 | from pypinyin.style._utils import get_finals, get_initials
34 | except Exception:
35 | pass
36 |
37 |
38 | class PypinyinBackend:
39 | """PypinyinBackend for Chinese. Most codes is referenced from espnet.
40 | There are two types pinyin or initials_finals, one is
41 | just like "ni1 hao3", the other is like "n i1 h ao3".
42 | """
43 |
44 | def __init__(
45 | self,
46 | backend="initials_finals",
47 | punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
48 | ) -> None:
49 | self.backend = backend
50 | self.punctuation_marks = punctuation_marks
51 |
52 | def phonemize(
53 | self, text: List[str], separator: Separator, strip=True, njobs=1
54 | ) -> List[str]:
55 | assert isinstance(text, List)
56 | phonemized = []
57 | for _text in text:
58 | _text = re.sub(" +", " ", _text.strip())
59 | _text = _text.replace(" ", separator.word)
60 | phones = []
61 | if self.backend == "pypinyin":
62 | for n, py in enumerate(
63 | pinyin(
64 | _text, style=Style.TONE3, neutral_tone_with_five=True
65 | )
66 | ):
67 | if all([c in self.punctuation_marks for c in py[0]]):
68 | if len(phones):
69 | assert phones[-1] == separator.syllable
70 | phones.pop(-1)
71 |
72 | phones.extend(list(py[0]))
73 | else:
74 | phones.extend([py[0], separator.syllable])
75 | elif self.backend == "pypinyin_initials_finals":
76 | for n, py in enumerate(
77 | pinyin(
78 | _text, style=Style.TONE3, neutral_tone_with_five=True
79 | )
80 | ):
81 | if all([c in self.punctuation_marks for c in py[0]]):
82 | if len(phones):
83 | assert phones[-1] == separator.syllable
84 | phones.pop(-1)
85 | phones.extend(list(py[0]))
86 | else:
87 | if py[0][-1].isalnum():
88 | initial = get_initials(py[0], strict=False)
89 | if py[0][-1].isdigit():
90 | final = (
91 | get_finals(py[0][:-1], strict=False)
92 | + py[0][-1]
93 | )
94 | else:
95 | final = get_finals(py[0], strict=False)
96 | phones.extend(
97 | [
98 | initial,
99 | separator.phone,
100 | final,
101 | separator.syllable,
102 | ]
103 | )
104 | else:
105 | assert ValueError
106 | else:
107 | raise NotImplementedError
108 | phonemized.append(
109 | "".join(phones).rstrip(f"{separator.word}{separator.syllable}")
110 | )
111 | return phonemized
112 |
113 |
114 | class TextTokenizer:
115 | """Phonemize Text."""
116 |
117 | def __init__(
118 | self,
119 | language="en-us",
120 | backend="espeak",
121 | separator=Separator(word="_", syllable="-", phone="|"),
122 | preserve_punctuation=True,
123 | punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
124 | with_stress: bool = False,
125 | tie: Union[bool, str] = False,
126 | language_switch: LanguageSwitch = "keep-flags",
127 | words_mismatch: WordMismatch = "ignore",
128 | ) -> None:
129 | if backend == "espeak":
130 | phonemizer = EspeakBackend(
131 | language,
132 | punctuation_marks=punctuation_marks,
133 | preserve_punctuation=preserve_punctuation,
134 | with_stress=with_stress,
135 | tie=tie,
136 | language_switch=language_switch,
137 | words_mismatch=words_mismatch,
138 | )
139 | elif backend in ["pypinyin", "pypinyin_initials_finals"]:
140 | phonemizer = PypinyinBackend(
141 | backend=backend,
142 | punctuation_marks=punctuation_marks + separator.word,
143 | )
144 | else:
145 | raise NotImplementedError(f"{backend}")
146 |
147 | self.backend = phonemizer
148 | self.separator = separator
149 |
150 | def to_list(self, phonemized: str) -> List[str]:
151 | fields = []
152 | for word in phonemized.split(self.separator.word):
153 | # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
154 | pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
155 | fields.extend(
156 | [p for p in pp if p != self.separator.phone]
157 | + [self.separator.word]
158 | )
159 | assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
160 | self.separator.phone
161 | )
162 | return fields[:-1]
163 |
164 | def __call__(self, text, strip=True) -> List[List[str]]:
165 | if isinstance(text, str):
166 | text = [text]
167 |
168 | phonemized = self.backend.phonemize(
169 | text, separator=self.separator, strip=strip, njobs=1
170 | )
171 | return [self.to_list(p) for p in phonemized]
172 |
173 |
174 | def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
175 | phonemes = tokenizer([text.strip()])
176 | return phonemes[0] # k2symbols
177 |
178 |
179 | def remove_encodec_weight_norm(model):
180 | from encodec.modules import SConv1d
181 | from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
182 | from torch.nn.utils import remove_weight_norm
183 |
184 | encoder = model.encoder.model
185 | for key in encoder._modules:
186 | if isinstance(encoder._modules[key], SEANetResnetBlock):
187 | remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
188 | block_modules = encoder._modules[key].block._modules
189 | for skey in block_modules:
190 | if isinstance(block_modules[skey], SConv1d):
191 | remove_weight_norm(block_modules[skey].conv.conv)
192 | elif isinstance(encoder._modules[key], SConv1d):
193 | remove_weight_norm(encoder._modules[key].conv.conv)
194 |
195 | decoder = model.decoder.model
196 | for key in decoder._modules:
197 | if isinstance(decoder._modules[key], SEANetResnetBlock):
198 | remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
199 | block_modules = decoder._modules[key].block._modules
200 | for skey in block_modules:
201 | if isinstance(block_modules[skey], SConv1d):
202 | remove_weight_norm(block_modules[skey].conv.conv)
203 | elif isinstance(decoder._modules[key], SConvTranspose1d):
204 | remove_weight_norm(decoder._modules[key].convtr.convtr)
205 | elif isinstance(decoder._modules[key], SConv1d):
206 | remove_weight_norm(decoder._modules[key].conv.conv)
207 |
208 |
209 | class AudioTokenizer:
210 | """EnCodec audio."""
211 |
212 | def __init__(
213 | self,
214 | device: Any = None,
215 | ) -> None:
216 | # Instantiate a pretrained EnCodec model
217 | model = EncodecModel.encodec_model_24khz()
218 | model.set_target_bandwidth(6.0)
219 | remove_encodec_weight_norm(model)
220 |
221 | if not device:
222 | device = torch.device("cpu")
223 | if torch.cuda.is_available():
224 | device = torch.device("cuda:0")
225 |
226 | self._device = device
227 |
228 | self.codec = model.to(device)
229 | self.sample_rate = model.sample_rate
230 | self.channels = model.channels
231 |
232 | @property
233 | def device(self):
234 | return self._device
235 |
236 | def encode(self, wav: torch.Tensor) -> torch.Tensor:
237 | return self.codec.encode(wav.to(self.device))
238 |
239 | def decode(self, frames: torch.Tensor) -> torch.Tensor:
240 | return self.codec.decode(frames)
241 |
242 |
243 | def tokenize_audio(tokenizer: AudioTokenizer, audio):
244 | # Load and pre-process the audio waveform
245 | if isinstance(audio, str):
246 | wav, sr = torchaudio.load(audio)
247 | else:
248 | wav, sr = audio
249 | wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
250 | wav = wav.unsqueeze(0)
251 |
252 | # Extract discrete codes from EnCodec
253 | with torch.no_grad():
254 | encoded_frames = tokenizer.encode(wav)
255 | return encoded_frames
256 |
257 |
258 | # @dataclass
259 | # class AudioTokenConfig:
260 | # frame_shift: Seconds = 320.0 / 24000
261 | # num_quantizers: int = 8
262 | #
263 | # def to_dict(self) -> Dict[str, Any]:
264 | # return asdict(self)
265 | #
266 | # @staticmethod
267 | # def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig":
268 | # return AudioTokenConfig(**data)
269 | #
270 | #
271 | # class AudioTokenExtractor(FeatureExtractor):
272 | # name = "encodec"
273 | # config_type = AudioTokenConfig
274 | #
275 | # def __init__(self, config: Optional[Any] = None):
276 | # super(AudioTokenExtractor, self).__init__(config)
277 | # self.tokenizer = AudioTokenizer()
278 | #
279 | # def extract(
280 | # self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
281 | # ) -> np.ndarray:
282 | # if not isinstance(samples, torch.Tensor):
283 | # samples = torch.from_numpy(samples)
284 | # if sampling_rate != self.tokenizer.sample_rate:
285 | # samples = convert_audio(
286 | # samples,
287 | # sampling_rate,
288 | # self.tokenizer.sample_rate,
289 | # self.tokenizer.channels,
290 | # )
291 | # if len(samples.shape) == 2:
292 | # samples = samples.unsqueeze(0)
293 | # else:
294 | # raise ValueError()
295 | #
296 | # device = self.tokenizer.device
297 | # encoded_frames = self.tokenizer.encode(samples.detach().to(device))
298 | # codes = encoded_frames[0][0] # [B, n_q, T]
299 | # if True:
300 | # duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
301 | # expected_num_frames = compute_num_frames(
302 | # duration=duration,
303 | # frame_shift=self.frame_shift,
304 | # sampling_rate=sampling_rate,
305 | # )
306 | # assert abs(codes.shape[-1] - expected_num_frames) <= 1
307 | # codes = codes[..., :expected_num_frames]
308 | # return codes.cpu().squeeze(0).permute(1, 0).numpy()
309 | #
310 | # @property
311 | # def frame_shift(self) -> Seconds:
312 | # return self.config.frame_shift
313 | #
314 | # def feature_dim(self, sampling_rate: int) -> int:
315 | # return self.config.num_quantizers
316 | #
317 | # def pad_tensor_list(self, tensor_list, device, padding_value=0):
318 | # # 计算每个张量的长度
319 | # lengths = [tensor.shape[0] for tensor in tensor_list]
320 | # # 使用pad_sequence函数进行填充
321 | # tensor_list = [torch.Tensor(t).to(device) for t in tensor_list]
322 | # padded_tensor = torch.nn.utils.rnn.pad_sequence(
323 | # tensor_list, batch_first=True, padding_value=padding_value
324 | # )
325 | # return padded_tensor, lengths
326 | #
327 | # def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray:
328 | # samples = [wav.squeeze() for wav in samples]
329 | # device = self.tokenizer.device
330 | # samples, lengths = self.pad_tensor_list(samples, device)
331 | # samples = samples.unsqueeze(1)
332 | #
333 | # if not isinstance(samples, torch.Tensor):
334 | # samples = torch.from_numpy(samples)
335 | # if len(samples.shape) != 3:
336 | # raise ValueError()
337 | # if sampling_rate != self.tokenizer.sample_rate:
338 | # samples = [
339 | # convert_audio(
340 | # wav,
341 | # sampling_rate,
342 | # self.tokenizer.sample_rate,
343 | # self.tokenizer.channels,
344 | # )
345 | # for wav in samples
346 | # ]
347 | # # Extract discrete codes from EnCodec
348 | # with torch.no_grad():
349 | # encoded_frames = self.tokenizer.encode(samples.detach().to(device))
350 | # encoded_frames = encoded_frames[0][0] # [B, n_q, T]
351 | # batch_codes = []
352 | # for b, length in enumerate(lengths):
353 | # codes = encoded_frames[b]
354 | # duration = round(length / sampling_rate, ndigits=12)
355 | # expected_num_frames = compute_num_frames(
356 | # duration=duration,
357 | # frame_shift=self.frame_shift,
358 | # sampling_rate=sampling_rate,
359 | # )
360 | # batch_codes.append(codes[..., :expected_num_frames])
361 | # return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes]
362 |
363 |
364 | if __name__ == "__main__":
365 | model = EncodecModel.encodec_model_24khz()
366 | model.set_target_bandwidth(6.0)
367 |
368 | samples = torch.from_numpy(np.random.random([4, 1, 1600])).type(
369 | torch.float32
370 | )
371 | codes_raw = model.encode(samples)
372 |
373 | remove_encodec_weight_norm(model)
374 | codes_norm = model.encode(samples)
375 |
376 | assert torch.allclose(codes_raw[0][0], codes_norm[0][0])
377 |
--------------------------------------------------------------------------------
/descriptions.py:
--------------------------------------------------------------------------------
1 | top_md = """
2 | # VALL-E X
3 | VALL-E X can synthesize high-quality personalized speech with only a 3-second enrolled recording of
4 | an unseen speaker as an acoustic prompt, even in another language for a monolingual speaker.
5 | This implementation supports zero-shot, mono-lingual/cross-lingual text-to-speech functionality of three languages (English, Chinese, Japanese)
6 | See this [demo](https://plachtaa.github.io/) page for more details.
7 | """
8 |
9 | infer_from_audio_md = """
10 | Upload a speech of 3~10 seconds as the audio prompt and type in the text you'd like to synthesize.
11 | The model will synthesize speech of given text with the same voice of your audio prompt.
12 | The model also tends to preserve the emotion & acoustic environment of your given speech.
13 | For faster inference, please use **"Make prompt"** to get a `.npz` file as the encoded audio prompt, and use it by **"Infer from prompt"**
14 | """
15 |
16 | make_prompt_md = """
17 | Upload a speech of 3~10 seconds as the audio prompt.
18 | Get a `.npz` file as the encoded audio prompt. Use it by **"Infer with prompt"**
19 | """
20 |
21 | infer_from_prompt_md = """
22 | Faster than **"Infer from audio"**.
23 | You need to **"Make prompt"** first, and upload the encoded prompt (a `.npz` file)
24 | """
25 |
26 | long_text_md = """
27 | Very long text is chunked into several sentences, and each sentence is synthesized separately.
28 | Please make a prompt or use a preset prompt to infer long text.
29 | """
30 |
31 | long_text_example = "Just a few years ago, there were no legions of deep learning scientists developing intelligent products and services at major companies and startups. When we entered the field, machine learning did not command headlines in daily newspapers. Our parents had no idea what machine learning was, let alone why we might prefer it to a career in medicine or law. Machine learning was a blue skies academic discipline whose industrial significance was limited to a narrow set of real-world applications, including speech recognition and computer vision. Moreover, many of these applications required so much domain knowledge that they were often regarded as entirely separate areas for which machine learning was one small component. At that time, neural networks—the predecessors of the deep learning methods that we focus on in this book—were generally regarded as outmoded."
--------------------------------------------------------------------------------
/examples.py:
--------------------------------------------------------------------------------
1 | infer_from_audio_examples = [
2 | ["This is how this machine has taken my voice.", 'English', 'no-accent', "prompts/en-2.wav", None, "Wow, look at that! That's no ordinary Teddy bear!"],
3 | ["我喜欢抽电子烟,尤其是锐刻五代。", '中文', 'no-accent', "prompts/zh-1.wav", None, "今天我很荣幸,"],
4 | ["私の声を真似するのはそんなに面白いですか?", '日本語', 'no-accent', "prompts/ja-2.ogg", None, "初めまして、朝武よしのです。"],
5 | ["你可以听得出来我有多困。", '中文', 'no-accent', "prompts/en-1.wav", None, ""],
6 | ["この文は、クロスリンガル合成の例です。", '日本語', 'no-accent', "prompts/zh-2.wav", None, ""],
7 | ["Actually, I can't speak English, but this machine helped me do it.", 'English', 'no-accent', "prompts/ja-1.wav", None, ""],
8 | ]
9 |
10 | make_npz_prompt_examples = [
11 | ["Gem-trader", "prompts/en-2.wav", None, "Wow, look at that! That's no ordinary Teddy bear!"],
12 | ["Ding Zhen", "prompts/zh-1.wav", None, "今天我很荣幸,"],
13 | ["Yoshino", "prompts/ja-2.ogg", None, "初めまして、朝武よしのです。"],
14 | ["Sleepy-woman", "prompts/en-1.wav", None, ""],
15 | ["Yae", "prompts/zh-2.wav", None, ""],
16 | ["Cafe", "prompts/ja-1.wav", None, ""],
17 | ]
18 |
19 | infer_from_prompt_examples = [
20 | ["A prompt contains voice, prosody and emotion information of a certain speaker.", "English", "no-accent", "vctk_1", None],
21 | ["This prompt is made with an audio of three seconds.", "English", "no-accent", "librispeech_1", None],
22 | ["This prompt is made with Chinese speech", "English", "no-accent", "seel", None],
23 | ]
24 |
25 |
--------------------------------------------------------------------------------
/images/vallex_framework.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/images/vallex_framework.jpg
--------------------------------------------------------------------------------
/macros.py:
--------------------------------------------------------------------------------
1 | NUM_LAYERS = 12
2 | NUM_HEAD = 16
3 | N_DIM = 1024
4 | PREFIX_MODE = 1
5 | NUM_QUANTIZERS = 8
6 | SAMPLE_RATE = 24000
7 |
8 | lang2token = {
9 | 'zh': "[ZH]",
10 | 'ja': "[JA]",
11 | "en": "[EN]",
12 | 'mix': "",
13 | }
14 |
15 | lang2code = {
16 | 'zh': 0,
17 | 'ja': 1,
18 | "en": 2,
19 | }
20 |
21 | token2lang = {
22 | '[ZH]': "zh",
23 | '[JA]': "ja",
24 | "[EN]": "en",
25 | "": "mix"
26 | }
27 |
28 | code2lang = {
29 | 0: 'zh',
30 | 1: 'ja',
31 | 2: "en",
32 | }
33 |
34 | langdropdown2token = {
35 | 'English': "[EN]",
36 | '中文': "[ZH]",
37 | '日本語': "[JA]",
38 | 'Mix': "",
39 | }
--------------------------------------------------------------------------------
/model-card.md:
--------------------------------------------------------------------------------
1 | # Model Card: VALL-E X
2 |
3 | **Author**: [Songting](https://github.com/Plachtaa).
4 |
5 | This is the official codebase for running open-sourced VALL-E X.
6 |
7 | The following is additional information about the models released here.
8 |
9 | ## Model Details
10 |
11 | VALL-E X is a series of two transformer models that turn text into audio.
12 |
13 | ### Phoneme to acoustic tokens
14 | - Input: IPAs converted from input text by a rule-based G2P tool.
15 | - Output: tokens from the first codebook of the [EnCodec Codec](https://github.com/facebookresearch/encodec) from facebook
16 |
17 | ### Coarse to fine tokens
18 | - Input: IPAs converted from input text by a rule-based G2P tool & the first codebook from EnCodec
19 | - Output: 8 codebooks from EnCodec
20 |
21 | ### Architecture
22 | | Model | Parameters | Attention | Output Vocab size |
23 | |:------------------------:|:----------:|------------|:-----------------:|
24 | | G2P tool | - | - | 69 |
25 | | Phoneme to coarse tokens | 150 M | Causal | 1x 1,024 |
26 | | Coarse to fine tokens | 150 M | Non-causal | 7x 1,024 |
27 |
28 | ### Release date
29 | August 2023
30 |
31 | ## Broader Implications
32 | We anticipate that this model's text to audio capabilities can be used to improve accessbility tools in a variety of languages.
33 | Straightforward improvements will allow models to run faster than realtime, rendering them useful for applications such as virtual assistants.
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch.nn as nn
4 | # from icefall.utils import AttributeDict, str2bool
5 |
6 | from .macros import (
7 | NUM_AUDIO_TOKENS,
8 | NUM_MEL_BINS,
9 | NUM_SPEAKER_CLASSES,
10 | NUM_TEXT_TOKENS,
11 | SPEAKER_EMBEDDING_DIM,
12 | )
13 | from .transformer import Transformer
14 | from .vallex import VALLE, VALLF
15 | from .visualizer import visualize
16 |
17 |
18 | def add_model_arguments(parser: argparse.ArgumentParser):
19 | parser.add_argument(
20 | "--model-name",
21 | type=str,
22 | default="VALL-E",
23 | help="VALL-E, VALL-F, Transformer.",
24 | )
25 | parser.add_argument(
26 | "--decoder-dim",
27 | type=int,
28 | default=1024,
29 | help="Embedding dimension in the decoder model.",
30 | )
31 | parser.add_argument(
32 | "--nhead",
33 | type=int,
34 | default=16,
35 | help="Number of attention heads in the Decoder layers.",
36 | )
37 | parser.add_argument(
38 | "--num-decoder-layers",
39 | type=int,
40 | default=12,
41 | help="Number of Decoder layers.",
42 | )
43 | parser.add_argument(
44 | "--scale-factor",
45 | type=float,
46 | default=1.0,
47 | help="Model scale factor which will be assigned different meanings in different models.",
48 | )
49 | parser.add_argument(
50 | "--norm-first",
51 | type=bool,
52 | default=True,
53 | help="Pre or Post Normalization.",
54 | )
55 | parser.add_argument(
56 | "--add-prenet",
57 | type=bool,
58 | default=False,
59 | help="Whether add PreNet after Inputs.",
60 | )
61 |
62 | # VALL-E & F
63 | parser.add_argument(
64 | "--prefix-mode",
65 | type=int,
66 | default=1,
67 | help="The mode for how to prefix VALL-E NAR Decoder, "
68 | "0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance.",
69 | )
70 | parser.add_argument(
71 | "--share-embedding",
72 | type=bool,
73 | default=True,
74 | help="Share the parameters of the output projection layer with the parameters of the acoustic embedding.",
75 | )
76 | parser.add_argument(
77 | "--prepend-bos",
78 | type=bool,
79 | default=False,
80 | help="Whether prepend to the acoustic tokens -> AR Decoder inputs.",
81 | )
82 | parser.add_argument(
83 | "--num-quantizers",
84 | type=int,
85 | default=8,
86 | help="Number of Audio/Semantic quantization layers.",
87 | )
88 |
89 | # Transformer
90 | parser.add_argument(
91 | "--scaling-xformers",
92 | type=bool,
93 | default=False,
94 | help="Apply Reworked Conformer scaling on Transformers.",
95 | )
96 |
97 |
98 | def get_model(params) -> nn.Module:
99 | if params.model_name.lower() in ["vall-f", "vallf"]:
100 | model = VALLF(
101 | params.decoder_dim,
102 | params.nhead,
103 | params.num_decoder_layers,
104 | norm_first=params.norm_first,
105 | add_prenet=params.add_prenet,
106 | prefix_mode=params.prefix_mode,
107 | share_embedding=params.share_embedding,
108 | nar_scale_factor=params.scale_factor,
109 | prepend_bos=params.prepend_bos,
110 | num_quantizers=params.num_quantizers,
111 | )
112 | elif params.model_name.lower() in ["vall-e", "valle"]:
113 | model = VALLE(
114 | params.decoder_dim,
115 | params.nhead,
116 | params.num_decoder_layers,
117 | norm_first=params.norm_first,
118 | add_prenet=params.add_prenet,
119 | prefix_mode=params.prefix_mode,
120 | share_embedding=params.share_embedding,
121 | nar_scale_factor=params.scale_factor,
122 | prepend_bos=params.prepend_bos,
123 | num_quantizers=params.num_quantizers,
124 | )
125 | else:
126 | assert params.model_name in ["Transformer"]
127 | model = Transformer(
128 | params.decoder_dim,
129 | params.nhead,
130 | params.num_decoder_layers,
131 | norm_first=params.norm_first,
132 | add_prenet=params.add_prenet,
133 | scaling_xformers=params.scaling_xformers,
134 | )
135 |
136 | return model
137 |
--------------------------------------------------------------------------------
/models/macros.py:
--------------------------------------------------------------------------------
1 | # Text
2 | NUM_TEXT_TOKENS = 2048
3 |
4 | # Audio
5 | NUM_AUDIO_TOKENS = 1024 # EnCodec RVQ bins
6 | NUM_MEL_BINS = 100 # BigVGAN bigvgan_24khz_100band
7 |
8 |
9 | # Speaker
10 | NUM_SPEAKER_CLASSES = 4096
11 | SPEAKER_EMBEDDING_DIM = 64
12 |
--------------------------------------------------------------------------------
/models/transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 (authors: Feiteng Li)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from functools import partial
16 | from typing import Any, Dict, List, Tuple, Union
17 |
18 | import torch
19 | import torch.nn as nn
20 | import torch.nn.functional as F
21 | # from icefall.utils import make_pad_mask
22 | # from torchmetrics.classification import BinaryAccuracy
23 |
24 | from models.vallex import Transpose
25 | from modules.embedding import SinePositionalEmbedding, TokenEmbedding
26 | from modules.scaling import BalancedDoubleSwish, ScaledLinear
27 | from modules.transformer import (
28 | BalancedBasicNorm,
29 | IdentityNorm,
30 | TransformerDecoderLayer,
31 | TransformerEncoder,
32 | TransformerEncoderLayer,
33 | )
34 |
35 | from .macros import NUM_MEL_BINS, NUM_TEXT_TOKENS
36 | from .visualizer import visualize
37 |
38 | IdentityNorm = IdentityNorm
39 |
40 |
41 | class Transformer(nn.Module):
42 | """It implements seq2seq Transformer TTS for debug(No StopPredictor and SpeakerEmbeding)
43 | Neural Speech Synthesis with Transformer Network
44 | https://arxiv.org/abs/1809.08895
45 | """
46 |
47 | def __init__(
48 | self,
49 | d_model: int,
50 | nhead: int,
51 | num_layers: int,
52 | norm_first: bool = True,
53 | add_prenet: bool = False,
54 | scaling_xformers: bool = False,
55 | ):
56 | """
57 | Args:
58 | d_model:
59 | The number of expected features in the input (required).
60 | nhead:
61 | The number of heads in the multiheadattention models (required).
62 | num_layers:
63 | The number of sub-decoder-layers in the decoder (required).
64 | """
65 | super().__init__()
66 | self.text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x
67 |
68 | if add_prenet:
69 | self.encoder_prenet = nn.Sequential(
70 | Transpose(),
71 | nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
72 | nn.BatchNorm1d(d_model),
73 | nn.ReLU(),
74 | nn.Dropout(0.5),
75 | nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
76 | nn.BatchNorm1d(d_model),
77 | nn.ReLU(),
78 | nn.Dropout(0.5),
79 | nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
80 | nn.BatchNorm1d(d_model),
81 | nn.ReLU(),
82 | nn.Dropout(0.5),
83 | Transpose(),
84 | nn.Linear(d_model, d_model),
85 | )
86 |
87 | self.decoder_prenet = nn.Sequential(
88 | nn.Linear(NUM_MEL_BINS, 256),
89 | nn.ReLU(),
90 | nn.Dropout(0.5),
91 | nn.Linear(256, 256),
92 | nn.ReLU(),
93 | nn.Dropout(0.5),
94 | nn.Linear(256, d_model),
95 | )
96 |
97 | assert scaling_xformers is False # TODO: update this block
98 | else:
99 | self.encoder_prenet = nn.Identity()
100 | if scaling_xformers:
101 | self.decoder_prenet = ScaledLinear(NUM_MEL_BINS, d_model)
102 | else:
103 | self.decoder_prenet = nn.Linear(NUM_MEL_BINS, d_model)
104 |
105 | self.encoder_position = SinePositionalEmbedding(
106 | d_model,
107 | dropout=0.1,
108 | scale=False,
109 | )
110 | self.decoder_position = SinePositionalEmbedding(
111 | d_model, dropout=0.1, scale=False
112 | )
113 |
114 | if scaling_xformers:
115 | self.encoder = TransformerEncoder(
116 | TransformerEncoderLayer(
117 | d_model,
118 | nhead,
119 | dim_feedforward=d_model * 4,
120 | dropout=0.1,
121 | batch_first=True,
122 | norm_first=norm_first,
123 | linear1_self_attention_cls=ScaledLinear,
124 | linear2_self_attention_cls=partial(
125 | ScaledLinear, initial_scale=0.01
126 | ),
127 | linear1_feedforward_cls=ScaledLinear,
128 | linear2_feedforward_cls=partial(
129 | ScaledLinear, initial_scale=0.01
130 | ),
131 | activation=partial(
132 | BalancedDoubleSwish,
133 | channel_dim=-1,
134 | max_abs=10.0,
135 | min_prob=0.25,
136 | ),
137 | layer_norm_cls=IdentityNorm,
138 | ),
139 | num_layers=num_layers,
140 | norm=BalancedBasicNorm(d_model) if norm_first else None,
141 | )
142 |
143 | self.decoder = nn.TransformerDecoder(
144 | TransformerDecoderLayer(
145 | d_model,
146 | nhead,
147 | dim_feedforward=d_model * 4,
148 | dropout=0.1,
149 | batch_first=True,
150 | norm_first=norm_first,
151 | linear1_self_attention_cls=ScaledLinear,
152 | linear2_self_attention_cls=partial(
153 | ScaledLinear, initial_scale=0.01
154 | ),
155 | linear1_feedforward_cls=ScaledLinear,
156 | linear2_feedforward_cls=partial(
157 | ScaledLinear, initial_scale=0.01
158 | ),
159 | activation=partial(
160 | BalancedDoubleSwish,
161 | channel_dim=-1,
162 | max_abs=10.0,
163 | min_prob=0.25,
164 | ),
165 | layer_norm_cls=IdentityNorm,
166 | ),
167 | num_layers=num_layers,
168 | norm=BalancedBasicNorm(d_model) if norm_first else None,
169 | )
170 |
171 | self.predict_layer = ScaledLinear(d_model, NUM_MEL_BINS)
172 | self.stop_layer = nn.Linear(d_model, 1)
173 | else:
174 | self.encoder = nn.TransformerEncoder(
175 | nn.TransformerEncoderLayer(
176 | d_model,
177 | nhead,
178 | dim_feedforward=d_model * 4,
179 | activation=F.relu,
180 | dropout=0.1,
181 | batch_first=True,
182 | norm_first=norm_first,
183 | ),
184 | num_layers=num_layers,
185 | norm=nn.LayerNorm(d_model) if norm_first else None,
186 | )
187 |
188 | self.decoder = nn.TransformerDecoder(
189 | nn.TransformerDecoderLayer(
190 | d_model,
191 | nhead,
192 | dim_feedforward=d_model * 4,
193 | activation=F.relu,
194 | dropout=0.1,
195 | batch_first=True,
196 | norm_first=norm_first,
197 | ),
198 | num_layers=num_layers,
199 | norm=nn.LayerNorm(d_model) if norm_first else None,
200 | )
201 |
202 | self.predict_layer = nn.Linear(d_model, NUM_MEL_BINS)
203 | self.stop_layer = nn.Linear(d_model, 1)
204 |
205 | self.stop_accuracy_metric = BinaryAccuracy(
206 | threshold=0.5, multidim_average="global"
207 | )
208 |
209 | # self.apply(self._init_weights)
210 |
211 | # def _init_weights(self, module):
212 | # if isinstance(module, (nn.Linear)):
213 | # module.weight.data.normal_(mean=0.0, std=0.02)
214 | # if isinstance(module, nn.Linear) and module.bias is not None:
215 | # module.bias.data.zero_()
216 | # elif isinstance(module, nn.LayerNorm):
217 | # module.bias.data.zero_()
218 | # module.weight.data.fill_(1.0)
219 | # elif isinstance(module, nn.Embedding):
220 | # module.weight.data.normal_(mean=0.0, std=0.02)
221 |
222 | def forward(
223 | self,
224 | x: torch.Tensor,
225 | x_lens: torch.Tensor,
226 | y: torch.Tensor,
227 | y_lens: torch.Tensor,
228 | reduction: str = "sum",
229 | train_stage: int = 0,
230 | **kwargs,
231 | ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
232 | """
233 | Args:
234 | x:
235 | A 2-D tensor of shape (N, S).
236 | x_lens:
237 | A 1-D tensor of shape (N,). It contains the number of tokens in `x`
238 | before padding.
239 | y:
240 | A 3-D tensor of shape (N, T, 8).
241 | y_lens:
242 | A 1-D tensor of shape (N,). It contains the number of tokens in `x`
243 | before padding.
244 | train_stage:
245 | Not used in this model.
246 | Returns:
247 | Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy.
248 | """
249 | del train_stage
250 |
251 | assert x.ndim == 2, x.shape
252 | assert x_lens.ndim == 1, x_lens.shape
253 | assert y.ndim == 3, y.shape
254 | assert y_lens.ndim == 1, y_lens.shape
255 |
256 | assert torch.all(x_lens > 0)
257 |
258 | # NOTE: x has been padded in TextTokenCollater
259 | x_mask = make_pad_mask(x_lens).to(x.device)
260 |
261 | x = self.text_embedding(x)
262 | x = self.encoder_prenet(x)
263 | x = self.encoder_position(x)
264 | x = self.encoder(x, src_key_padding_mask=x_mask)
265 |
266 | total_loss, metrics = 0.0, {}
267 |
268 | y_mask = make_pad_mask(y_lens).to(y.device)
269 | y_mask_float = y_mask.type(torch.float32)
270 | data_mask = 1.0 - y_mask_float.unsqueeze(-1)
271 |
272 | # Training
273 | # AR Decoder
274 | def pad_y(y):
275 | y = F.pad(y, (0, 0, 1, 0, 0, 0), value=0).detach()
276 | # inputs, targets
277 | return y[:, :-1], y[:, 1:]
278 |
279 | y, targets = pad_y(y * data_mask) # mask padding as zeros
280 |
281 | y_emb = self.decoder_prenet(y)
282 | y_pos = self.decoder_position(y_emb)
283 |
284 | y_len = y_lens.max()
285 | tgt_mask = torch.triu(
286 | torch.ones(y_len, y_len, device=y.device, dtype=torch.bool),
287 | diagonal=1,
288 | )
289 | y_dec = self.decoder(
290 | y_pos,
291 | x,
292 | tgt_mask=tgt_mask,
293 | memory_key_padding_mask=x_mask,
294 | )
295 |
296 | predict = self.predict_layer(y_dec)
297 | # loss
298 | total_loss = F.mse_loss(predict, targets, reduction=reduction)
299 |
300 | logits = self.stop_layer(y_dec).squeeze(-1)
301 | stop_loss = F.binary_cross_entropy_with_logits(
302 | logits,
303 | y_mask_float.detach(),
304 | weight=1.0 + y_mask_float.detach() * 4.0,
305 | reduction=reduction,
306 | )
307 | metrics["stop_loss"] = stop_loss.detach()
308 |
309 | stop_accuracy = self.stop_accuracy_metric(
310 | (torch.sigmoid(logits) >= 0.5).type(torch.int64),
311 | y_mask.type(torch.int64),
312 | )
313 | # icefall MetricsTracker.norm_items()
314 | metrics["stop_accuracy"] = stop_accuracy.item() * y_lens.sum().type(
315 | torch.float32
316 | )
317 |
318 | return ((x, predict), total_loss + 100.0 * stop_loss, metrics)
319 |
320 | def inference(
321 | self,
322 | x: torch.Tensor,
323 | x_lens: torch.Tensor,
324 | y: Any = None,
325 | **kwargs,
326 | ) -> torch.Tensor:
327 | """
328 | Args:
329 | x:
330 | A 2-D tensor of shape (1, S).
331 | x_lens:
332 | A 1-D tensor of shape (1,). It contains the number of tokens in `x`
333 | before padding.
334 | Returns:
335 | Return the predicted audio code matrix and cross-entropy loss.
336 | """
337 | assert x.ndim == 2, x.shape
338 | assert x_lens.ndim == 1, x_lens.shape
339 |
340 | assert torch.all(x_lens > 0)
341 |
342 | x_mask = make_pad_mask(x_lens).to(x.device)
343 |
344 | x = self.text_embedding(x)
345 | x = self.encoder_prenet(x)
346 | x = self.encoder_position(x)
347 | x = self.encoder(x, src_key_padding_mask=x_mask)
348 |
349 | x_mask = make_pad_mask(x_lens).to(x.device)
350 |
351 | # AR Decoder
352 | # TODO: Managing decoder steps avoid repetitive computation
353 | y = torch.zeros(
354 | [x.shape[0], 1, NUM_MEL_BINS], dtype=torch.float32, device=x.device
355 | )
356 | while True:
357 | y_emb = self.decoder_prenet(y)
358 | y_pos = self.decoder_position(y_emb)
359 |
360 | tgt_mask = torch.triu(
361 | torch.ones(
362 | y.shape[1], y.shape[1], device=y.device, dtype=torch.bool
363 | ),
364 | diagonal=1,
365 | )
366 |
367 | y_dec = self.decoder(
368 | y_pos,
369 | x,
370 | tgt_mask=tgt_mask,
371 | memory_mask=None,
372 | memory_key_padding_mask=x_mask,
373 | )
374 | predict = self.predict_layer(y_dec[:, -1:])
375 |
376 | logits = self.stop_layer(y_dec[:, -1:]) > 0 # sigmoid(0.0) = 0.5
377 | if y.shape[1] > x_lens.max() * 10 or all(logits.cpu().numpy()):
378 | print(
379 | f"TransformerTTS EOS [Text {x_lens[0]} -> Audio {y.shape[1]}]"
380 | )
381 | break
382 |
383 | y = torch.concat([y, predict], dim=1)
384 |
385 | return y[:, 1:]
386 |
387 | def visualize(
388 | self,
389 | predicts: Tuple[torch.Tensor],
390 | batch: Dict[str, Union[List, torch.Tensor]],
391 | output_dir: str,
392 | limit: int = 4,
393 | ) -> None:
394 | visualize(predicts, batch, output_dir, limit=limit)
395 |
--------------------------------------------------------------------------------
/models/visualizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2023 (authors: Feiteng Li)
3 | #
4 | # See ../../../../LICENSE for clarification regarding multiple authors
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 |
19 | from typing import Dict, List, Tuple, Union
20 |
21 | import matplotlib.pyplot as plt
22 | import numpy as np
23 | import torch
24 |
25 |
26 | def visualize(
27 | predicts: Tuple[torch.Tensor],
28 | batch: Dict[str, Union[List, torch.Tensor]],
29 | output_dir: str,
30 | limit: int = 4,
31 | ) -> None:
32 | text_tokens = batch["text_tokens"].to("cpu").detach().numpy()
33 | text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy()
34 | audio_features = batch["audio_features"].to("cpu").detach().numpy()
35 | audio_features_lens = (
36 | batch["audio_features_lens"].to("cpu").detach().numpy()
37 | )
38 | assert text_tokens.ndim == 2
39 |
40 | utt_ids, texts = batch["utt_id"], batch["text"]
41 |
42 | encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy()
43 | decoder_outputs = predicts[1]
44 | if isinstance(decoder_outputs, list):
45 | decoder_outputs = decoder_outputs[-1]
46 | decoder_outputs = (
47 | decoder_outputs.to("cpu").type(torch.float32).detach().numpy()
48 | )
49 |
50 | vmin, vmax = 0, 1024 # Encodec
51 | if decoder_outputs.dtype == np.float32:
52 | vmin, vmax = -6, 0 # Fbank
53 |
54 | num_figures = 3
55 | for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])):
56 | _ = plt.figure(figsize=(14, 8 * num_figures))
57 |
58 | S = text_tokens_lens[b]
59 | T = audio_features_lens[b]
60 |
61 | # encoder
62 | plt.subplot(num_figures, 1, 1)
63 | plt.title(f"Text: {text}")
64 | plt.imshow(
65 | X=np.transpose(encoder_outputs[b]),
66 | cmap=plt.get_cmap("jet"),
67 | aspect="auto",
68 | interpolation="nearest",
69 | )
70 | plt.gca().invert_yaxis()
71 | plt.axvline(x=S - 0.4, linewidth=2, color="r")
72 | plt.xlabel("Encoder Output")
73 | plt.colorbar()
74 |
75 | # decoder
76 | plt.subplot(num_figures, 1, 2)
77 | plt.imshow(
78 | X=np.transpose(decoder_outputs[b]),
79 | cmap=plt.get_cmap("jet"),
80 | aspect="auto",
81 | interpolation="nearest",
82 | vmin=vmin,
83 | vmax=vmax,
84 | )
85 | plt.gca().invert_yaxis()
86 | plt.axvline(x=T - 0.4, linewidth=2, color="r")
87 | plt.xlabel("Decoder Output")
88 | plt.colorbar()
89 |
90 | # target
91 | plt.subplot(num_figures, 1, 3)
92 | plt.imshow(
93 | X=np.transpose(audio_features[b]),
94 | cmap=plt.get_cmap("jet"),
95 | aspect="auto",
96 | interpolation="nearest",
97 | vmin=vmin,
98 | vmax=vmax,
99 | )
100 | plt.gca().invert_yaxis()
101 | plt.axvline(x=T - 0.4, linewidth=2, color="r")
102 | plt.xlabel("Decoder Target")
103 | plt.colorbar()
104 |
105 | plt.savefig(f"{output_dir}/{utt_id}.png")
106 | plt.close()
107 |
--------------------------------------------------------------------------------
/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/modules/__init__.py
--------------------------------------------------------------------------------
/modules/embedding.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 (authors: Feiteng Li)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import math
16 |
17 | import torch
18 | import torch.nn as nn
19 |
20 |
21 | class TokenEmbedding(nn.Module):
22 | def __init__(
23 | self,
24 | dim_model: int,
25 | vocab_size: int,
26 | dropout: float = 0.0,
27 | ):
28 | super().__init__()
29 |
30 | self.vocab_size = vocab_size
31 | self.dim_model = dim_model
32 |
33 | self.dropout = torch.nn.Dropout(p=dropout)
34 | self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
35 |
36 | @property
37 | def weight(self) -> torch.Tensor:
38 | return self.word_embeddings.weight
39 |
40 | def embedding(self, index: int) -> torch.Tensor:
41 | return self.word_embeddings.weight[index : index + 1]
42 |
43 | def forward(self, x: torch.Tensor):
44 | X = self.word_embeddings(x)
45 | X = self.dropout(X)
46 |
47 | return X
48 |
49 |
50 | class SinePositionalEmbedding(nn.Module):
51 | def __init__(
52 | self,
53 | dim_model: int,
54 | dropout: float = 0.0,
55 | scale: bool = False,
56 | alpha: bool = False,
57 | ):
58 | super().__init__()
59 | self.dim_model = dim_model
60 | self.x_scale = math.sqrt(dim_model) if scale else 1.0
61 | self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
62 | self.dropout = torch.nn.Dropout(p=dropout)
63 |
64 | self.reverse = False
65 | self.pe = None
66 | self.extend_pe(torch.tensor(0.0).expand(1, 4000))
67 |
68 | def extend_pe(self, x):
69 | """Reset the positional encodings."""
70 | if self.pe is not None:
71 | if self.pe.size(1) >= x.size(1):
72 | if self.pe.dtype != x.dtype or self.pe.device != x.device:
73 | self.pe = self.pe.to(dtype=x.dtype, device=x.device)
74 | return
75 | pe = torch.zeros(x.size(1), self.dim_model)
76 | if self.reverse:
77 | position = torch.arange(
78 | x.size(1) - 1, -1, -1.0, dtype=torch.float32
79 | ).unsqueeze(1)
80 | else:
81 | position = torch.arange(
82 | 0, x.size(1), dtype=torch.float32
83 | ).unsqueeze(1)
84 | div_term = torch.exp(
85 | torch.arange(0, self.dim_model, 2, dtype=torch.float32)
86 | * -(math.log(10000.0) / self.dim_model)
87 | )
88 | pe[:, 0::2] = torch.sin(position * div_term)
89 | pe[:, 1::2] = torch.cos(position * div_term)
90 | pe = pe.unsqueeze(0)
91 | self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
92 |
93 | def forward(self, x: torch.Tensor) -> torch.Tensor:
94 | self.extend_pe(x)
95 | output = x.unsqueeze(-1) if x.ndim == 2 else x
96 | output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
97 | return self.dropout(output)
98 |
--------------------------------------------------------------------------------
/modules/scheduler.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2023 (authors: Feiteng Li)
3 | #
4 | # See ../../../../LICENSE for clarification regarding multiple authors
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 |
19 | import torch
20 |
21 | from modules.optim import Eden
22 |
23 |
24 | def calc_lr(step, dim_embed, warmup_steps):
25 | return dim_embed ** (-0.5) * min(
26 | step ** (-0.5), step * warmup_steps ** (-1.5)
27 | )
28 |
29 |
30 | class NoamScheduler(torch.optim.lr_scheduler._LRScheduler):
31 | def __init__(
32 | self,
33 | base_lr: float,
34 | optimizer: torch.optim.Optimizer,
35 | dim_embed: int,
36 | warmup_steps: int,
37 | last_epoch: int = -1,
38 | verbose: bool = False,
39 | ) -> None:
40 |
41 | self.dim_embed = dim_embed
42 | self.base_lr = base_lr
43 | self.warmup_steps = warmup_steps
44 | self.num_param_groups = len(optimizer.param_groups)
45 |
46 | super().__init__(optimizer, last_epoch, verbose)
47 |
48 | def get_lr(self) -> float:
49 | lr = self.base_lr * calc_lr(
50 | self._step_count, self.dim_embed, self.warmup_steps
51 | )
52 | return [lr] * self.num_param_groups
53 |
54 | def set_step(self, step: int):
55 | self._step_count = step
56 |
57 |
58 | def get_scheduler(params, optimizer):
59 | if params.scheduler_name.lower() == "eden":
60 | scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps)
61 | elif params.scheduler_name.lower() == "noam":
62 | scheduler = NoamScheduler(
63 | params.base_lr,
64 | optimizer,
65 | params.decoder_dim,
66 | warmup_steps=params.warmup_steps,
67 | )
68 | # scheduler.set_step(params.start_batch or params.batch_idx_train)
69 | elif params.scheduler_name.lower() == "cosine":
70 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
71 | params.warmup_steps,
72 | optimizer,
73 | eta_min=params.base_lr,
74 | )
75 | else:
76 | raise NotImplementedError(f"{params.scheduler_name}")
77 |
78 | return scheduler
79 |
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/.DS_Store
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/README:
--------------------------------------------------------------------------------
1 | Pretrained Punkt Models -- Jan Strunk (New version trained after issues 313 and 514 had been corrected)
2 |
3 | Most models were prepared using the test corpora from Kiss and Strunk (2006). Additional models have
4 | been contributed by various people using NLTK for sentence boundary detection.
5 |
6 | For information about how to use these models, please confer the tokenization HOWTO:
7 | http://nltk.googlecode.com/svn/trunk/doc/howto/tokenize.html
8 | and chapter 3.8 of the NLTK book:
9 | http://nltk.googlecode.com/svn/trunk/doc/book/ch03.html#sec-segmentation
10 |
11 | There are pretrained tokenizers for the following languages:
12 |
13 | File Language Source Contents Size of training corpus(in tokens) Model contributed by
14 | =======================================================================================================================================================================
15 | czech.pickle Czech Multilingual Corpus 1 (ECI) Lidove Noviny ~345,000 Jan Strunk / Tibor Kiss
16 | Literarni Noviny
17 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
18 | danish.pickle Danish Avisdata CD-Rom Ver. 1.1. 1995 Berlingske Tidende ~550,000 Jan Strunk / Tibor Kiss
19 | (Berlingske Avisdata, Copenhagen) Weekend Avisen
20 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
21 | dutch.pickle Dutch Multilingual Corpus 1 (ECI) De Limburger ~340,000 Jan Strunk / Tibor Kiss
22 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
23 | english.pickle English Penn Treebank (LDC) Wall Street Journal ~469,000 Jan Strunk / Tibor Kiss
24 | (American)
25 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
26 | estonian.pickle Estonian University of Tartu, Estonia Eesti Ekspress ~359,000 Jan Strunk / Tibor Kiss
27 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
28 | finnish.pickle Finnish Finnish Parole Corpus, Finnish Books and major national ~364,000 Jan Strunk / Tibor Kiss
29 | Text Bank (Suomen Kielen newspapers
30 | Tekstipankki)
31 | Finnish Center for IT Science
32 | (CSC)
33 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
34 | french.pickle French Multilingual Corpus 1 (ECI) Le Monde ~370,000 Jan Strunk / Tibor Kiss
35 | (European)
36 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
37 | german.pickle German Neue Zürcher Zeitung AG Neue Zürcher Zeitung ~847,000 Jan Strunk / Tibor Kiss
38 | (Switzerland) CD-ROM
39 | (Uses "ss"
40 | instead of "ß")
41 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
42 | greek.pickle Greek Efstathios Stamatatos To Vima (TO BHMA) ~227,000 Jan Strunk / Tibor Kiss
43 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
44 | italian.pickle Italian Multilingual Corpus 1 (ECI) La Stampa, Il Mattino ~312,000 Jan Strunk / Tibor Kiss
45 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
46 | norwegian.pickle Norwegian Centre for Humanities Bergens Tidende ~479,000 Jan Strunk / Tibor Kiss
47 | (Bokmål and Information Technologies,
48 | Nynorsk) Bergen
49 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
50 | polish.pickle Polish Polish National Corpus Literature, newspapers, etc. ~1,000,000 Krzysztof Langner
51 | (http://www.nkjp.pl/)
52 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
53 | portuguese.pickle Portuguese CETENFolha Corpus Folha de São Paulo ~321,000 Jan Strunk / Tibor Kiss
54 | (Brazilian) (Linguateca)
55 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
56 | slovene.pickle Slovene TRACTOR Delo ~354,000 Jan Strunk / Tibor Kiss
57 | Slovene Academy for Arts
58 | and Sciences
59 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
60 | spanish.pickle Spanish Multilingual Corpus 1 (ECI) Sur ~353,000 Jan Strunk / Tibor Kiss
61 | (European)
62 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
63 | swedish.pickle Swedish Multilingual Corpus 1 (ECI) Dagens Nyheter ~339,000 Jan Strunk / Tibor Kiss
64 | (and some other texts)
65 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
66 | turkish.pickle Turkish METU Turkish Corpus Milliyet ~333,000 Jan Strunk / Tibor Kiss
67 | (Türkçe Derlem Projesi)
68 | University of Ankara
69 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
70 |
71 | The corpora contained about 400,000 tokens on average and mostly consisted of newspaper text converted to
72 | Unicode using the codecs module.
73 |
74 | Kiss, Tibor and Strunk, Jan (2006): Unsupervised Multilingual Sentence Boundary Detection.
75 | Computational Linguistics 32: 485-525.
76 |
77 | ---- Training Code ----
78 |
79 | # import punkt
80 | import nltk.tokenize.punkt
81 |
82 | # Make a new Tokenizer
83 | tokenizer = nltk.tokenize.punkt.PunktSentenceTokenizer()
84 |
85 | # Read in training corpus (one example: Slovene)
86 | import codecs
87 | text = codecs.open("slovene.plain","Ur","iso-8859-2").read()
88 |
89 | # Train tokenizer
90 | tokenizer.train(text)
91 |
92 | # Dump pickled tokenizer
93 | import pickle
94 | out = open("slovene.pickle","wb")
95 | pickle.dump(tokenizer, out)
96 | out.close()
97 |
98 | ---------
99 |
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/czech.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/czech.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/danish.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/danish.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/dutch.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/dutch.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/english.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/english.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/estonian.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/estonian.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/finnish.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/finnish.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/french.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/french.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/german.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/german.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/greek.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/greek.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/italian.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/italian.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/malayalam.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/malayalam.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/norwegian.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/norwegian.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/polish.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/polish.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/portuguese.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/portuguese.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/russian.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/russian.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/slovene.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/slovene.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/spanish.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/spanish.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/swedish.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/swedish.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/PY3/turkish.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/turkish.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/README:
--------------------------------------------------------------------------------
1 | Pretrained Punkt Models -- Jan Strunk (New version trained after issues 313 and 514 had been corrected)
2 |
3 | Most models were prepared using the test corpora from Kiss and Strunk (2006). Additional models have
4 | been contributed by various people using NLTK for sentence boundary detection.
5 |
6 | For information about how to use these models, please confer the tokenization HOWTO:
7 | http://nltk.googlecode.com/svn/trunk/doc/howto/tokenize.html
8 | and chapter 3.8 of the NLTK book:
9 | http://nltk.googlecode.com/svn/trunk/doc/book/ch03.html#sec-segmentation
10 |
11 | There are pretrained tokenizers for the following languages:
12 |
13 | File Language Source Contents Size of training corpus(in tokens) Model contributed by
14 | =======================================================================================================================================================================
15 | czech.pickle Czech Multilingual Corpus 1 (ECI) Lidove Noviny ~345,000 Jan Strunk / Tibor Kiss
16 | Literarni Noviny
17 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
18 | danish.pickle Danish Avisdata CD-Rom Ver. 1.1. 1995 Berlingske Tidende ~550,000 Jan Strunk / Tibor Kiss
19 | (Berlingske Avisdata, Copenhagen) Weekend Avisen
20 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
21 | dutch.pickle Dutch Multilingual Corpus 1 (ECI) De Limburger ~340,000 Jan Strunk / Tibor Kiss
22 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
23 | english.pickle English Penn Treebank (LDC) Wall Street Journal ~469,000 Jan Strunk / Tibor Kiss
24 | (American)
25 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
26 | estonian.pickle Estonian University of Tartu, Estonia Eesti Ekspress ~359,000 Jan Strunk / Tibor Kiss
27 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
28 | finnish.pickle Finnish Finnish Parole Corpus, Finnish Books and major national ~364,000 Jan Strunk / Tibor Kiss
29 | Text Bank (Suomen Kielen newspapers
30 | Tekstipankki)
31 | Finnish Center for IT Science
32 | (CSC)
33 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
34 | french.pickle French Multilingual Corpus 1 (ECI) Le Monde ~370,000 Jan Strunk / Tibor Kiss
35 | (European)
36 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
37 | german.pickle German Neue Zürcher Zeitung AG Neue Zürcher Zeitung ~847,000 Jan Strunk / Tibor Kiss
38 | (Switzerland) CD-ROM
39 | (Uses "ss"
40 | instead of "ß")
41 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
42 | greek.pickle Greek Efstathios Stamatatos To Vima (TO BHMA) ~227,000 Jan Strunk / Tibor Kiss
43 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
44 | italian.pickle Italian Multilingual Corpus 1 (ECI) La Stampa, Il Mattino ~312,000 Jan Strunk / Tibor Kiss
45 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
46 | norwegian.pickle Norwegian Centre for Humanities Bergens Tidende ~479,000 Jan Strunk / Tibor Kiss
47 | (Bokmål and Information Technologies,
48 | Nynorsk) Bergen
49 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
50 | polish.pickle Polish Polish National Corpus Literature, newspapers, etc. ~1,000,000 Krzysztof Langner
51 | (http://www.nkjp.pl/)
52 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
53 | portuguese.pickle Portuguese CETENFolha Corpus Folha de São Paulo ~321,000 Jan Strunk / Tibor Kiss
54 | (Brazilian) (Linguateca)
55 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
56 | slovene.pickle Slovene TRACTOR Delo ~354,000 Jan Strunk / Tibor Kiss
57 | Slovene Academy for Arts
58 | and Sciences
59 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
60 | spanish.pickle Spanish Multilingual Corpus 1 (ECI) Sur ~353,000 Jan Strunk / Tibor Kiss
61 | (European)
62 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
63 | swedish.pickle Swedish Multilingual Corpus 1 (ECI) Dagens Nyheter ~339,000 Jan Strunk / Tibor Kiss
64 | (and some other texts)
65 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
66 | turkish.pickle Turkish METU Turkish Corpus Milliyet ~333,000 Jan Strunk / Tibor Kiss
67 | (Türkçe Derlem Projesi)
68 | University of Ankara
69 | -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
70 |
71 | The corpora contained about 400,000 tokens on average and mostly consisted of newspaper text converted to
72 | Unicode using the codecs module.
73 |
74 | Kiss, Tibor and Strunk, Jan (2006): Unsupervised Multilingual Sentence Boundary Detection.
75 | Computational Linguistics 32: 485-525.
76 |
77 | ---- Training Code ----
78 |
79 | # import punkt
80 | import nltk.tokenize.punkt
81 |
82 | # Make a new Tokenizer
83 | tokenizer = nltk.tokenize.punkt.PunktSentenceTokenizer()
84 |
85 | # Read in training corpus (one example: Slovene)
86 | import codecs
87 | text = codecs.open("slovene.plain","Ur","iso-8859-2").read()
88 |
89 | # Train tokenizer
90 | tokenizer.train(text)
91 |
92 | # Dump pickled tokenizer
93 | import pickle
94 | out = open("slovene.pickle","wb")
95 | pickle.dump(tokenizer, out)
96 | out.close()
97 |
98 | ---------
99 |
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/czech.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/czech.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/danish.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/danish.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/dutch.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/dutch.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/estonian.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/estonian.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/finnish.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/finnish.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/french.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/french.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/german.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/german.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/italian.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/italian.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/malayalam.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/malayalam.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/norwegian.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/norwegian.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/polish.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/polish.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/portuguese.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/portuguese.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/russian.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/russian.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/slovene.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/slovene.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/spanish.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/spanish.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/swedish.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/swedish.pickle
--------------------------------------------------------------------------------
/nltk_data/tokenizers/punkt/turkish.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/turkish.pickle
--------------------------------------------------------------------------------
/presets/acou_1.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/acou_1.npz
--------------------------------------------------------------------------------
/presets/acou_2.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/acou_2.npz
--------------------------------------------------------------------------------
/presets/acou_3.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/acou_3.npz
--------------------------------------------------------------------------------
/presets/acou_4.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/acou_4.npz
--------------------------------------------------------------------------------
/presets/alan.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/alan.npz
--------------------------------------------------------------------------------
/presets/amused.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/amused.npz
--------------------------------------------------------------------------------
/presets/anger.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/anger.npz
--------------------------------------------------------------------------------
/presets/babara.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/babara.npz
--------------------------------------------------------------------------------
/presets/bronya.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/bronya.npz
--------------------------------------------------------------------------------
/presets/cafe.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/cafe.npz
--------------------------------------------------------------------------------
/presets/dingzhen.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/dingzhen.npz
--------------------------------------------------------------------------------
/presets/disgust.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/disgust.npz
--------------------------------------------------------------------------------
/presets/emo_amused.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/emo_amused.npz
--------------------------------------------------------------------------------
/presets/emo_anger.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/emo_anger.npz
--------------------------------------------------------------------------------
/presets/emo_neutral.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/emo_neutral.npz
--------------------------------------------------------------------------------
/presets/emo_sleepy.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/emo_sleepy.npz
--------------------------------------------------------------------------------
/presets/emotion_sleepiness.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/emotion_sleepiness.npz
--------------------------------------------------------------------------------
/presets/en2zh_tts_1.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/en2zh_tts_1.npz
--------------------------------------------------------------------------------
/presets/en2zh_tts_2.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/en2zh_tts_2.npz
--------------------------------------------------------------------------------
/presets/en2zh_tts_3.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/en2zh_tts_3.npz
--------------------------------------------------------------------------------
/presets/en2zh_tts_4.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/en2zh_tts_4.npz
--------------------------------------------------------------------------------
/presets/esta.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/esta.npz
--------------------------------------------------------------------------------
/presets/fuxuan.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/fuxuan.npz
--------------------------------------------------------------------------------
/presets/librispeech_1.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/librispeech_1.npz
--------------------------------------------------------------------------------
/presets/librispeech_2.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/librispeech_2.npz
--------------------------------------------------------------------------------
/presets/librispeech_3.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/librispeech_3.npz
--------------------------------------------------------------------------------
/presets/librispeech_4.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/librispeech_4.npz
--------------------------------------------------------------------------------
/presets/neutral.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/neutral.npz
--------------------------------------------------------------------------------
/presets/paimon.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/paimon.npz
--------------------------------------------------------------------------------
/presets/rosalia.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/rosalia.npz
--------------------------------------------------------------------------------
/presets/seel.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/seel.npz
--------------------------------------------------------------------------------
/presets/sleepiness.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/sleepiness.npz
--------------------------------------------------------------------------------
/presets/vctk_1.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/vctk_1.npz
--------------------------------------------------------------------------------
/presets/vctk_2.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/vctk_2.npz
--------------------------------------------------------------------------------
/presets/vctk_3.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/vctk_3.npz
--------------------------------------------------------------------------------
/presets/vctk_4.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/vctk_4.npz
--------------------------------------------------------------------------------
/presets/yaesakura.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/yaesakura.npz
--------------------------------------------------------------------------------
/presets/zh2en_tts_1.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/zh2en_tts_1.npz
--------------------------------------------------------------------------------
/presets/zh2en_tts_2.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/zh2en_tts_2.npz
--------------------------------------------------------------------------------
/presets/zh2en_tts_3.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/zh2en_tts_3.npz
--------------------------------------------------------------------------------
/presets/zh2en_tts_4.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/zh2en_tts_4.npz
--------------------------------------------------------------------------------
/prompts/en-1.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/prompts/en-1.wav
--------------------------------------------------------------------------------
/prompts/en-2.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/prompts/en-2.wav
--------------------------------------------------------------------------------
/prompts/ja-1.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/prompts/ja-1.wav
--------------------------------------------------------------------------------
/prompts/ja-2.ogg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/prompts/ja-2.ogg
--------------------------------------------------------------------------------
/prompts/ph.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/prompts/ph.txt
--------------------------------------------------------------------------------
/prompts/zh-1.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/prompts/zh-1.wav
--------------------------------------------------------------------------------
/prompts/zh-2.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/prompts/zh-2.wav
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | soundfile
2 | numpy
3 | torch
4 | torchvision
5 | torchaudio
6 | tokenizers
7 | encodec
8 | langid
9 | wget
10 | unidecode
11 | pyopenjtalk-prebuilt
12 | pypinyin
13 | inflect
14 | cn2an
15 | jieba
16 | eng_to_ipa
17 | openai-whisper
18 | phonemizer==3.2.0
19 | matplotlib
20 | gradio
21 | nltk
22 | sudachipy
23 | sudachidict_core
24 | vocos
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import logging
4 | from data.dataset import create_dataloader
5 | from macros import *
6 | from data.tokenizer import (
7 | AudioTokenizer,
8 | tokenize_audio,
9 | )
10 | from data.collation import get_text_token_collater
11 | from models.vallex import VALLE
12 | if torch.cuda.is_available():
13 | device = torch.device("cuda", 0)
14 | from vocos import Vocos
15 |
16 | def get_model(device):
17 | url = 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'
18 |
19 | checkpoints_dir = "./checkpoints"
20 |
21 | model_checkpoint_name = "vallex-checkpoint_modified.pt"
22 | if not os.path.exists(checkpoints_dir): os.mkdir(checkpoints_dir)
23 | if not os.path.exists(os.path.join(checkpoints_dir, model_checkpoint_name)):
24 | import wget
25 | print("3")
26 | try:
27 | logging.info(
28 | "Downloading model from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt ...")
29 | # download from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt to ./checkpoints/vallex-checkpoint.pt
30 | wget.download("https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt",
31 | out="./checkpoints/vallex-checkpoint.pt", bar=wget.bar_adaptive)
32 | except Exception as e:
33 | logging.info(e)
34 | raise Exception(
35 | "\n Model weights download failed, please go to 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'"
36 | "\n manually download model weights and put it to {} .".format(os.getcwd() + "\checkpoints"))
37 | # VALL-E
38 | model = VALLE(
39 | N_DIM,
40 | NUM_HEAD,
41 | NUM_LAYERS,
42 | norm_first=True,
43 | add_prenet=False,
44 | prefix_mode=PREFIX_MODE,
45 | share_embedding=True,
46 | nar_scale_factor=1.0,
47 | prepend_bos=True,
48 | num_quantizers=NUM_QUANTIZERS,
49 | ).to(device)
50 | checkpoint = torch.load(os.path.join(checkpoints_dir, model_checkpoint_name), map_location='cpu')
51 | missing_keys, unexpected_keys = model.load_state_dict(
52 | checkpoint["model"], strict=True
53 | )
54 | assert not missing_keys
55 |
56 | # Encodec
57 | codec = AudioTokenizer(device)
58 |
59 | vocos = Vocos.from_pretrained('charactr/vocos-encodec-24khz').to(device)
60 |
61 | return model, codec, vocos
--------------------------------------------------------------------------------
/train_utils/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/train_utils/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/train_utils/lhotse/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 | import uuid
5 |
6 | def fix_random_seed(random_seed: int):
7 | """
8 | Set the same random seed for the libraries and modules that Lhotse interacts with.
9 | Includes the ``random`` module, numpy, torch, and ``uuid4()`` function defined in this file.
10 | """
11 | global _lhotse_uuid
12 | random.seed(random_seed)
13 | np.random.seed(random_seed)
14 | torch.random.manual_seed(random_seed)
15 | # Ensure deterministic ID creation
16 | rd = random.Random()
17 | rd.seed(random_seed)
18 | _lhotse_uuid = lambda: uuid.UUID(int=rd.getrandbits(128))
--------------------------------------------------------------------------------
/train_utils/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import logging
4 | from macros import *
5 | from data.tokenizer import (
6 | AudioTokenizer,
7 | tokenize_audio,
8 | )
9 | from models.vallex import VALLE
10 | from vocos import Vocos
11 |
12 | def get_model(device):
13 | url = 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'
14 |
15 | checkpoints_dir = "./checkpoints"
16 |
17 | model_checkpoint_name = "vallex-checkpoint_modified.pt"
18 | if not os.path.exists(checkpoints_dir): os.mkdir(checkpoints_dir)
19 | if not os.path.exists(os.path.join(checkpoints_dir, model_checkpoint_name)):
20 | import wget
21 | print("3")
22 | try:
23 | logging.info(
24 | "Downloading model from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt ...")
25 | # download from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt to ./checkpoints/vallex-checkpoint.pt
26 | wget.download("https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt",
27 | out="./checkpoints/vallex-checkpoint.pt", bar=wget.bar_adaptive)
28 | except Exception as e:
29 | logging.info(e)
30 | raise Exception(
31 | "\n Model weights download failed, please go to 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'"
32 | "\n manually download model weights and put it to {} .".format(os.getcwd() + "\checkpoints"))
33 | # VALL-E
34 | model = VALLE(
35 | N_DIM,
36 | NUM_HEAD,
37 | NUM_LAYERS,
38 | norm_first=True,
39 | add_prenet=False,
40 | prefix_mode=PREFIX_MODE,
41 | share_embedding=True,
42 | nar_scale_factor=1.0,
43 | prepend_bos=True,
44 | num_quantizers=NUM_QUANTIZERS,
45 | ).to(device)
46 | checkpoint = torch.load(os.path.join(checkpoints_dir, model_checkpoint_name), map_location='cpu')
47 | missing_keys, unexpected_keys = model.load_state_dict(
48 | checkpoint["model"], strict=True
49 | )
50 | assert not missing_keys
51 |
52 | # Encodec
53 | codec = AudioTokenizer(device)
54 |
55 | vocos = Vocos.from_pretrained('charactr/vocos-encodec-24khz').to(device)
56 |
57 | return model, codec, vocos
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | # from icefall.utils import make_pad_mask
4 |
5 | from .symbol_table import SymbolTable
6 |
7 | # make_pad_mask = make_pad_mask
8 | SymbolTable = SymbolTable
9 |
10 |
11 | class Transpose(nn.Identity):
12 | """(N, T, D) -> (N, D, T)"""
13 |
14 | def forward(self, input: torch.Tensor) -> torch.Tensor:
15 | return input.transpose(1, 2)
16 |
--------------------------------------------------------------------------------
/utils/download.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import requests
3 |
4 |
5 | def download_file_from_google_drive(id, destination):
6 | URL = "https://docs.google.com/uc?export=download&confirm=1"
7 |
8 | session = requests.Session()
9 |
10 | response = session.get(URL, params={"id": id}, stream=True)
11 | token = get_confirm_token(response)
12 |
13 | if token:
14 | params = {"id": id, "confirm": token}
15 | response = session.get(URL, params=params, stream=True)
16 |
17 | save_response_content(response, destination)
18 |
19 |
20 | def get_confirm_token(response):
21 | for key, value in response.cookies.items():
22 | if key.startswith("download_warning"):
23 | return value
24 |
25 | return None
26 |
27 |
28 | def save_response_content(response, destination):
29 | CHUNK_SIZE = 32768
30 |
31 | with open(destination, "wb", encoding='utf-8') as f:
32 | for chunk in response.iter_content(CHUNK_SIZE):
33 | if chunk: # filter out keep-alive new chunks
34 | f.write(chunk)
35 |
36 |
37 | def main():
38 | if len(sys.argv) >= 3:
39 | file_id = sys.argv[1]
40 | destination = sys.argv[2]
41 | else:
42 | file_id = "TAKE_ID_FROM_SHAREABLE_LINK"
43 | destination = "DESTINATION_FILE_ON_YOUR_DISK"
44 | print(f"dowload {file_id} to {destination}")
45 | download_file_from_google_drive(file_id, destination)
46 |
47 |
48 | if __name__ == "__main__":
49 | main()
--------------------------------------------------------------------------------
/utils/g2p/__init__.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 | import utils.g2p.cleaners
3 | from utils.g2p.symbols import symbols
4 | from tokenizers import Tokenizer
5 |
6 | # Mappings from symbol to numeric ID and vice versa:
7 | _symbol_to_id = {s: i for i, s in enumerate(symbols)}
8 | _id_to_symbol = {i: s for i, s in enumerate(symbols)}
9 |
10 |
11 | class PhonemeBpeTokenizer:
12 | def __init__(self, tokenizer_path = "./utils/g2p/bpe_1024.json"):
13 | self.tokenizer = Tokenizer.from_file(tokenizer_path)
14 |
15 | def tokenize(self, text):
16 | # 1. convert text to phoneme
17 | phonemes, langs = _clean_text(text, ['cje_cleaners'])
18 | # 2. replace blank space " " with "_"
19 | phonemes = phonemes.replace(" ", "_")
20 | # 3. tokenize phonemes
21 | phoneme_tokens = self.tokenizer.encode(phonemes).ids
22 | assert(len(phoneme_tokens) == len(langs))
23 | if not len(phoneme_tokens):
24 | raise ValueError("Empty text is given")
25 | return phoneme_tokens, langs
26 |
27 | def text_to_sequence(text, cleaner_names):
28 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
29 | Args:
30 | text: string to convert to a sequence
31 | cleaner_names: names of the cleaner functions to run the text through
32 | Returns:
33 | List of integers corresponding to the symbols in the text
34 | '''
35 | sequence = []
36 | symbol_to_id = {s: i for i, s in enumerate(symbols)}
37 | clean_text = _clean_text(text, cleaner_names)
38 | for symbol in clean_text:
39 | if symbol not in symbol_to_id.keys():
40 | continue
41 | symbol_id = symbol_to_id[symbol]
42 | sequence += [symbol_id]
43 | return sequence
44 |
45 |
46 | def cleaned_text_to_sequence(cleaned_text):
47 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
48 | Args:
49 | text: string to convert to a sequence
50 | Returns:
51 | List of integers corresponding to the symbols in the text
52 | '''
53 | sequence = [_symbol_to_id[symbol] for symbol in cleaned_text if symbol in _symbol_to_id.keys()]
54 | return sequence
55 |
56 |
57 | def sequence_to_text(sequence):
58 | '''Converts a sequence of IDs back to a string'''
59 | result = ''
60 | for symbol_id in sequence:
61 | s = _id_to_symbol[symbol_id]
62 | result += s
63 | return result
64 |
65 |
66 | def _clean_text(text, cleaner_names):
67 | for name in cleaner_names:
68 | cleaner = getattr(utils.g2p.cleaners, name)
69 | if not cleaner:
70 | raise Exception('Unknown cleaner: %s' % name)
71 | text, langs = cleaner(text)
72 | return text, langs
73 |
--------------------------------------------------------------------------------
/utils/g2p/bpe_69.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "1.0",
3 | "truncation": null,
4 | "padding": null,
5 | "added_tokens": [
6 | {
7 | "id": 0,
8 | "content": "[UNK]",
9 | "single_word": false,
10 | "lstrip": false,
11 | "rstrip": false,
12 | "normalized": false,
13 | "special": true
14 | },
15 | {
16 | "id": 1,
17 | "content": "[CLS]",
18 | "single_word": false,
19 | "lstrip": false,
20 | "rstrip": false,
21 | "normalized": false,
22 | "special": true
23 | },
24 | {
25 | "id": 2,
26 | "content": "[SEP]",
27 | "single_word": false,
28 | "lstrip": false,
29 | "rstrip": false,
30 | "normalized": false,
31 | "special": true
32 | },
33 | {
34 | "id": 3,
35 | "content": "[PAD]",
36 | "single_word": false,
37 | "lstrip": false,
38 | "rstrip": false,
39 | "normalized": false,
40 | "special": true
41 | },
42 | {
43 | "id": 4,
44 | "content": "[MASK]",
45 | "single_word": false,
46 | "lstrip": false,
47 | "rstrip": false,
48 | "normalized": false,
49 | "special": true
50 | }
51 | ],
52 | "normalizer": null,
53 | "pre_tokenizer": {
54 | "type": "Whitespace"
55 | },
56 | "post_processor": null,
57 | "decoder": null,
58 | "model": {
59 | "type": "BPE",
60 | "dropout": null,
61 | "unk_token": "[UNK]",
62 | "continuing_subword_prefix": null,
63 | "end_of_word_suffix": null,
64 | "fuse_unk": false,
65 | "byte_fallback": false,
66 | "vocab": {
67 | "[UNK]": 0,
68 | "[CLS]": 1,
69 | "[SEP]": 2,
70 | "[PAD]": 3,
71 | "[MASK]": 4,
72 | "!": 5,
73 | "#": 6,
74 | "*": 7,
75 | ",": 8,
76 | "-": 9,
77 | ".": 10,
78 | "=": 11,
79 | "?": 12,
80 | "N": 13,
81 | "Q": 14,
82 | "^": 15,
83 | "_": 16,
84 | "`": 17,
85 | "a": 18,
86 | "b": 19,
87 | "d": 20,
88 | "e": 21,
89 | "f": 22,
90 | "g": 23,
91 | "h": 24,
92 | "i": 25,
93 | "j": 26,
94 | "k": 27,
95 | "l": 28,
96 | "m": 29,
97 | "n": 30,
98 | "o": 31,
99 | "p": 32,
100 | "s": 33,
101 | "t": 34,
102 | "u": 35,
103 | "v": 36,
104 | "w": 37,
105 | "x": 38,
106 | "y": 39,
107 | "z": 40,
108 | "~": 41,
109 | "æ": 42,
110 | "ç": 43,
111 | "ð": 44,
112 | "ŋ": 45,
113 | "ɑ": 46,
114 | "ɔ": 47,
115 | "ə": 48,
116 | "ɛ": 49,
117 | "ɥ": 50,
118 | "ɪ": 51,
119 | "ɫ": 52,
120 | "ɯ": 53,
121 | "ɸ": 54,
122 | "ɹ": 55,
123 | "ɾ": 56,
124 | "ʃ": 57,
125 | "ʊ": 58,
126 | "ʑ": 59,
127 | "ʒ": 60,
128 | "ʰ": 61,
129 | "ˈ": 62,
130 | "ˌ": 63,
131 | "θ": 64,
132 | "…": 65,
133 | "⁼": 66,
134 | "↑": 67,
135 | "→": 68,
136 | "↓": 69
137 | },
138 | "merges": [
139 | ]
140 | }
141 | }
--------------------------------------------------------------------------------
/utils/g2p/cleaners.py:
--------------------------------------------------------------------------------
1 | import re
2 | from utils.g2p.japanese import japanese_to_romaji_with_accent, japanese_to_ipa, japanese_to_ipa2, japanese_to_ipa3
3 | from utils.g2p.mandarin import number_to_chinese, chinese_to_bopomofo, latin_to_bopomofo, chinese_to_romaji, chinese_to_lazy_ipa, chinese_to_ipa, chinese_to_ipa2
4 | from utils.g2p.english import english_to_lazy_ipa, english_to_ipa2, english_to_lazy_ipa2
5 | patterns = [r'\[EN\](.*?)\[EN\]', r'\[ZH\](.*?)\[ZH\]', r'\[JA\](.*?)\[JA\]']
6 | def japanese_cleaners(text):
7 | text = japanese_to_romaji_with_accent(text)
8 | text = re.sub(r'([A-Za-z])$', r'\1.', text)
9 | return text
10 |
11 | def japanese_cleaners2(text):
12 | return japanese_cleaners(text).replace('ts', 'ʦ').replace('...', '…')
13 |
14 | def chinese_cleaners(text):
15 | '''Pipeline for Chinese text'''
16 | text = number_to_chinese(text)
17 | text = chinese_to_bopomofo(text)
18 | text = latin_to_bopomofo(text)
19 | text = re.sub(r'([ˉˊˇˋ˙])$', r'\1。', text)
20 | return text
21 |
22 | def cje_cleaners(text):
23 | matches = []
24 | for pattern in patterns:
25 | matches.extend(re.finditer(pattern, text))
26 |
27 | matches.sort(key=lambda x: x.start()) # Sort matches by their start positions
28 |
29 | outputs = ""
30 | output_langs = []
31 |
32 | for match in matches:
33 | text_segment = text[match.start():match.end()]
34 | phon = clean_one(text_segment)
35 | if "[EN]" in text_segment:
36 | lang = 'en'
37 | elif "[ZH]" in text_segment:
38 | lang = 'zh'
39 | elif "[JA]" in text_segment:
40 | lang = 'ja'
41 | else:
42 | raise ValueError("If you see this error, please report this bug to issues.")
43 | outputs += phon
44 | output_langs += [lang] * len(phon)
45 | assert len(outputs) == len(output_langs)
46 | return outputs, output_langs
47 |
48 |
49 | def clean_one(text):
50 | if text.find('[ZH]') != -1:
51 | text = re.sub(r'\[ZH\](.*?)\[ZH\]',
52 | lambda x: chinese_to_ipa(x.group(1))+' ', text)
53 | if text.find('[JA]') != -1:
54 | text = re.sub(r'\[JA\](.*?)\[JA\]',
55 | lambda x: japanese_to_ipa2(x.group(1))+' ', text)
56 | if text.find('[EN]') != -1:
57 | text = re.sub(r'\[EN\](.*?)\[EN\]',
58 | lambda x: english_to_ipa2(x.group(1))+' ', text)
59 | text = re.sub(r'\s+$', '', text)
60 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
61 | return text
62 |
--------------------------------------------------------------------------------
/utils/g2p/english.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | '''
4 | Cleaners are transformations that run over the input text at both training and eval time.
5 |
6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8 | 1. "english_cleaners" for English text
9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12 | the symbols in symbols.py to match your data).
13 | '''
14 |
15 |
16 | # Regular expression matching whitespace:
17 |
18 |
19 | import re
20 | from unidecode import unidecode
21 | import inflect
22 | _inflect = inflect.engine()
23 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
24 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
25 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
26 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
27 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
28 | _number_re = re.compile(r'[0-9]+')
29 |
30 | # List of (regular expression, replacement) pairs for abbreviations:
31 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
32 | ('mrs', 'misess'),
33 | ('mr', 'mister'),
34 | ('dr', 'doctor'),
35 | ('st', 'saint'),
36 | ('co', 'company'),
37 | ('jr', 'junior'),
38 | ('maj', 'major'),
39 | ('gen', 'general'),
40 | ('drs', 'doctors'),
41 | ('rev', 'reverend'),
42 | ('lt', 'lieutenant'),
43 | ('hon', 'honorable'),
44 | ('sgt', 'sergeant'),
45 | ('capt', 'captain'),
46 | ('esq', 'esquire'),
47 | ('ltd', 'limited'),
48 | ('col', 'colonel'),
49 | ('ft', 'fort'),
50 | ]]
51 |
52 |
53 | # List of (ipa, lazy ipa) pairs:
54 | _lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
55 | ('r', 'ɹ'),
56 | ('æ', 'e'),
57 | ('ɑ', 'a'),
58 | ('ɔ', 'o'),
59 | ('ð', 'z'),
60 | ('θ', 's'),
61 | ('ɛ', 'e'),
62 | ('ɪ', 'i'),
63 | ('ʊ', 'u'),
64 | ('ʒ', 'ʥ'),
65 | ('ʤ', 'ʥ'),
66 | ('ˈ', '↓'),
67 | ]]
68 |
69 | # List of (ipa, lazy ipa2) pairs:
70 | _lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
71 | ('r', 'ɹ'),
72 | ('ð', 'z'),
73 | ('θ', 's'),
74 | ('ʒ', 'ʑ'),
75 | ('ʤ', 'dʑ'),
76 | ('ˈ', '↓'),
77 | ]]
78 |
79 | # List of (ipa, ipa2) pairs
80 | _ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
81 | ('r', 'ɹ'),
82 | ('ʤ', 'dʒ'),
83 | ('ʧ', 'tʃ')
84 | ]]
85 |
86 |
87 | def expand_abbreviations(text):
88 | for regex, replacement in _abbreviations:
89 | text = re.sub(regex, replacement, text)
90 | return text
91 |
92 |
93 | def collapse_whitespace(text):
94 | return re.sub(r'\s+', ' ', text)
95 |
96 |
97 | def _remove_commas(m):
98 | return m.group(1).replace(',', '')
99 |
100 |
101 | def _expand_decimal_point(m):
102 | return m.group(1).replace('.', ' point ')
103 |
104 |
105 | def _expand_dollars(m):
106 | match = m.group(1)
107 | parts = match.split('.')
108 | if len(parts) > 2:
109 | return match + ' dollars' # Unexpected format
110 | dollars = int(parts[0]) if parts[0] else 0
111 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
112 | if dollars and cents:
113 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
114 | cent_unit = 'cent' if cents == 1 else 'cents'
115 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
116 | elif dollars:
117 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
118 | return '%s %s' % (dollars, dollar_unit)
119 | elif cents:
120 | cent_unit = 'cent' if cents == 1 else 'cents'
121 | return '%s %s' % (cents, cent_unit)
122 | else:
123 | return 'zero dollars'
124 |
125 |
126 | def _expand_ordinal(m):
127 | return _inflect.number_to_words(m.group(0))
128 |
129 |
130 | def _expand_number(m):
131 | num = int(m.group(0))
132 | if num > 1000 and num < 3000:
133 | if num == 2000:
134 | return 'two thousand'
135 | elif num > 2000 and num < 2010:
136 | return 'two thousand ' + _inflect.number_to_words(num % 100)
137 | elif num % 100 == 0:
138 | return _inflect.number_to_words(num // 100) + ' hundred'
139 | else:
140 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
141 | else:
142 | return _inflect.number_to_words(num, andword='')
143 |
144 |
145 | def normalize_numbers(text):
146 | text = re.sub(_comma_number_re, _remove_commas, text)
147 | text = re.sub(_pounds_re, r'\1 pounds', text)
148 | text = re.sub(_dollars_re, _expand_dollars, text)
149 | text = re.sub(_decimal_number_re, _expand_decimal_point, text)
150 | text = re.sub(_ordinal_re, _expand_ordinal, text)
151 | text = re.sub(_number_re, _expand_number, text)
152 | return text
153 |
154 |
155 | def mark_dark_l(text):
156 | return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text)
157 |
158 |
159 | def english_to_ipa(text):
160 | import eng_to_ipa as ipa
161 | text = unidecode(text).lower()
162 | text = expand_abbreviations(text)
163 | text = normalize_numbers(text)
164 | phonemes = ipa.convert(text)
165 | phonemes = collapse_whitespace(phonemes)
166 | return phonemes
167 |
168 |
169 | def english_to_lazy_ipa(text):
170 | text = english_to_ipa(text)
171 | for regex, replacement in _lazy_ipa:
172 | text = re.sub(regex, replacement, text)
173 | return text
174 |
175 |
176 | def english_to_ipa2(text):
177 | text = english_to_ipa(text)
178 | text = mark_dark_l(text)
179 | for regex, replacement in _ipa_to_ipa2:
180 | text = re.sub(regex, replacement, text)
181 | return text.replace('...', '…')
182 |
183 |
184 | def english_to_lazy_ipa2(text):
185 | text = english_to_ipa(text)
186 | for regex, replacement in _lazy_ipa2:
187 | text = re.sub(regex, replacement, text)
188 | return text
189 |
--------------------------------------------------------------------------------
/utils/g2p/japanese.py:
--------------------------------------------------------------------------------
1 | import re
2 | from unidecode import unidecode
3 |
4 |
5 |
6 | # Regular expression matching Japanese without punctuation marks:
7 | _japanese_characters = re.compile(
8 | r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
9 |
10 | # Regular expression matching non-Japanese characters or punctuation marks:
11 | _japanese_marks = re.compile(
12 | r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
13 |
14 | # List of (symbol, Japanese) pairs for marks:
15 | _symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [
16 | ('%', 'パーセント')
17 | ]]
18 |
19 | # List of (romaji, ipa) pairs for marks:
20 | _romaji_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
21 | ('ts', 'ʦ'),
22 | ('u', 'ɯ'),
23 | ('j', 'ʥ'),
24 | ('y', 'j'),
25 | ('ni', 'n^i'),
26 | ('nj', 'n^'),
27 | ('hi', 'çi'),
28 | ('hj', 'ç'),
29 | ('f', 'ɸ'),
30 | ('I', 'i*'),
31 | ('U', 'ɯ*'),
32 | ('r', 'ɾ')
33 | ]]
34 |
35 | # List of (romaji, ipa2) pairs for marks:
36 | _romaji_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
37 | ('u', 'ɯ'),
38 | ('ʧ', 'tʃ'),
39 | ('j', 'dʑ'),
40 | ('y', 'j'),
41 | ('ni', 'n^i'),
42 | ('nj', 'n^'),
43 | ('hi', 'çi'),
44 | ('hj', 'ç'),
45 | ('f', 'ɸ'),
46 | ('I', 'i*'),
47 | ('U', 'ɯ*'),
48 | ('r', 'ɾ')
49 | ]]
50 |
51 | # List of (consonant, sokuon) pairs:
52 | _real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [
53 | (r'Q([↑↓]*[kg])', r'k#\1'),
54 | (r'Q([↑↓]*[tdjʧ])', r't#\1'),
55 | (r'Q([↑↓]*[sʃ])', r's\1'),
56 | (r'Q([↑↓]*[pb])', r'p#\1')
57 | ]]
58 |
59 | # List of (consonant, hatsuon) pairs:
60 | _real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [
61 | (r'N([↑↓]*[pbm])', r'm\1'),
62 | (r'N([↑↓]*[ʧʥj])', r'n^\1'),
63 | (r'N([↑↓]*[tdn])', r'n\1'),
64 | (r'N([↑↓]*[kg])', r'ŋ\1')
65 | ]]
66 |
67 |
68 | def symbols_to_japanese(text):
69 | for regex, replacement in _symbols_to_japanese:
70 | text = re.sub(regex, replacement, text)
71 | return text
72 |
73 |
74 | def japanese_to_romaji_with_accent(text):
75 | '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html'''
76 | import pyopenjtalk
77 | text = symbols_to_japanese(text)
78 | sentences = re.split(_japanese_marks, text)
79 | marks = re.findall(_japanese_marks, text)
80 | text = ''
81 | for i, sentence in enumerate(sentences):
82 | if re.match(_japanese_characters, sentence):
83 | if text != '':
84 | text += ' '
85 | labels = pyopenjtalk.extract_fullcontext(sentence)
86 | for n, label in enumerate(labels):
87 | phoneme = re.search(r'\-([^\+]*)\+', label).group(1)
88 | if phoneme not in ['sil', 'pau']:
89 | text += phoneme.replace('ch', 'ʧ').replace('sh',
90 | 'ʃ').replace('cl', 'Q')
91 | else:
92 | continue
93 | # n_moras = int(re.search(r'/F:(\d+)_', label).group(1))
94 | a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1))
95 | a2 = int(re.search(r"\+(\d+)\+", label).group(1))
96 | a3 = int(re.search(r"\+(\d+)/", label).group(1))
97 | if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil', 'pau']:
98 | a2_next = -1
99 | else:
100 | a2_next = int(
101 | re.search(r"\+(\d+)\+", labels[n + 1]).group(1))
102 | # Accent phrase boundary
103 | if a3 == 1 and a2_next == 1:
104 | text += ' '
105 | # Falling
106 | elif a1 == 0 and a2_next == a2 + 1:
107 | text += '↓'
108 | # Rising
109 | elif a2 == 1 and a2_next == 2:
110 | text += '↑'
111 | if i < len(marks):
112 | text += unidecode(marks[i]).replace(' ', '')
113 | return text
114 |
115 |
116 | def get_real_sokuon(text):
117 | for regex, replacement in _real_sokuon:
118 | text = re.sub(regex, replacement, text)
119 | return text
120 |
121 |
122 | def get_real_hatsuon(text):
123 | for regex, replacement in _real_hatsuon:
124 | text = re.sub(regex, replacement, text)
125 | return text
126 |
127 |
128 | def japanese_to_ipa(text):
129 | text = japanese_to_romaji_with_accent(text).replace('...', '…')
130 | text = re.sub(
131 | r'([aiueo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
132 | text = get_real_sokuon(text)
133 | text = get_real_hatsuon(text)
134 | for regex, replacement in _romaji_to_ipa:
135 | text = re.sub(regex, replacement, text)
136 | return text
137 |
138 |
139 | def japanese_to_ipa2(text):
140 | text = japanese_to_romaji_with_accent(text).replace('...', '…')
141 | text = get_real_sokuon(text)
142 | text = get_real_hatsuon(text)
143 | for regex, replacement in _romaji_to_ipa2:
144 | text = re.sub(regex, replacement, text)
145 | return text
146 |
147 |
148 | def japanese_to_ipa3(text):
149 | text = japanese_to_ipa2(text).replace('n^', 'ȵ').replace(
150 | 'ʃ', 'ɕ').replace('*', '\u0325').replace('#', '\u031a')
151 | text = re.sub(
152 | r'([aiɯeo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
153 | text = re.sub(r'((?:^|\s)(?:ts|tɕ|[kpt]))', r'\1ʰ', text)
154 | return text
155 |
--------------------------------------------------------------------------------
/utils/g2p/mandarin.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import re
4 | import jieba
5 | import cn2an
6 | import logging
7 |
8 |
9 | # List of (Latin alphabet, bopomofo) pairs:
10 | _latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
11 | ('a', 'ㄟˉ'),
12 | ('b', 'ㄅㄧˋ'),
13 | ('c', 'ㄙㄧˉ'),
14 | ('d', 'ㄉㄧˋ'),
15 | ('e', 'ㄧˋ'),
16 | ('f', 'ㄝˊㄈㄨˋ'),
17 | ('g', 'ㄐㄧˋ'),
18 | ('h', 'ㄝˇㄑㄩˋ'),
19 | ('i', 'ㄞˋ'),
20 | ('j', 'ㄐㄟˋ'),
21 | ('k', 'ㄎㄟˋ'),
22 | ('l', 'ㄝˊㄛˋ'),
23 | ('m', 'ㄝˊㄇㄨˋ'),
24 | ('n', 'ㄣˉ'),
25 | ('o', 'ㄡˉ'),
26 | ('p', 'ㄆㄧˉ'),
27 | ('q', 'ㄎㄧㄡˉ'),
28 | ('r', 'ㄚˋ'),
29 | ('s', 'ㄝˊㄙˋ'),
30 | ('t', 'ㄊㄧˋ'),
31 | ('u', 'ㄧㄡˉ'),
32 | ('v', 'ㄨㄧˉ'),
33 | ('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'),
34 | ('x', 'ㄝˉㄎㄨˋㄙˋ'),
35 | ('y', 'ㄨㄞˋ'),
36 | ('z', 'ㄗㄟˋ')
37 | ]]
38 |
39 | # List of (bopomofo, romaji) pairs:
40 | _bopomofo_to_romaji = [(re.compile('%s' % x[0]), x[1]) for x in [
41 | ('ㄅㄛ', 'p⁼wo'),
42 | ('ㄆㄛ', 'pʰwo'),
43 | ('ㄇㄛ', 'mwo'),
44 | ('ㄈㄛ', 'fwo'),
45 | ('ㄅ', 'p⁼'),
46 | ('ㄆ', 'pʰ'),
47 | ('ㄇ', 'm'),
48 | ('ㄈ', 'f'),
49 | ('ㄉ', 't⁼'),
50 | ('ㄊ', 'tʰ'),
51 | ('ㄋ', 'n'),
52 | ('ㄌ', 'l'),
53 | ('ㄍ', 'k⁼'),
54 | ('ㄎ', 'kʰ'),
55 | ('ㄏ', 'h'),
56 | ('ㄐ', 'ʧ⁼'),
57 | ('ㄑ', 'ʧʰ'),
58 | ('ㄒ', 'ʃ'),
59 | ('ㄓ', 'ʦ`⁼'),
60 | ('ㄔ', 'ʦ`ʰ'),
61 | ('ㄕ', 's`'),
62 | ('ㄖ', 'ɹ`'),
63 | ('ㄗ', 'ʦ⁼'),
64 | ('ㄘ', 'ʦʰ'),
65 | ('ㄙ', 's'),
66 | ('ㄚ', 'a'),
67 | ('ㄛ', 'o'),
68 | ('ㄜ', 'ə'),
69 | ('ㄝ', 'e'),
70 | ('ㄞ', 'ai'),
71 | ('ㄟ', 'ei'),
72 | ('ㄠ', 'au'),
73 | ('ㄡ', 'ou'),
74 | ('ㄧㄢ', 'yeNN'),
75 | ('ㄢ', 'aNN'),
76 | ('ㄧㄣ', 'iNN'),
77 | ('ㄣ', 'əNN'),
78 | ('ㄤ', 'aNg'),
79 | ('ㄧㄥ', 'iNg'),
80 | ('ㄨㄥ', 'uNg'),
81 | ('ㄩㄥ', 'yuNg'),
82 | ('ㄥ', 'əNg'),
83 | ('ㄦ', 'əɻ'),
84 | ('ㄧ', 'i'),
85 | ('ㄨ', 'u'),
86 | ('ㄩ', 'ɥ'),
87 | ('ˉ', '→'),
88 | ('ˊ', '↑'),
89 | ('ˇ', '↓↑'),
90 | ('ˋ', '↓'),
91 | ('˙', ''),
92 | (',', ','),
93 | ('。', '.'),
94 | ('!', '!'),
95 | ('?', '?'),
96 | ('—', '-')
97 | ]]
98 |
99 | # List of (romaji, ipa) pairs:
100 | _romaji_to_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
101 | ('ʃy', 'ʃ'),
102 | ('ʧʰy', 'ʧʰ'),
103 | ('ʧ⁼y', 'ʧ⁼'),
104 | ('NN', 'n'),
105 | ('Ng', 'ŋ'),
106 | ('y', 'j'),
107 | ('h', 'x')
108 | ]]
109 |
110 | # List of (bopomofo, ipa) pairs:
111 | _bopomofo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
112 | ('ㄅㄛ', 'p⁼wo'),
113 | ('ㄆㄛ', 'pʰwo'),
114 | ('ㄇㄛ', 'mwo'),
115 | ('ㄈㄛ', 'fwo'),
116 | ('ㄅ', 'p⁼'),
117 | ('ㄆ', 'pʰ'),
118 | ('ㄇ', 'm'),
119 | ('ㄈ', 'f'),
120 | ('ㄉ', 't⁼'),
121 | ('ㄊ', 'tʰ'),
122 | ('ㄋ', 'n'),
123 | ('ㄌ', 'l'),
124 | ('ㄍ', 'k⁼'),
125 | ('ㄎ', 'kʰ'),
126 | ('ㄏ', 'x'),
127 | ('ㄐ', 'tʃ⁼'),
128 | ('ㄑ', 'tʃʰ'),
129 | ('ㄒ', 'ʃ'),
130 | ('ㄓ', 'ts`⁼'),
131 | ('ㄔ', 'ts`ʰ'),
132 | ('ㄕ', 's`'),
133 | ('ㄖ', 'ɹ`'),
134 | ('ㄗ', 'ts⁼'),
135 | ('ㄘ', 'tsʰ'),
136 | ('ㄙ', 's'),
137 | ('ㄚ', 'a'),
138 | ('ㄛ', 'o'),
139 | ('ㄜ', 'ə'),
140 | ('ㄝ', 'ɛ'),
141 | ('ㄞ', 'aɪ'),
142 | ('ㄟ', 'eɪ'),
143 | ('ㄠ', 'ɑʊ'),
144 | ('ㄡ', 'oʊ'),
145 | ('ㄧㄢ', 'jɛn'),
146 | ('ㄩㄢ', 'ɥæn'),
147 | ('ㄢ', 'an'),
148 | ('ㄧㄣ', 'in'),
149 | ('ㄩㄣ', 'ɥn'),
150 | ('ㄣ', 'ən'),
151 | ('ㄤ', 'ɑŋ'),
152 | ('ㄧㄥ', 'iŋ'),
153 | ('ㄨㄥ', 'ʊŋ'),
154 | ('ㄩㄥ', 'jʊŋ'),
155 | ('ㄥ', 'əŋ'),
156 | ('ㄦ', 'əɻ'),
157 | ('ㄧ', 'i'),
158 | ('ㄨ', 'u'),
159 | ('ㄩ', 'ɥ'),
160 | ('ˉ', '→'),
161 | ('ˊ', '↑'),
162 | ('ˇ', '↓↑'),
163 | ('ˋ', '↓'),
164 | ('˙', ''),
165 | (',', ','),
166 | ('。', '.'),
167 | ('!', '!'),
168 | ('?', '?'),
169 | ('—', '-')
170 | ]]
171 |
172 | # List of (bopomofo, ipa2) pairs:
173 | _bopomofo_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
174 | ('ㄅㄛ', 'pwo'),
175 | ('ㄆㄛ', 'pʰwo'),
176 | ('ㄇㄛ', 'mwo'),
177 | ('ㄈㄛ', 'fwo'),
178 | ('ㄅ', 'p'),
179 | ('ㄆ', 'pʰ'),
180 | ('ㄇ', 'm'),
181 | ('ㄈ', 'f'),
182 | ('ㄉ', 't'),
183 | ('ㄊ', 'tʰ'),
184 | ('ㄋ', 'n'),
185 | ('ㄌ', 'l'),
186 | ('ㄍ', 'k'),
187 | ('ㄎ', 'kʰ'),
188 | ('ㄏ', 'h'),
189 | ('ㄐ', 'tɕ'),
190 | ('ㄑ', 'tɕʰ'),
191 | ('ㄒ', 'ɕ'),
192 | ('ㄓ', 'tʂ'),
193 | ('ㄔ', 'tʂʰ'),
194 | ('ㄕ', 'ʂ'),
195 | ('ㄖ', 'ɻ'),
196 | ('ㄗ', 'ts'),
197 | ('ㄘ', 'tsʰ'),
198 | ('ㄙ', 's'),
199 | ('ㄚ', 'a'),
200 | ('ㄛ', 'o'),
201 | ('ㄜ', 'ɤ'),
202 | ('ㄝ', 'ɛ'),
203 | ('ㄞ', 'aɪ'),
204 | ('ㄟ', 'eɪ'),
205 | ('ㄠ', 'ɑʊ'),
206 | ('ㄡ', 'oʊ'),
207 | ('ㄧㄢ', 'jɛn'),
208 | ('ㄩㄢ', 'yæn'),
209 | ('ㄢ', 'an'),
210 | ('ㄧㄣ', 'in'),
211 | ('ㄩㄣ', 'yn'),
212 | ('ㄣ', 'ən'),
213 | ('ㄤ', 'ɑŋ'),
214 | ('ㄧㄥ', 'iŋ'),
215 | ('ㄨㄥ', 'ʊŋ'),
216 | ('ㄩㄥ', 'jʊŋ'),
217 | ('ㄥ', 'ɤŋ'),
218 | ('ㄦ', 'əɻ'),
219 | ('ㄧ', 'i'),
220 | ('ㄨ', 'u'),
221 | ('ㄩ', 'y'),
222 | ('ˉ', '˥'),
223 | ('ˊ', '˧˥'),
224 | ('ˇ', '˨˩˦'),
225 | ('ˋ', '˥˩'),
226 | ('˙', ''),
227 | (',', ','),
228 | ('。', '.'),
229 | ('!', '!'),
230 | ('?', '?'),
231 | ('—', '-')
232 | ]]
233 |
234 |
235 | def number_to_chinese(text):
236 | numbers = re.findall(r'\d+(?:\.?\d+)?', text)
237 | for number in numbers:
238 | text = text.replace(number, cn2an.an2cn(number), 1)
239 | return text
240 |
241 |
242 | def chinese_to_bopomofo(text):
243 | from pypinyin import lazy_pinyin, BOPOMOFO
244 | text = text.replace('、', ',').replace(';', ',').replace(':', ',')
245 | words = jieba.lcut(text, cut_all=False)
246 | text = ''
247 | for word in words:
248 | bopomofos = lazy_pinyin(word, BOPOMOFO)
249 | if not re.search('[\u4e00-\u9fff]', word):
250 | text += word
251 | continue
252 | for i in range(len(bopomofos)):
253 | bopomofos[i] = re.sub(r'([\u3105-\u3129])$', r'\1ˉ', bopomofos[i])
254 | if text != '':
255 | text += ' '
256 | text += ''.join(bopomofos)
257 | return text
258 |
259 |
260 | def latin_to_bopomofo(text):
261 | for regex, replacement in _latin_to_bopomofo:
262 | text = re.sub(regex, replacement, text)
263 | return text
264 |
265 |
266 | def bopomofo_to_romaji(text):
267 | for regex, replacement in _bopomofo_to_romaji:
268 | text = re.sub(regex, replacement, text)
269 | return text
270 |
271 |
272 | def bopomofo_to_ipa(text):
273 | for regex, replacement in _bopomofo_to_ipa:
274 | text = re.sub(regex, replacement, text)
275 | return text
276 |
277 |
278 | def bopomofo_to_ipa2(text):
279 | for regex, replacement in _bopomofo_to_ipa2:
280 | text = re.sub(regex, replacement, text)
281 | return text
282 |
283 |
284 | def chinese_to_romaji(text):
285 | text = number_to_chinese(text)
286 | text = chinese_to_bopomofo(text)
287 | text = latin_to_bopomofo(text)
288 | text = bopomofo_to_romaji(text)
289 | text = re.sub('i([aoe])', r'y\1', text)
290 | text = re.sub('u([aoəe])', r'w\1', text)
291 | text = re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑ ]+|$)',
292 | r'\1ɹ`\2', text).replace('ɻ', 'ɹ`')
293 | text = re.sub('([ʦs][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text)
294 | return text
295 |
296 |
297 | def chinese_to_lazy_ipa(text):
298 | text = chinese_to_romaji(text)
299 | for regex, replacement in _romaji_to_ipa:
300 | text = re.sub(regex, replacement, text)
301 | return text
302 |
303 |
304 | def chinese_to_ipa(text):
305 | text = number_to_chinese(text)
306 | text = chinese_to_bopomofo(text)
307 | text = latin_to_bopomofo(text)
308 | text = bopomofo_to_ipa(text)
309 | text = re.sub('i([aoe])', r'j\1', text)
310 | text = re.sub('u([aoəe])', r'w\1', text)
311 | text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)',
312 | r'\1ɹ`\2', text).replace('ɻ', 'ɹ`')
313 | text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text)
314 | return text
315 |
316 |
317 | def chinese_to_ipa2(text):
318 | text = number_to_chinese(text)
319 | text = chinese_to_bopomofo(text)
320 | text = latin_to_bopomofo(text)
321 | text = bopomofo_to_ipa2(text)
322 | text = re.sub(r'i([aoe])', r'j\1', text)
323 | text = re.sub(r'u([aoəe])', r'w\1', text)
324 | text = re.sub(r'([ʂɹ]ʰ?)([˩˨˧˦˥ ]+|$)', r'\1ʅ\2', text)
325 | text = re.sub(r'(sʰ?)([˩˨˧˦˥ ]+|$)', r'\1ɿ\2', text)
326 | return text
327 |
--------------------------------------------------------------------------------
/utils/g2p/symbols.py:
--------------------------------------------------------------------------------
1 | '''
2 | Defines the set of symbols used in text input to the model.
3 | '''
4 |
5 | # japanese_cleaners
6 | # _pad = '_'
7 | # _punctuation = ',.!?-'
8 | # _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ '
9 |
10 |
11 | '''# japanese_cleaners2
12 | _pad = '_'
13 | _punctuation = ',.!?-~…'
14 | _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ '
15 | '''
16 |
17 |
18 | '''# korean_cleaners
19 | _pad = '_'
20 | _punctuation = ',.!?…~'
21 | _letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
22 | '''
23 |
24 | '''# chinese_cleaners
25 | _pad = '_'
26 | _punctuation = ',。!?—…'
27 | _letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ '
28 | '''
29 |
30 | # # zh_ja_mixture_cleaners
31 | # _pad = '_'
32 | # _punctuation = ',.!?-~…'
33 | # _letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ '
34 |
35 |
36 | '''# sanskrit_cleaners
37 | _pad = '_'
38 | _punctuation = '।'
39 | _letters = 'ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऽािीुूृॄेैोौ्ॠॢ '
40 | '''
41 |
42 | '''# cjks_cleaners
43 | _pad = '_'
44 | _punctuation = ',.!?-~…'
45 | _letters = 'NQabdefghijklmnopstuvwxyzʃʧʥʦɯɹəɥçɸɾβŋɦː⁼ʰ`^#*=→↓↑ '
46 | '''
47 |
48 | '''# thai_cleaners
49 | _pad = '_'
50 | _punctuation = '.!? '
51 | _letters = 'กขฃคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลวศษสหฬอฮฯะัาำิีึืุูเแโใไๅๆ็่้๊๋์'
52 | '''
53 |
54 | # # cjke_cleaners2
55 | _pad = '_'
56 | _punctuation = ',.!?-~…'
57 | _letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
58 |
59 |
60 | '''# shanghainese_cleaners
61 | _pad = '_'
62 | _punctuation = ',.!?…'
63 | _letters = 'abdfghiklmnopstuvyzøŋȵɑɔɕəɤɦɪɿʑʔʰ̩̃ᴀᴇ15678 '
64 | '''
65 |
66 | '''# chinese_dialect_cleaners
67 | _pad = '_'
68 | _punctuation = ',.!?~…─'
69 | _letters = '#Nabdefghijklmnoprstuvwxyzæçøŋœȵɐɑɒɓɔɕɗɘəɚɛɜɣɤɦɪɭɯɵɷɸɻɾɿʂʅʊʋʌʏʑʔʦʮʰʷˀː˥˦˧˨˩̥̩̃̚ᴀᴇ↑↓∅ⱼ '
70 | '''
71 |
72 | # Export all symbols:
73 | symbols = [_pad] + list(_punctuation) + list(_letters)
74 |
75 | # Special symbol ids
76 | SPACE_ID = symbols.index(" ")
77 |
--------------------------------------------------------------------------------
/utils/generation.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | import os
3 | import torch
4 | from vocos import Vocos
5 | import logging
6 | import langid
7 | langid.set_languages(['en', 'zh', 'ja'])
8 |
9 | import pathlib
10 | import platform
11 | if platform.system().lower() == 'windows':
12 | temp = pathlib.PosixPath
13 | pathlib.PosixPath = pathlib.WindowsPath
14 | else:
15 | temp = pathlib.WindowsPath
16 | pathlib.WindowsPath = pathlib.PosixPath
17 |
18 | import numpy as np
19 | from data.tokenizer import (
20 | AudioTokenizer,
21 | tokenize_audio,
22 | )
23 | from data.collation import get_text_token_collater
24 | from models.vallex import VALLE
25 | from utils.g2p import PhonemeBpeTokenizer
26 | from utils.sentence_cutter import split_text_into_sentences
27 |
28 | from macros import *
29 |
30 | device = torch.device("cpu")
31 | if torch.cuda.is_available():
32 | device = torch.device("cuda", 0)
33 |
34 | url = 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'
35 |
36 | checkpoints_dir = "./checkpoints/"
37 |
38 | model_checkpoint_name = "vallex-checkpoint.pt"
39 |
40 | model = None
41 |
42 | codec = None
43 |
44 | vocos = None
45 |
46 | text_tokenizer = PhonemeBpeTokenizer(tokenizer_path="./utils/g2p/bpe_69.json")
47 | text_collater = get_text_token_collater()
48 |
49 | def preload_models():
50 | global model, codec, vocos
51 | if not os.path.exists(checkpoints_dir): os.mkdir(checkpoints_dir)
52 | if not os.path.exists(os.path.join(checkpoints_dir, model_checkpoint_name)):
53 | import wget
54 | try:
55 | logging.info(
56 | "Downloading model from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt ...")
57 | # download from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt to ./checkpoints/vallex-checkpoint.pt
58 | wget.download("https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt",
59 | out="./checkpoints/vallex-checkpoint.pt", bar=wget.bar_adaptive)
60 | except Exception as e:
61 | logging.info(e)
62 | raise Exception(
63 | "\n Model weights download failed, please go to 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'"
64 | "\n manually download model weights and put it to {} .".format(os.getcwd() + "\checkpoints"))
65 | # VALL-E
66 | model = VALLE(
67 | N_DIM,
68 | NUM_HEAD,
69 | NUM_LAYERS,
70 | norm_first=True,
71 | add_prenet=False,
72 | prefix_mode=PREFIX_MODE,
73 | share_embedding=True,
74 | nar_scale_factor=1.0,
75 | prepend_bos=True,
76 | num_quantizers=NUM_QUANTIZERS,
77 | ).to(device)
78 | checkpoint = torch.load(os.path.join(checkpoints_dir, model_checkpoint_name), map_location='cpu')
79 | missing_keys, unexpected_keys = model.load_state_dict(
80 | checkpoint["model"], strict=True
81 | )
82 | assert not missing_keys
83 | model.eval()
84 |
85 | # Encodec
86 | codec = AudioTokenizer(device)
87 |
88 | vocos = Vocos.from_pretrained('charactr/vocos-encodec-24khz').to(device)
89 |
90 | @torch.no_grad()
91 | def generate_audio(text, prompt=None, language='auto', accent='no-accent'):
92 | global model, codec, vocos, text_tokenizer, text_collater
93 | text = text.replace("\n", "").strip(" ")
94 | # detect language
95 | if language == "auto":
96 | language = langid.classify(text)[0]
97 | lang_token = lang2token[language]
98 | lang = token2lang[lang_token]
99 | text = lang_token + text + lang_token
100 |
101 | # load prompt
102 | if prompt is not None:
103 | prompt_path = prompt
104 | if not os.path.exists(prompt_path):
105 | prompt_path = "./presets/" + prompt + ".npz"
106 | if not os.path.exists(prompt_path):
107 | prompt_path = "./customs/" + prompt + ".npz"
108 | if not os.path.exists(prompt_path):
109 | raise ValueError(f"Cannot find prompt {prompt}")
110 | prompt_data = np.load(prompt_path)
111 | audio_prompts = prompt_data['audio_tokens']
112 | text_prompts = prompt_data['text_tokens']
113 | lang_pr = prompt_data['lang_code']
114 | lang_pr = code2lang[int(lang_pr)]
115 |
116 | # numpy to tensor
117 | audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)
118 | text_prompts = torch.tensor(text_prompts).type(torch.int32)
119 | else:
120 | audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device)
121 | text_prompts = torch.zeros([1, 0]).type(torch.int32)
122 | lang_pr = lang if lang != 'mix' else 'en'
123 |
124 | enroll_x_lens = text_prompts.shape[-1]
125 | logging.info(f"synthesize text: {text}")
126 | phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
127 | text_tokens, text_tokens_lens = text_collater(
128 | [
129 | phone_tokens
130 | ]
131 | )
132 | text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
133 | text_tokens_lens += enroll_x_lens
134 | # accent control
135 | lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
136 | encoded_frames = model.inference(
137 | text_tokens.to(device),
138 | text_tokens_lens.to(device),
139 | audio_prompts,
140 | enroll_x_lens=enroll_x_lens,
141 | top_k=-100,
142 | temperature=1,
143 | prompt_language=lang_pr,
144 | text_language=langs if accent == "no-accent" else lang,
145 | )
146 | # Decode with Vocos
147 | frames = encoded_frames.permute(2,0,1)
148 | features = vocos.codes_to_features(frames)
149 | samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
150 |
151 | return samples.squeeze().cpu().numpy()
152 |
153 | @torch.no_grad()
154 | def generate_audio_from_long_text(text, prompt=None, language='auto', accent='no-accent', mode='sliding-window'):
155 | """
156 | For long audio generation, two modes are available.
157 | fixed-prompt: This mode will keep using the same prompt the user has provided, and generate audio sentence by sentence.
158 | sliding-window: This mode will use the last sentence as the prompt for the next sentence, but has some concern on speaker maintenance.
159 | """
160 | global model, codec, vocos, text_tokenizer, text_collater
161 | if prompt is None or prompt == "":
162 | mode = 'sliding-window' # If no prompt is given, use sliding-window mode
163 | sentences = split_text_into_sentences(text)
164 | # detect language
165 | if language == "auto":
166 | language = langid.classify(text)[0]
167 |
168 | # if initial prompt is given, encode it
169 | if prompt is not None and prompt != "":
170 | prompt_path = prompt
171 | if not os.path.exists(prompt_path):
172 | prompt_path = "./presets/" + prompt + ".npz"
173 | if not os.path.exists(prompt_path):
174 | prompt_path = "./customs/" + prompt + ".npz"
175 | if not os.path.exists(prompt_path):
176 | raise ValueError(f"Cannot find prompt {prompt}")
177 | prompt_data = np.load(prompt_path)
178 | audio_prompts = prompt_data['audio_tokens']
179 | text_prompts = prompt_data['text_tokens']
180 | lang_pr = prompt_data['lang_code']
181 | lang_pr = code2lang[int(lang_pr)]
182 |
183 | # numpy to tensor
184 | audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)
185 | text_prompts = torch.tensor(text_prompts).type(torch.int32)
186 | else:
187 | audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device)
188 | text_prompts = torch.zeros([1, 0]).type(torch.int32)
189 | lang_pr = language if language != 'mix' else 'en'
190 | if mode == 'fixed-prompt':
191 | complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)
192 | for text in sentences:
193 | text = text.replace("\n", "").strip(" ")
194 | if text == "":
195 | continue
196 | lang_token = lang2token[language]
197 | lang = token2lang[lang_token]
198 | text = lang_token + text + lang_token
199 |
200 | enroll_x_lens = text_prompts.shape[-1]
201 | logging.info(f"synthesize text: {text}")
202 | phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
203 | text_tokens, text_tokens_lens = text_collater(
204 | [
205 | phone_tokens
206 | ]
207 | )
208 | text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
209 | text_tokens_lens += enroll_x_lens
210 | # accent control
211 | lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
212 | encoded_frames = model.inference(
213 | text_tokens.to(device),
214 | text_tokens_lens.to(device),
215 | audio_prompts,
216 | enroll_x_lens=enroll_x_lens,
217 | top_k=-100,
218 | temperature=1,
219 | prompt_language=lang_pr,
220 | text_language=langs if accent == "no-accent" else lang,
221 | )
222 | complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)
223 | # Decode with Vocos
224 | frames = complete_tokens.permute(1,0,2)
225 | features = vocos.codes_to_features(frames)
226 | samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
227 | return samples.squeeze().cpu().numpy()
228 | elif mode == "sliding-window":
229 | complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)
230 | original_audio_prompts = audio_prompts
231 | original_text_prompts = text_prompts
232 | for text in sentences:
233 | text = text.replace("\n", "").strip(" ")
234 | if text == "":
235 | continue
236 | lang_token = lang2token[language]
237 | lang = token2lang[lang_token]
238 | text = lang_token + text + lang_token
239 |
240 | enroll_x_lens = text_prompts.shape[-1]
241 | logging.info(f"synthesize text: {text}")
242 | phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
243 | text_tokens, text_tokens_lens = text_collater(
244 | [
245 | phone_tokens
246 | ]
247 | )
248 | text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
249 | text_tokens_lens += enroll_x_lens
250 | # accent control
251 | lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
252 | encoded_frames = model.inference(
253 | text_tokens.to(device),
254 | text_tokens_lens.to(device),
255 | audio_prompts,
256 | enroll_x_lens=enroll_x_lens,
257 | top_k=-100,
258 | temperature=1,
259 | prompt_language=lang_pr,
260 | text_language=langs if accent == "no-accent" else lang,
261 | )
262 | complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)
263 | if torch.rand(1) < 0.5:
264 | audio_prompts = encoded_frames[:, :, -NUM_QUANTIZERS:]
265 | text_prompts = text_tokens[:, enroll_x_lens:]
266 | else:
267 | audio_prompts = original_audio_prompts
268 | text_prompts = original_text_prompts
269 | # Decode with Vocos
270 | frames = complete_tokens.permute(1,0,2)
271 | features = vocos.codes_to_features(frames)
272 | samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
273 | return samples.squeeze().cpu().numpy()
274 | else:
275 | raise ValueError(f"No such mode {mode}")
--------------------------------------------------------------------------------
/utils/prompt_making.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchaudio
4 | import logging
5 | import langid
6 | import whisper
7 | langid.set_languages(['en', 'zh', 'ja'])
8 |
9 | import numpy as np
10 | from data.tokenizer import (
11 | AudioTokenizer,
12 | tokenize_audio,
13 | )
14 | from data.collation import get_text_token_collater
15 | from utils.g2p import PhonemeBpeTokenizer
16 |
17 | from macros import *
18 |
19 | text_tokenizer = PhonemeBpeTokenizer(tokenizer_path="./utils/g2p/bpe_69.json")
20 | text_collater = get_text_token_collater()
21 |
22 | device = torch.device("cpu")
23 | if torch.cuda.is_available():
24 | device = torch.device("cuda", 0)
25 |
26 | codec = AudioTokenizer(device)
27 |
28 | if not os.path.exists("./whisper/"): os.mkdir("./whisper/")
29 | whisper_model = None
30 |
31 | @torch.no_grad()
32 | def transcribe_one(model, audio_path):
33 | # load audio and pad/trim it to fit 30 seconds
34 | audio = whisper.load_audio(audio_path)
35 | audio = whisper.pad_or_trim(audio)
36 |
37 | # make log-Mel spectrogram and move to the same device as the model
38 | mel = whisper.log_mel_spectrogram(audio).to(model.device)
39 |
40 | # detect the spoken language
41 | _, probs = model.detect_language(mel)
42 | print(f"Detected language: {max(probs, key=probs.get)}")
43 | lang = max(probs, key=probs.get)
44 | # decode the audio
45 | options = whisper.DecodingOptions(temperature=1.0, best_of=5, fp16=False if device == torch.device("cpu") else True, sample_len=150)
46 | result = whisper.decode(model, mel, options)
47 |
48 | # print the recognized text
49 | print(result.text)
50 |
51 | text_pr = result.text
52 | if text_pr.strip(" ")[-1] not in "?!.,。,?!。、":
53 | text_pr += "."
54 | return lang, text_pr
55 |
56 | def make_prompt(name, audio_prompt_path, transcript=None):
57 | global model, text_collater, text_tokenizer, codec
58 | wav_pr, sr = torchaudio.load(audio_prompt_path)
59 | # check length
60 | if wav_pr.size(-1) / sr > 15:
61 | raise ValueError(f"Prompt too long, expect length below 15 seconds, got {wav_pr / sr} seconds.")
62 | if wav_pr.size(0) == 2:
63 | wav_pr = wav_pr.mean(0, keepdim=True)
64 | text_pr, lang_pr = make_transcript(name, wav_pr, sr, transcript)
65 |
66 | # tokenize audio
67 | encoded_frames = tokenize_audio(codec, (wav_pr, sr))
68 | audio_tokens = encoded_frames[0][0].transpose(2, 1).cpu().numpy()
69 |
70 | # tokenize text
71 | phonemes, langs = text_tokenizer.tokenize(text=f"{text_pr}".strip())
72 | text_tokens, enroll_x_lens = text_collater(
73 | [
74 | phonemes
75 | ]
76 | )
77 |
78 | message = f"Detected language: {lang_pr}\n Detected text {text_pr}\n"
79 |
80 | # save as npz file
81 | save_path = os.path.join("./customs/", f"{name}.npz")
82 | np.savez(save_path, audio_tokens=audio_tokens, text_tokens=text_tokens, lang_code=lang2code[lang_pr])
83 | logging.info(f"Successful. Prompt saved to {save_path}")
84 |
85 |
86 | def make_transcript(name, wav, sr, transcript=None):
87 |
88 | if not isinstance(wav, torch.FloatTensor):
89 | wav = torch.tensor(wav)
90 | if wav.abs().max() > 1:
91 | wav /= wav.abs().max()
92 | if wav.size(-1) == 2:
93 | wav = wav.mean(-1, keepdim=False)
94 | if wav.ndim == 1:
95 | wav = wav.unsqueeze(0)
96 | assert wav.ndim and wav.size(0) == 1
97 | if transcript is None or transcript == "":
98 | logging.info("Transcript not given, using Whisper...")
99 | global whisper_model
100 | if whisper_model is None:
101 | whisper_model = whisper.load_model("medium", download_root=os.path.join(os.getcwd(), "whisper"))
102 | whisper_model.to(device)
103 | torchaudio.save(f"./prompts/{name}.wav", wav, sr)
104 | lang, text = transcribe_one(whisper_model, f"./prompts/{name}.wav")
105 | lang_token = lang2token[lang]
106 | text = lang_token + text + lang_token
107 | os.remove(f"./prompts/{name}.wav")
108 | whisper_model.cpu()
109 | else:
110 | text = transcript
111 | lang, _ = langid.classify(text)
112 | lang_token = lang2token[lang]
113 | text = lang_token + text + lang_token
114 |
115 | torch.cuda.empty_cache()
116 | return text, lang
--------------------------------------------------------------------------------
/utils/sentence_cutter.py:
--------------------------------------------------------------------------------
1 | import nltk
2 | import jieba
3 | import sudachipy
4 | import langid
5 | langid.set_languages(['en', 'zh', 'ja'])
6 |
7 | def split_text_into_sentences(text):
8 | if langid.classify(text)[0] == "en":
9 | sentences = nltk.tokenize.sent_tokenize(text)
10 |
11 | return sentences
12 | elif langid.classify(text)[0] == "zh":
13 | sentences = []
14 | segs = jieba.cut(text, cut_all=False)
15 | segs = list(segs)
16 | start = 0
17 | for i, seg in enumerate(segs):
18 | if seg in ["。", "!", "?", "……"]:
19 | sentences.append("".join(segs[start:i + 1]))
20 | start = i + 1
21 | if start < len(segs):
22 | sentences.append("".join(segs[start:]))
23 |
24 | return sentences
25 | elif langid.classify(text)[0] == "ja":
26 | sentences = []
27 | tokenizer = sudachipy.Dictionary().create()
28 | tokens = tokenizer.tokenize(text)
29 | current_sentence = ""
30 |
31 | for token in tokens:
32 | current_sentence += token.surface()
33 | if token.part_of_speech()[0] == "補助記号" and token.part_of_speech()[1] == "句点":
34 | sentences.append(current_sentence)
35 | current_sentence = ""
36 |
37 | if current_sentence:
38 | sentences.append(current_sentence)
39 |
40 | return sentences
41 |
42 | raise RuntimeError("It is impossible to reach here.")
43 |
44 | long_text = """
45 | This is a very long paragraph, so most TTS model is unable to handle it. Hence, we have to split it into several sentences. With the help of NLTK, we can split it into sentences. However, the punctuation is not preserved, so we have to add it back. How are we going to do write this code? Let's see.
46 | """
47 |
48 | long_text = """
49 | 现在我们要来尝试一下中文分句。因为很不幸的是,NLTK不支持中文分句。幸运的是,我们可以使用jieba来分句。但是,jieba分句后,标点符号会丢失,所以我们要手动添加回去。我现在正在想办法把这个例句写的更长更复杂一点,来测试jieba分句的性能。嗯......省略号,感觉不太好,因为省略号不是句号,所以jieba不会把它当作句子的结尾。会这样吗?我们来试试看。
50 | """
51 |
52 | long_text = """
53 | これなら、英語と中国語の分句もできる。でも、日本語はどうする?まつわ、ChatGPTに僕と教えてください。ちょーと待ってください。あ、出来た!
54 | """
--------------------------------------------------------------------------------
/utils/symbol_table.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang)
2 | #
3 | # See ../../../LICENSE for clarification regarding multiple authors
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from dataclasses import dataclass
18 | from dataclasses import field
19 | from typing import Dict
20 | from typing import Generic
21 | from typing import List
22 | from typing import Optional
23 | from typing import TypeVar
24 | from typing import Union
25 |
26 | Symbol = TypeVar('Symbol')
27 |
28 |
29 | # Disable __repr__ otherwise it could freeze e.g. Jupyter.
30 | @dataclass(repr=False)
31 | class SymbolTable(Generic[Symbol]):
32 | '''SymbolTable that maps symbol IDs, found on the FSA arcs to
33 | actual objects. These objects can be arbitrary Python objects
34 | that can serve as keys in a dictionary (i.e. they need to be
35 | hashable and immutable).
36 |
37 | The SymbolTable can only be read to/written from disk if the
38 | symbols are strings.
39 | '''
40 | _id2sym: Dict[int, Symbol] = field(default_factory=dict)
41 | '''Map an integer to a symbol.
42 | '''
43 |
44 | _sym2id: Dict[Symbol, int] = field(default_factory=dict)
45 | '''Map a symbol to an integer.
46 | '''
47 |
48 | _next_available_id: int = 1
49 | '''A helper internal field that helps adding new symbols
50 | to the table efficiently.
51 | '''
52 |
53 | eps: Symbol = ''
54 | '''Null symbol, always mapped to index 0.
55 | '''
56 |
57 | def __post_init__(self):
58 | for idx, sym in self._id2sym.items():
59 | assert self._sym2id[sym] == idx
60 | assert idx >= 0
61 |
62 | for sym, idx in self._sym2id.items():
63 | assert idx >= 0
64 | assert self._id2sym[idx] == sym
65 |
66 | if 0 not in self._id2sym:
67 | self._id2sym[0] = self.eps
68 | self._sym2id[self.eps] = 0
69 | else:
70 | assert self._id2sym[0] == self.eps
71 | assert self._sym2id[self.eps] == 0
72 |
73 | self._next_available_id = max(self._id2sym) + 1
74 |
75 | @staticmethod
76 | def from_str(s: str) -> 'SymbolTable':
77 | '''Build a symbol table from a string.
78 |
79 | The string consists of lines. Every line has two fields separated
80 | by space(s), tab(s) or both. The first field is the symbol and the
81 | second the integer id of the symbol.
82 |
83 | Args:
84 | s:
85 | The input string with the format described above.
86 | Returns:
87 | An instance of :class:`SymbolTable`.
88 | '''
89 | id2sym: Dict[int, str] = dict()
90 | sym2id: Dict[str, int] = dict()
91 |
92 | for line in s.split('\n'):
93 | fields = line.split()
94 | if len(fields) == 0:
95 | continue # skip empty lines
96 | assert len(fields) == 2, \
97 | f'Expect a line with 2 fields. Given: {len(fields)}'
98 | sym, idx = fields[0], int(fields[1])
99 | assert sym not in sym2id, f'Duplicated symbol {sym}'
100 | assert idx not in id2sym, f'Duplicated id {idx}'
101 | id2sym[idx] = sym
102 | sym2id[sym] = idx
103 |
104 | eps = id2sym.get(0, '')
105 |
106 | return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=eps)
107 |
108 | @staticmethod
109 | def from_file(filename: str) -> 'SymbolTable':
110 | '''Build a symbol table from file.
111 |
112 | Every line in the symbol table file has two fields separated by
113 | space(s), tab(s) or both. The following is an example file:
114 |
115 | .. code-block::
116 |
117 | 0
118 | a 1
119 | b 2
120 | c 3
121 |
122 | Args:
123 | filename:
124 | Name of the symbol table file. Its format is documented above.
125 |
126 | Returns:
127 | An instance of :class:`SymbolTable`.
128 |
129 | '''
130 | with open(filename, 'r', encoding='utf-8') as f:
131 | return SymbolTable.from_str(f.read().strip())
132 |
133 | def to_str(self) -> str:
134 | '''
135 | Returns:
136 | Return a string representation of this object. You can pass
137 | it to the method ``from_str`` to recreate an identical object.
138 | '''
139 | s = ''
140 | for idx, symbol in sorted(self._id2sym.items()):
141 | s += f'{symbol} {idx}\n'
142 | return s
143 |
144 | def to_file(self, filename: str):
145 | '''Serialize the SymbolTable to a file.
146 |
147 | Every line in the symbol table file has two fields separated by
148 | space(s), tab(s) or both. The following is an example file:
149 |
150 | .. code-block::
151 |
152 | 0
153 | a 1
154 | b 2
155 | c 3
156 |
157 | Args:
158 | filename:
159 | Name of the symbol table file. Its format is documented above.
160 | '''
161 | with open(filename, 'w', encoding='utf-8') as f:
162 | for idx, symbol in sorted(self._id2sym.items()):
163 | print(symbol, idx, file=f)
164 |
165 | def add(self, symbol: Symbol, index: Optional[int] = None) -> int:
166 | '''Add a new symbol to the SymbolTable.
167 |
168 | Args:
169 | symbol:
170 | The symbol to be added.
171 | index:
172 | Optional int id to which the symbol should be assigned.
173 | If it is not available, a ValueError will be raised.
174 |
175 | Returns:
176 | The int id to which the symbol has been assigned.
177 | '''
178 | # Already in the table? Return its ID.
179 | if symbol in self._sym2id:
180 | return self._sym2id[symbol]
181 | # Specific ID not provided - use next available.
182 | if index is None:
183 | index = self._next_available_id
184 | # Specific ID provided but not available.
185 | if index in self._id2sym:
186 | raise ValueError(f"Cannot assign id '{index}' to '{symbol}' - "
187 | f"already occupied by {self._id2sym[index]}")
188 | self._sym2id[symbol] = index
189 | self._id2sym[index] = symbol
190 |
191 | # Update next available ID if needed
192 | if self._next_available_id <= index:
193 | self._next_available_id = index + 1
194 |
195 | return index
196 |
197 | def get(self, k: Union[int, Symbol]) -> Union[Symbol, int]:
198 | '''Get a symbol for an id or get an id for a symbol
199 |
200 | Args:
201 | k:
202 | If it is an id, it tries to find the symbol corresponding
203 | to the id; if it is a symbol, it tries to find the id
204 | corresponding to the symbol.
205 |
206 | Returns:
207 | An id or a symbol depending on the given `k`.
208 | '''
209 | if isinstance(k, int):
210 | return self._id2sym[k]
211 | else:
212 | return self._sym2id[k]
213 |
214 | def merge(self, other: 'SymbolTable') -> 'SymbolTable':
215 | '''Create a union of two SymbolTables.
216 | Raises an AssertionError if the same IDs are occupied by
217 | different symbols.
218 |
219 | Args:
220 | other:
221 | A symbol table to merge with ``self``.
222 |
223 | Returns:
224 | A new symbol table.
225 | '''
226 | self._check_compatible(other)
227 |
228 | id2sym = {**self._id2sym, **other._id2sym}
229 | sym2id = {**self._sym2id, **other._sym2id}
230 |
231 | return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=self.eps)
232 |
233 | def _check_compatible(self, other: 'SymbolTable') -> None:
234 | # Epsilon compatibility
235 | assert self.eps == other.eps, f'Mismatched epsilon symbol: ' \
236 | f'{self.eps} != {other.eps}'
237 | # IDs compatibility
238 | common_ids = set(self._id2sym).intersection(other._id2sym)
239 | for idx in common_ids:
240 | assert self[idx] == other[idx], f'ID conflict for id: {idx}, ' \
241 | f'self[idx] = "{self[idx]}", ' \
242 | f'other[idx] = "{other[idx]}"'
243 | # Symbols compatibility
244 | common_symbols = set(self._sym2id).intersection(other._sym2id)
245 | for sym in common_symbols:
246 | assert self[sym] == other[sym], f'ID conflict for id: {sym}, ' \
247 | f'self[sym] = "{self[sym]}", ' \
248 | f'other[sym] = "{other[sym]}"'
249 |
250 | def __getitem__(self, item: Union[int, Symbol]) -> Union[Symbol, int]:
251 | return self.get(item)
252 |
253 | def __contains__(self, item: Union[int, Symbol]) -> bool:
254 | if isinstance(item, int):
255 | return item in self._id2sym
256 | else:
257 | return item in self._sym2id
258 |
259 | def __len__(self) -> int:
260 | return len(self._id2sym)
261 |
262 | def __eq__(self, other: 'SymbolTable') -> bool:
263 | if len(self) != len(other):
264 | return False
265 |
266 | for s in self.symbols:
267 | if self[s] != other[s]:
268 | return False
269 |
270 | return True
271 |
272 | @property
273 | def ids(self) -> List[int]:
274 | '''Returns a list of integer IDs corresponding to the symbols.
275 | '''
276 | ans = list(self._id2sym.keys())
277 | ans.sort()
278 | return ans
279 |
280 | @property
281 | def symbols(self) -> List[Symbol]:
282 | '''Returns a list of symbols (e.g., strings) corresponding to
283 | the integer IDs.
284 | '''
285 | ans = list(self._sym2id.keys())
286 | ans.sort()
287 | return ans
288 |
--------------------------------------------------------------------------------