├── LICENSE ├── README-ZH.md ├── README.md ├── customs ├── make_custom_dataset.py └── ph.txt ├── data ├── __init__.py ├── collation.py ├── datamodule.py ├── dataset.py ├── fbank.py ├── input_strategies.py └── tokenizer.py ├── descriptions.py ├── examples.py ├── images └── vallex_framework.jpg ├── launch-ui.py ├── macros.py ├── model-card.md ├── models ├── __init__.py ├── macros.py ├── transformer.py ├── vallex.py └── visualizer.py ├── modules ├── __init__.py ├── activation.py ├── embedding.py ├── optim.py ├── scaling.py ├── scheduler.py └── transformer.py ├── nltk_data └── tokenizers │ └── punkt │ ├── .DS_Store │ ├── PY3 │ ├── README │ ├── czech.pickle │ ├── danish.pickle │ ├── dutch.pickle │ ├── english.pickle │ ├── estonian.pickle │ ├── finnish.pickle │ ├── french.pickle │ ├── german.pickle │ ├── greek.pickle │ ├── italian.pickle │ ├── malayalam.pickle │ ├── norwegian.pickle │ ├── polish.pickle │ ├── portuguese.pickle │ ├── russian.pickle │ ├── slovene.pickle │ ├── spanish.pickle │ ├── swedish.pickle │ └── turkish.pickle │ ├── README │ ├── czech.pickle │ ├── danish.pickle │ ├── dutch.pickle │ ├── english.pickle │ ├── estonian.pickle │ ├── finnish.pickle │ ├── french.pickle │ ├── german.pickle │ ├── greek.pickle │ ├── italian.pickle │ ├── malayalam.pickle │ ├── norwegian.pickle │ ├── polish.pickle │ ├── portuguese.pickle │ ├── russian.pickle │ ├── slovene.pickle │ ├── spanish.pickle │ ├── swedish.pickle │ └── turkish.pickle ├── presets ├── acou_1.npz ├── acou_2.npz ├── acou_3.npz ├── acou_4.npz ├── alan.npz ├── amused.npz ├── anger.npz ├── babara.npz ├── bronya.npz ├── cafe.npz ├── dingzhen.npz ├── disgust.npz ├── emo_amused.npz ├── emo_anger.npz ├── emo_neutral.npz ├── emo_sleepy.npz ├── emotion_sleepiness.npz ├── en2zh_tts_1.npz ├── en2zh_tts_2.npz ├── en2zh_tts_3.npz ├── en2zh_tts_4.npz ├── esta.npz ├── fuxuan.npz ├── librispeech_1.npz ├── librispeech_2.npz ├── librispeech_3.npz ├── librispeech_4.npz ├── neutral.npz ├── paimon.npz ├── rosalia.npz ├── seel.npz ├── sleepiness.npz ├── vctk_1.npz ├── vctk_2.npz ├── vctk_3.npz ├── vctk_4.npz ├── yaesakura.npz ├── zh2en_tts_1.npz ├── zh2en_tts_2.npz ├── zh2en_tts_3.npz └── zh2en_tts_4.npz ├── prompts ├── en-1.wav ├── en-2.wav ├── ja-1.wav ├── ja-2.ogg ├── ph.txt ├── zh-1.wav └── zh-2.wav ├── requirements.txt ├── test.py ├── train.py ├── train_utils ├── __pycache__ │ └── utils.cpython-310.pyc ├── icefall │ └── utils.py ├── lhotse │ └── utils.py ├── model.py └── utils.py └── utils ├── __init__.py ├── download.py ├── g2p ├── __init__.py ├── bpe_1024.json ├── bpe_69.json ├── cleaners.py ├── english.py ├── japanese.py ├── mandarin.py └── symbols.py ├── generation.py ├── prompt_making.py ├── sentence_cutter.py └── symbol_table.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Songting 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README-ZH.md: -------------------------------------------------------------------------------- 1 | # VALL-E X: 多语言文本到语音合成与语音克隆 🔊 2 | [![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/qCBRmAnTxg) 3 |
4 | [English](README.md) | 中文 5 |
6 | 微软[VALL-E X](https://arxiv.org/pdf/2303.03926) 零样本语音合成模型的开源实现.
7 | **预训练模型现已向公众开放,供研究或应用使用。** 8 | ![vallex-framework](/images/vallex_framework.jpg "VALL-E X framework") 9 | 10 | VALL-E X 是一个强大而创新的多语言文本转语音(TTS)模型,最初由微软发布。虽然微软最初在他们的研究论文中提出了该概念,但并未发布任何代码或预训练模型。我们认识到了这项技术的潜力和价值,复现并训练了一个开源可用的VALL-E X模型。我们很乐意与社区分享我们的预训练模型,让每个人都能体验到次世代TTS的威力。 🎧 11 |
12 | 更多细节请查看 [model card](./model-card.md). 13 | 14 | ## 📖 目录 15 | * [🚀 更新日志](#-更新日志) 16 | * [📢 功能特点](#-功能特点) 17 | * [💻 本地安装](#-本地安装) 18 | * [🎧 在线Demo](#-在线Demo) 19 | * [🐍 使用方法](#-Python中的使用方法) 20 | * [❓ FAQ](#-faq) 21 | * [🧠 TODO](#-todo) 22 | 23 | ## 🚀 Updates 24 | **2023.09.10** 25 | - 支持AR decoder的batch decoding以实现更稳定的生成结果 26 | 27 | **2023.08.30** 28 | - 将EnCodec解码器替换成了Vocos解码器,提升了音质。 (感谢[@v0xie](https://github.com/v0xie)) 29 | 30 | **2023.08.23** 31 | - 加入了长文本生成功能 32 | 33 | **2023.08.20** 34 | - 加入了中文版README 35 | 36 | **2023.08.14** 37 | - 预训练模型权重已发布,从[这里](https://drive.google.com/file/d/10gdQWvP-K_e1undkvv0p2b7SU6I4Egyl/view?usp=sharing)下载。 38 | 39 | ## 💻 本地安装 40 | ### 使用pip安装,推荐使用Python 3.10,CUDA 11.7 ~ 12.0,PyTorch 2.0+ 41 | ```commandline 42 | git clone https://github.com/Plachtaa/VALL-E-X.git 43 | cd VALL-E-X 44 | pip install -r requirements.txt 45 | ``` 46 | 47 | > 注意:如果需要制作prompt,需要安装 ffmpeg 并将其所在文件夹加入到环境变量PATH中 48 | 49 | 第一次运行程序时,会自动下载相应的模型。如果下载失败并报错,请按照以下步骤手动下载模型。 50 | 51 | (请注意目录和文件夹的大小写) 52 | 53 | 1.检查安装目录下是否存在`checkpoints`文件夹,如果没有,在安装目录下手动创建`checkpoints`文件夹(`./checkpoints/`)。 54 | 55 | 2.检查`checkpoints`文件夹中是否有`vallex-checkpoint.pt`文件。如果没有,请从[这里](https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt) 56 | 手动下载`vallex-checkpoint.pt`文件并放到`checkpoints`文件夹里。 57 | 58 | 3.检查安装目录下是否存在`whisper`文件夹,如果没有,在安装目录下手动创建`whisper`文件夹(`./whisper/`)。 59 | 60 | 4.检查`whisper`文件夹中是否有`medium.pt`文件。如果没有,请从[这里](https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt) 61 | 手动下载`medium.pt`文件并放到`whisper`文件夹里。 62 | 63 | ## 🎧 在线Demo 64 | 如果你不想在本地安装,你可以在线体验VALL-E X的功能,点击下面的任意一个链接即可开始体验。 65 |
66 | [![Open in Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/Plachta/VALL-E-X) 67 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1yyD_sz531QntLKowMHo-XxorsFBCfKul?usp=sharing) 68 | 69 | 70 | ## 📢 功能特点 71 | 72 | VALL-E X 配备有一系列尖端功能: 73 | 74 | 1. **多语言 TTS**: 可使用三种语言 - 英语、中文和日语 - 进行自然、富有表现力的语音合成。 75 | 76 | 2. **零样本语音克隆**: 仅需录制任意说话人的短短的 3~10 秒录音,VALL-E X 就能生成个性化、高质量的语音,完美还原他们的声音。 77 | 78 |
79 |
查看示例
80 | 81 | [prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/a7baa51d-a53a-41cc-a03d-6970f25fcca7) 82 | 83 | 84 | [output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/b895601a-d126-4138-beff-061aabdc7985) 85 | 86 |
87 | 88 | 3. **语音情感控制**: VALL-E X 可以合成与给定说话人录音相同情感的语音,为音频增添更多表现力。 89 | 90 |
91 |
查看示例
92 | 93 | https://github.com/Plachtaa/VALL-E-X/assets/112609742/56fa9988-925e-4757-82c5-83ecb0df6266 94 | 95 | 96 | https://github.com/Plachtaa/VALL-E-X/assets/112609742/699c47a3-d502-4801-8364-bd89bcc0b8f1 97 | 98 |
99 | 100 | 4. **零样本跨语言语音合成**: VALL-E X 可以合成与给定说话人母语不同的另一种语言,在不影响口音和流利度的同时,保留该说话人的音色与情感。以下是一个使用日语母语者进行英文与中文合成的样例: 🇯🇵 🗣 101 | 102 |
103 |
查看示例
104 | 105 | [jp-prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/ea6e2ee4-139a-41b4-837e-0bd04dda6e19) 106 | 107 | 108 | [en-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/db8f9782-923f-425e-ba94-e8c1bd48f207) 109 | 110 | 111 | [zh-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/15829d79-e448-44d3-8965-fafa7a3f8c28) 112 | 113 |
114 | 115 | 5. **口音控制**: VALL-E X 允许您控制所合成音频的口音,比如说中文带英语口音或反之。 🇨🇳 💬 116 | 117 |
118 |
查看示例
119 | 120 | [en-prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/f688d7f6-70ef-46ec-b1cc-355c31e78b3b) 121 | 122 | 123 | [zh-accent-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/be59c7ca-b45b-44ca-a30d-4d800c950ccc) 124 | 125 | 126 | [en-accent-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/8b4f4f9b-f299-4ea4-a548-137437b71738) 127 | 128 |
129 | 130 | 6. **声学环境保留**: 当给定说话人的录音在不同的声学环境下录制时,VALL-E X 可以保留该声学环境,使合成语音听起来更加自然。 131 | 132 |
133 |
查看示例
134 | 135 | [noise-prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/68986d88-abd0-4d1d-96e4-4f893eb9259e) 136 | 137 | 138 | [noise-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/96c4c612-4516-4683-8804-501b70938608) 139 | 140 |
141 | 142 | 143 | 你可以访问我们的[demo页面](https://plachtaa.github.io/) 来浏览更多示例! 144 | 145 | ## 💻 Python中的使用方法 146 | 147 |
148 |

🪑 基本使用

149 | 150 | ```python 151 | from utils.generation import SAMPLE_RATE, generate_audio, preload_models 152 | from scipy.io.wavfile import write as write_wav 153 | from IPython.display import Audio 154 | 155 | # download and load all models 156 | preload_models() 157 | 158 | # generate audio from text 159 | text_prompt = """ 160 | Hello, my name is Nose. And uh, and I like hamburger. Hahaha... But I also have other interests such as playing tactic toast. 161 | """ 162 | audio_array = generate_audio(text_prompt) 163 | 164 | # save audio to disk 165 | write_wav("vallex_generation.wav", SAMPLE_RATE, audio_array) 166 | 167 | # play text in notebook 168 | Audio(audio_array, rate=SAMPLE_RATE) 169 | ``` 170 | 171 | [hamburger.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/578d7bbe-cda9-483e-898c-29646edc8f2e) 172 | 173 |
174 | 175 |
176 |

🌎 多语言

177 |
178 | 该VALL-E X实现支持三种语言:英语、中文和日语。您可以通过设置`language`参数来指定语言。默认情况下,该模型将自动检测语言。 179 |
180 | 181 | ```python 182 | 183 | text_prompt = """ 184 | チュソクは私のお気に入りの祭りです。 私は数日間休んで、友人や家族との時間を過ごすことができます。 185 | """ 186 | audio_array = generate_audio(text_prompt) 187 | ``` 188 | 189 | [vallex_japanese.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/ee57a688-3e83-4be5-b0fe-019d16eec51c) 190 | 191 | *注意:即使在一句话中混合多种语言的情况下,VALL-E X也能完美地控制口音,但是您需要手动标记各个句子对应的语言以便于我们的G2P工具识别它们。* 192 | ```python 193 | text_prompt = """ 194 | [EN]The Thirty Years' War was a devastating conflict that had a profound impact on Europe.[EN] 195 | [ZH]这是历史的开始。 如果您想听更多,请继续。[ZH] 196 | """ 197 | audio_array = generate_audio(text_prompt, language='mix') 198 | ``` 199 | 200 | [vallex_codeswitch.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/d8667abf-bd08-499f-a383-a861d852f98a) 201 | 202 |
203 | 204 |
205 |

📼 预设音色

206 | 207 | 我们提供十几种说话人音色可直接VALL-E X使用! 在[这里](/presets)浏览所有可用音色。 208 | 209 | > VALL-E X 尝试匹配给定预设音色的音调、音高、情感和韵律。该模型还尝试保留音乐、环境噪声等。 210 | ```python 211 | text_prompt = """ 212 | I am an innocent boy with a smoky voice. It is a great honor for me to speak at the United Nations today. 213 | """ 214 | audio_array = generate_audio(text_prompt, prompt="dingzhen") 215 | ``` 216 | 217 | [smoky.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/d3f55732-b1cd-420f-87d6-eab60db14dc5) 218 | 219 |
220 | 221 |
222 |

🎙声音克隆

223 | 224 | VALL-E X 支持声音克隆!你可以使用任何人,角色,甚至是你自己的声音,来制作一个音频提示。在你使用该音频提示时,VALL-E X 将会使用与其相似的声音来合成文本。 225 |
226 | 你需要提供一段3~10秒长的语音,以及该语音对应的文本,来制作音频提示。你也可以将文本留空,让[Whisper](https://github.com/openai/whisper)模型为你生成文本。 227 | > VALL-E X 尝试匹配给定音频提示的音调、音高、情感和韵律。该模型还尝试保留音乐、环境噪声等。 228 | 229 | ```python 230 | from utils.prompt_making import make_prompt 231 | 232 | ### Use given transcript 233 | make_prompt(name="paimon", audio_prompt_path="paimon_prompt.wav", 234 | transcript="Just, what was that? Paimon thought we were gonna get eaten.") 235 | 236 | ### Alternatively, use whisper 237 | make_prompt(name="paimon", audio_prompt_path="paimon_prompt.wav") 238 | ``` 239 | 来尝试一下刚刚做好的音频提示吧! 240 | ```python 241 | from utils.generation import SAMPLE_RATE, generate_audio, preload_models 242 | from scipy.io.wavfile import write as write_wav 243 | 244 | # download and load all models 245 | preload_models() 246 | 247 | text_prompt = """ 248 | Hey, Traveler, Listen to this, This machine has taken my voice, and now it can talk just like me! 249 | """ 250 | audio_array = generate_audio(text_prompt, prompt="paimon") 251 | 252 | write_wav("paimon_cloned.wav", SAMPLE_RATE, audio_array) 253 | 254 | ``` 255 | 256 | [paimon_prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/e7922859-9d12-4e2a-8651-e156e4280311) 257 | 258 | 259 | [paimon_cloned.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/60d3b7e9-5ead-4024-b499-a897ce5f3d5e) 260 | 261 | 262 |
263 | 264 | 265 |
266 |

🎢用户界面

267 | 268 | 如果你不擅长代码,我们还为VALL-E X创建了一个用户友好的图形界面。它可以让您轻松地与模型进行交互,使语音克隆和多语言语音合成变得轻而易举。 269 |
270 | 使用以下命令启动用户界面: 271 | ```commandline 272 | python -X utf8 launch-ui.py 273 | ``` 274 |
275 | 276 | ## 🛠️ 硬件要求及推理速度 277 | 278 | VALL-E X 可以在CPU或GPU上运行 (`pytorch 2.0+`, CUDA 11.7 ~ CUDA 12.0). 279 | 280 | 若使用GPU运行,你需要至少6GB的显存。 281 | 282 | ## ⚙️ Details 283 | 284 | VALL-E X 与 [Bark](https://github.com/suno-ai/bark), [VALL-E](https://arxiv.org/abs/2301.02111) and [AudioLM](https://arxiv.org/abs/2209.03143)类似, 使用GPT风格的模型以自回归方式预测量化音频token,并由[EnCodec](https://github.com/facebookresearch/encodec)解码. 285 |
286 | 与 [Bark](https://github.com/suno-ai/bark) 相比: 287 | - ✔ **轻量**: 3️⃣ ✖ 更小, 288 | - ✔ **快速**: 4️⃣ ✖ 更快, 289 | - ✔ **中文&日文的更高质量** 290 | - ✔ **跨语言合成时没有外国口音** 291 | - ✔ **开放且易于操作的声音克隆** 292 | - ❌ **支持的语言较少** 293 | - ❌ **没有用于合成音乐及特殊音效的token** 294 | 295 | ### 支持的语言 296 | 297 | | 语言 | 状态 | 298 | |---------| :---: | 299 | | 英语 (en) | ✅ | 300 | | 日语 (ja) | ✅ | 301 | | 中文 (zh) | ✅ | 302 | 303 | ## ❓ FAQ 304 | 305 | #### 在哪里可以下载checkpoint? 306 | * 当您第一次运行程序时,我们使用`wget`将模型下载到`./checkpoints/`目录里。 307 | * 如果第一次运行时下载失败,请从[这里](https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt)手动下载模型,并将文件放在`./checkpoints/`里。 308 | 309 | #### 需要多少显存? 310 | * 6GB 显存(GPU VRAM) - 几乎所有NVIDIA GPU都满足要求. 311 | 312 | #### 为什么模型无法生成长文本? 313 | * 当序列长度增加时,Transformer的计算复杂度呈二次方增长。因此,所有训练音频都保持在22秒以下。请确保音频提示(audio prompt)和生成的音频的总长度小于22秒以确保可接受的性能。 314 | 315 | #### 更多... 316 | 317 | ## 🧠 待办事项 318 | - [x] 添加中文 README 319 | - [x] 长文本生成 320 | - [x] 用Vocos解码器替换Encodec解码器 321 | - [ ] 微调以实现更好的语音自适应 322 | - [ ] 给非python用户的`.bat`脚本 323 | - [ ] 更多... 324 | 325 | ## 🙏 感谢 326 | - [VALL-E X paper](https://arxiv.org/pdf/2303.03926) for the brilliant idea 327 | - [lifeiteng's vall-e](https://github.com/lifeiteng/vall-e) for related training code 328 | - [bark](https://github.com/suno-ai/bark) for the amazing pioneering work in neuro-codec TTS model 329 | 330 | ## ⭐️ 表示出你的支持 331 | 332 | 如果您觉得VALL-E X有趣且有用,请在GitHub上给我们一颗星! ⭐️ 它鼓励我们不断改进模型并添加令人兴奋的功能。 333 | 334 | ## 📜 License 335 | 336 | VALL-E X 使用 [MIT License](./LICENSE). 337 | 338 | --- 339 | 340 | 有问题或需要帮助? 可以随便 [open an issue](https://github.com/Plachtaa/VALL-E-X/issues/new) 或加入我们的 [Discord](https://discord.gg/qCBRmAnTxg) 341 | 342 | Happy voice cloning! 🎤 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VALL-E X: Multilingual Text-to-Speech Synthesis and Voice Cloning 🔊 2 | 3 | ## Fork README 4 | This repository eliminates the cumbersome dependencies of VALL-E-X and allows for fine tuning on custom data sets. 5 | Please refer to the original README as the basic operation has not been changed at all from the original. 6 | 7 | ## Current Accomplishments 8 | The training code worked. 9 | It was possible to train on custom datasets. 10 | 11 | ## How to create and use CustomDataset 12 | ```python 13 | from customs.make_custom_dataset import create_dataset 14 | 15 | ''' 16 | How should the data_dir be created? 17 | Place the necessary audio files in data_dir. 18 | Transcription, tokenization, etc. of the audio files are done by the create_dataset function. 19 | 20 | data_dir 21 | ├── bpe_69.json 22 | ├── utt1.wav 23 | ├── utt2.wav 24 | ├── utt3.wav 25 | ...... 26 | └── utt{n}.wav 27 | ''' 28 | 29 | data_dir = "your data_dir" 30 | create_dataset(data_dir, dataloader_process_only=True) 31 | ``` 32 | 33 | # When Training 34 | When training, please specify data_dir for training data and data_dir for validation data as "--train_dir" and "--valid_dir" as arguments on the command line. 35 | 36 | 37 | ## Original README 38 | [![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/qCBRmAnTxg) 39 |
40 | English | [中文](README-ZH.md) 41 |
42 | An open source implementation of Microsoft's [VALL-E X](https://arxiv.org/pdf/2303.03926) zero-shot TTS model.
43 | **We release our trained model to the public for research or application usage.** 44 | 45 | ![vallex-framework](/images/vallex_framework.jpg "VALL-E X framework") 46 | 47 | VALL-E X is an amazing multilingual text-to-speech (TTS) model proposed by Microsoft. While Microsoft initially publish in their research paper, they did not release any code or pretrained models. Recognizing the potential and value of this technology, our team took on the challenge to reproduce the results and train our own model. We are glad to share our trained VALL-E X model with the community, allowing everyone to experience the power next-generation TTS! 🎧 48 |
49 |
50 | More details about the model are presented in [model card](./model-card.md). 51 | 52 | ## 📖 Quick Index 53 | * [🚀 Updates](#-updates) 54 | * [📢 Features](#-features) 55 | * [💻 Installation](#-installation) 56 | * [🎧 Demos](#-demos) 57 | * [🐍 Usage](#-usage-in-python) 58 | * [❓ FAQ](#-faq) 59 | * [🧠 TODO](#-todo) 60 | 61 | ## 🚀 Updates 62 | **2023.09.10** 63 | - Added AR decoder batch decoding for more stable generation result. 64 | 65 | **2023.08.30** 66 | - Replaced EnCodec decoder with Vocos decoder, improved audio quality. (Thanks to [@v0xie](https://github.com/v0xie)) 67 | 68 | **2023.08.23** 69 | - Added long text generation. 70 | 71 | **2023.08.20** 72 | - Added [Chinese README](README-ZH.md). 73 | 74 | **2023.08.14** 75 | - Pretrained VALL-E X checkpoint is now released. Download it [here](https://drive.google.com/file/d/10gdQWvP-K_e1undkvv0p2b7SU6I4Egyl/view?usp=sharing) 76 | 77 | ## 💻 Installation 78 | ### Install with pip, recommended with Python 3.10, CUDA 11.7 ~ 12.0, PyTorch 2.0+ 79 | ```commandline 80 | git clone https://github.com/Plachtaa/VALL-E-X.git 81 | cd VALL-E-X 82 | pip install -r requirements.txt 83 | ``` 84 | 85 | > Note: If you want to make prompt, you need to install ffmpeg and add its folder to the environment variable PATH. 86 | 87 | When you run the program for the first time, it will automatically download the corresponding model. 88 | 89 | If the download fails and reports an error, please follow the steps below to manually download the model. 90 | 91 | (Please pay attention to the capitalization of folders) 92 | 93 | 1. Check whether there is a `checkpoints` folder in the installation directory. 94 | If not, manually create a `checkpoints` folder (`./checkpoints/`) in the installation directory. 95 | 96 | 2. Check whether there is a `vallex-checkpoint.pt` file in the `checkpoints` folder. 97 | If not, please manually download the `vallex-checkpoint.pt` file from [here](https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt) and put it in the `checkpoints` folder. 98 | 99 | 3. Check whether there is a `whisper` folder in the installation directory. 100 | If not, manually create a `whisper` folder (`./whisper/`) in the installation directory. 101 | 102 | 4. Check whether there is a `medium.pt` file in the `whisper` folder. 103 | If not, please manually download the `medium.pt` file from [here](https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt) and put it in the `whisper` folder. 104 | 105 | ## 🎧 Demos 106 | Not ready to set up the environment on your local machine just yet? No problem! We've got you covered with our online demos. You can try out VALL-E X directly on Hugging Face or Google Colab, experiencing the model's capabilities hassle-free! 107 |
108 | [![Open in Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/Plachta/VALL-E-X) 109 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1yyD_sz531QntLKowMHo-XxorsFBCfKul?usp=sharing) 110 | 111 | 112 | ## 📢 Features 113 | 114 | VALL-E X comes packed with cutting-edge functionalities: 115 | 116 | 1. **Multilingual TTS**: Speak in three languages - English, Chinese, and Japanese - with natural and expressive speech synthesis. 117 | 118 | 2. **Zero-shot Voice Cloning**: Enroll a short 3~10 seconds recording of an unseen speaker, and watch VALL-E X create personalized, high-quality speech that sounds just like them! 119 | 120 |
121 |
see example
122 | 123 | [prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/a7baa51d-a53a-41cc-a03d-6970f25fcca7) 124 | 125 | 126 | [output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/b895601a-d126-4138-beff-061aabdc7985) 127 | 128 |
129 | 130 | 3. **Speech Emotion Control**: Experience the power of emotions! VALL-E X can synthesize speech with the same emotion as the acoustic prompt provided, adding an extra layer of expressiveness to your audio. 131 | 132 |
133 |
see example
134 | 135 | https://github.com/Plachtaa/VALL-E-X/assets/112609742/56fa9988-925e-4757-82c5-83ecb0df6266 136 | 137 | 138 | https://github.com/Plachtaa/VALL-E-X/assets/112609742/699c47a3-d502-4801-8364-bd89bcc0b8f1 139 | 140 |
141 | 142 | 4. **Zero-shot Cross-Lingual Speech Synthesis**: Take monolingual speakers on a linguistic journey! VALL-E X can produce personalized speech in another language without compromising on fluency or accent. Below is a Japanese speaker talk in Chinese & English. 🇯🇵 🗣 143 | 144 |
145 |
see example
146 | 147 | [jp-prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/ea6e2ee4-139a-41b4-837e-0bd04dda6e19) 148 | 149 | 150 | [en-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/db8f9782-923f-425e-ba94-e8c1bd48f207) 151 | 152 | 153 | [zh-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/15829d79-e448-44d3-8965-fafa7a3f8c28) 154 | 155 |
156 | 157 | 5. **Accent Control**: Get creative with accents! VALL-E X allows you to experiment with different accents, like speaking Chinese with an English accent or vice versa. 🇨🇳 💬 158 | 159 |
160 |
see example
161 | 162 | [en-prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/f688d7f6-70ef-46ec-b1cc-355c31e78b3b) 163 | 164 | 165 | [zh-accent-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/be59c7ca-b45b-44ca-a30d-4d800c950ccc) 166 | 167 | 168 | [en-accent-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/8b4f4f9b-f299-4ea4-a548-137437b71738) 169 | 170 |
171 | 172 | 6. **Acoustic Environment Maintenance**: No need for perfectly clean audio prompts! VALL-E X adapts to the acoustic environment of the input, making speech generation feel natural and immersive. 173 | 174 |
175 |
see example
176 | 177 | [noise-prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/68986d88-abd0-4d1d-96e4-4f893eb9259e) 178 | 179 | 180 | [noise-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/96c4c612-4516-4683-8804-501b70938608) 181 | 182 |
183 | 184 | 185 | Explore our [demo page](https://plachtaa.github.io/) for a lot more examples! 186 | 187 | ## 🐍 Usage in Python 188 | 189 |
190 |

🪑 Basics

191 | 192 | ```python 193 | from utils.generation import SAMPLE_RATE, generate_audio, preload_models 194 | from scipy.io.wavfile import write as write_wav 195 | from IPython.display import Audio 196 | 197 | # download and load all models 198 | preload_models() 199 | 200 | # generate audio from text 201 | text_prompt = """ 202 | Hello, my name is Nose. And uh, and I like hamburger. Hahaha... But I also have other interests such as playing tactic toast. 203 | """ 204 | audio_array = generate_audio(text_prompt) 205 | 206 | # save audio to disk 207 | write_wav("vallex_generation.wav", SAMPLE_RATE, audio_array) 208 | 209 | # play text in notebook 210 | Audio(audio_array, rate=SAMPLE_RATE) 211 | ``` 212 | 213 | [hamburger.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/578d7bbe-cda9-483e-898c-29646edc8f2e) 214 | 215 |
216 | 217 |
218 |

🌎 Foreign Language

219 |
220 | This VALL-E X implementation also supports Chinese and Japanese. All three languages have equally awesome performance! 221 |
222 | 223 | ```python 224 | 225 | text_prompt = """ 226 | チュソクは私のお気に入りの祭りです。 私は数日間休んで、友人や家族との時間を過ごすことができます。 227 | """ 228 | audio_array = generate_audio(text_prompt) 229 | ``` 230 | 231 | [vallex_japanese.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/ee57a688-3e83-4be5-b0fe-019d16eec51c) 232 | 233 | *Note: VALL-E X controls accent perfectly even when synthesizing code-switch text. However, you need to manually denote language of respective sentences (since our g2p tool is rule-base)* 234 | ```python 235 | text_prompt = """ 236 | [EN]The Thirty Years' War was a devastating conflict that had a profound impact on Europe.[EN] 237 | [ZH]这是历史的开始。 如果您想听更多,请继续。[ZH] 238 | """ 239 | audio_array = generate_audio(text_prompt, language='mix') 240 | ``` 241 | 242 | [vallex_codeswitch.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/d8667abf-bd08-499f-a383-a861d852f98a) 243 | 244 |
245 | 246 |
247 |

📼 Voice Presets

248 | 249 | VALL-E X provides tens of speaker voices which you can directly used for inference! Browse all voices in the [code](/presets) 250 | 251 | > VALL-E X tries to match the tone, pitch, emotion and prosody of a given preset. The model also attempts to preserve music, ambient noise, etc. 252 | 253 | ```python 254 | text_prompt = """ 255 | I am an innocent boy with a smoky voice. It is a great honor for me to speak at the United Nations today. 256 | """ 257 | audio_array = generate_audio(text_prompt, prompt="dingzhen") 258 | ``` 259 | 260 | [smoky.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/d3f55732-b1cd-420f-87d6-eab60db14dc5) 261 | 262 |
263 | 264 |
265 |

🎙Voice Cloning

266 | 267 | VALL-E X supports voice cloning! You can make a voice prompt with any person, character or even your own voice, and use it like other voice presets.
268 | To make a voice prompt, you need to provide a speech of 3~10 seconds long, as well as the transcript of the speech. 269 | You can also leave the transcript blank to let the [Whisper](https://github.com/openai/whisper) model to generate the transcript. 270 | > VALL-E X tries to match the tone, pitch, emotion and prosody of a given prompt. The model also attempts to preserve music, ambient noise, etc. 271 | 272 | ```python 273 | from utils.prompt_making import make_prompt 274 | 275 | ### Use given transcript 276 | make_prompt(name="paimon", audio_prompt_path="paimon_prompt.wav", 277 | transcript="Just, what was that? Paimon thought we were gonna get eaten.") 278 | 279 | ### Alternatively, use whisper 280 | make_prompt(name="paimon", audio_prompt_path="paimon_prompt.wav") 281 | ``` 282 | Now let's try out the prompt we've just made! 283 | ```python 284 | from utils.generation import SAMPLE_RATE, generate_audio, preload_models 285 | from scipy.io.wavfile import write as write_wav 286 | 287 | # download and load all models 288 | preload_models() 289 | 290 | text_prompt = """ 291 | Hey, Traveler, Listen to this, This machine has taken my voice, and now it can talk just like me! 292 | """ 293 | audio_array = generate_audio(text_prompt, prompt="paimon") 294 | 295 | write_wav("paimon_cloned.wav", SAMPLE_RATE, audio_array) 296 | 297 | ``` 298 | 299 | [paimon_prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/e7922859-9d12-4e2a-8651-e156e4280311) 300 | 301 | 302 | [paimon_cloned.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/60d3b7e9-5ead-4024-b499-a897ce5f3d5e) 303 | 304 | 305 |
306 | 307 | 308 |
309 |

🎢User Interface

310 | 311 | Not comfortable with codes? No problem! We've also created a user-friendly graphical interface for VALL-E X. It allows you to interact with the model effortlessly, making voice cloning and multilingual speech synthesis a breeze. 312 |
313 | You can launch the UI by the following command: 314 | ```commandline 315 | python -X utf8 launch-ui.py 316 | ``` 317 |
318 | 319 | ## 🛠️ Hardware and Inference Speed 320 | 321 | VALL-E X works well on both CPU and GPU (`pytorch 2.0+`, CUDA 11.7 and CUDA 12.0). 322 | 323 | A GPU VRAM of 6GB is enough for running VALL-E X without offloading. 324 | 325 | ## ⚙️ Details 326 | 327 | VALL-E X is similar to [Bark](https://github.com/suno-ai/bark), [VALL-E](https://arxiv.org/abs/2301.02111) and [AudioLM](https://arxiv.org/abs/2209.03143), which generates audio in GPT-style by predicting audio tokens quantized by [EnCodec](https://github.com/facebookresearch/encodec). 328 |
329 | Comparing to [Bark](https://github.com/suno-ai/bark): 330 | - ✔ **Light-weighted**: 3️⃣ ✖ smaller, 331 | - ✔ **Efficient**: 4️⃣ ✖ faster, 332 | - ✔ **Better quality on Chinese & Japanese** 333 | - ✔ **Cross-lingual speech without foreign accent** 334 | - ✔ **Easy voice-cloning** 335 | - ❌ **Less languages** 336 | - ❌ **No special tokens for music / sound effects** 337 | 338 | ### Supported Languages 339 | 340 | | Language | Status | 341 | | --- | :---: | 342 | | English (en) | ✅ | 343 | | Japanese (ja) | ✅ | 344 | | Chinese, simplified (zh) | ✅ | 345 | 346 | ## ❓ FAQ 347 | 348 | #### Where can I download the model checkpoint? 349 | * We use `wget` to download the model to directory `./checkpoints/` when you run the program for the first time. 350 | * If the download fails on the first run, please manually download from [this link](https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt), and put the file under directory `./checkpoints/`. 351 | 352 | #### How much VRAM do I need? 353 | * 6GB GPU VRAM - Almost all NVIDIA GPUs satisfy the requirement. 354 | 355 | #### Why the model fails to generate long text? 356 | * Transformer's computation complexity increases quadratically while the sequence length increases. Hence, all training 357 | are kept under 22 seconds. Please make sure the total length of audio prompt and generated audio is less than 22 seconds 358 | to ensure acceptable performance. 359 | 360 | 361 | #### MORE TO BE ADDED... 362 | 363 | ## 🧠 TODO 364 | - [x] Add Chinese README 365 | - [x] Long text generation 366 | - [x] Replace Encodec decoder with Vocos decoder 367 | - [ ] Fine-tuning for better voice adaptation 368 | - [ ] `.bat` scripts for non-python users 369 | - [ ] To be added... 370 | 371 | ## 🙏 Appreciation 372 | - [VALL-E X paper](https://arxiv.org/pdf/2303.03926) for the brilliant idea 373 | - [lifeiteng's vall-e](https://github.com/lifeiteng/vall-e) for related training code 374 | - [bark](https://github.com/suno-ai/bark) for the amazing pioneering work in neuro-codec TTS model 375 | 376 | ## ⭐️ Show Your Support 377 | 378 | If you find VALL-E X interesting and useful, give us a star on GitHub! ⭐️ It encourages us to keep improving the model and adding exciting features. 379 | 380 | ## 📜 License 381 | 382 | VALL-E X is licensed under the [MIT License](./LICENSE). 383 | 384 | --- 385 | 386 | Have questions or need assistance? Feel free to [open an issue](https://github.com/Plachtaa/VALL-E-X/issues/new) or join our [Discord](https://discord.gg/qCBRmAnTxg) 387 | 388 | Happy voice cloning! 🎤 389 | -------------------------------------------------------------------------------- /customs/make_custom_dataset.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import glob 3 | import torch 4 | import numpy as np 5 | import os 6 | import torchaudio 7 | import soundfile as sf 8 | from utils.g2p.symbols import symbols 9 | from utils.g2p import PhonemeBpeTokenizer 10 | from utils.prompt_making import make_prompt, make_transcript 11 | from data.collation import get_text_token_collater 12 | from data.dataset import create_dataloader 13 | 14 | # Mappings from symbol to numeric ID and vice versa: 15 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 16 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 17 | from data.tokenizer import ( 18 | AudioTokenizer, 19 | tokenize_audio, 20 | ) 21 | 22 | tokenizer_path = "./utils/g2p/bpe_69.json" 23 | tokenizer = PhonemeBpeTokenizer(tokenizer_path) 24 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 25 | 26 | def make_prompts(name, audio_prompt_path, transcript=None): 27 | text_tokenizer = PhonemeBpeTokenizer(tokenizer_path="./utils/g2p/bpe_69.json") 28 | text_collater = get_text_token_collater() 29 | codec = AudioTokenizer(device) 30 | wav_pr, sr = torchaudio.load(audio_prompt_path) 31 | # check length 32 | if wav_pr.size(-1) / sr > 15: 33 | raise ValueError(f"Prompt too long, expect length below 15 seconds, got {wav_pr / sr} seconds.") 34 | if wav_pr.size(0) == 2: 35 | wav_pr = wav_pr.mean(0, keepdim=True) 36 | text_pr, lang_pr = make_transcript(name, wav_pr, sr, transcript) 37 | 38 | # tokenize audio 39 | encoded_frames = tokenize_audio(codec, (wav_pr, sr)) 40 | audio_tokens = encoded_frames[0][0].transpose(2, 1).cpu().numpy() 41 | 42 | # tokenize text 43 | phonemes, langs = text_tokenizer.tokenize(text=f"{text_pr}".strip()) 44 | text_tokens, enroll_x_lens = text_collater( 45 | [ 46 | phonemes 47 | ] 48 | ) 49 | 50 | return audio_tokens, text_tokens, langs, text_pr 51 | 52 | def create_dataset(data_dir, dataloader_process_only): 53 | if dataloader_process_only: 54 | h5_output_path=f"{data_dir}/audio_sum.hdf5" 55 | ann_output_path=f"{data_dir}/audio_ann_sum.txt" 56 | #audio_folder = os.path.join(data_dir, 'audio') 57 | audio_paths = glob.glob(f"{data_dir}/*.wav") # Change this to match your audio file extension 58 | 59 | # Create or open an HDF5 file 60 | with h5py.File(h5_output_path, 'w') as h5_file: 61 | # Loop through each audio and text file, assuming they have the same stem 62 | for audio_path in audio_paths: 63 | stem = os.path.splitext(os.path.basename(audio_path))[0] 64 | audio_tokens, text_tokens, langs, text = make_prompts(name=stem, audio_prompt_path=audio_path) 65 | 66 | text_tokens = text_tokens.squeeze(0) 67 | # Create a group for each stem 68 | grp = h5_file.create_group(stem) 69 | # Add audio and text tokens as datasets to the group 70 | grp.create_dataset('audio', data=audio_tokens) 71 | #grp.create_dataset('text', data=text_tokens) 72 | 73 | with open(ann_output_path, 'a', encoding='utf-8') as ann_file: 74 | try: 75 | audio, sample_rate = sf.read(audio_path) 76 | duration = len(audio) / sample_rate 77 | ann_file.write(f'{stem}|{duration}|{langs[0]}|{text}\n') # 改行を追加 78 | print(f"Successfully wrote to {ann_output_path}") 79 | except Exception as e: 80 | print(f"An error occurred: {e}") 81 | else: 82 | dataloader = create_dataloader(data_dir=data_dir) 83 | return dataloader -------------------------------------------------------------------------------- /customs/ph.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/customs/ph.txt -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | # from .datamodule import * 2 | # from .tokenizer import * 3 | from .collation import * 4 | -------------------------------------------------------------------------------- /data/collation.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from utils import SymbolTable 8 | 9 | 10 | class TextTokenCollater: 11 | """Collate list of text tokens 12 | 13 | Map sentences to integers. Sentences are padded to equal length. 14 | Beginning and end-of-sequence symbols can be added. 15 | 16 | Example: 17 | >>> token_collater = TextTokenCollater(text_tokens) 18 | >>> tokens_batch, tokens_lens = token_collater(text) 19 | 20 | Returns: 21 | tokens_batch: IntTensor of shape (B, L) 22 | B: batch dimension, number of input sentences 23 | L: length of the longest sentence 24 | tokens_lens: IntTensor of shape (B,) 25 | Length of each sentence after adding and 26 | but before padding. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | text_tokens: List[str], 32 | add_eos: bool = True, 33 | add_bos: bool = True, 34 | pad_symbol: str = "", 35 | bos_symbol: str = "", 36 | eos_symbol: str = "", 37 | ): 38 | self.pad_symbol = pad_symbol 39 | 40 | self.add_eos = add_eos 41 | self.add_bos = add_bos 42 | 43 | self.bos_symbol = bos_symbol 44 | self.eos_symbol = eos_symbol 45 | 46 | unique_tokens = ( 47 | [pad_symbol] 48 | + ([bos_symbol] if add_bos else []) 49 | + ([eos_symbol] if add_eos else []) 50 | + sorted(text_tokens) 51 | ) 52 | 53 | self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)} 54 | self.idx2token = [token for token in unique_tokens] 55 | 56 | def index( 57 | self, tokens_list: List[str] 58 | ) -> Tuple[torch.Tensor, torch.Tensor]: 59 | seqs, seq_lens = [], [] 60 | for tokens in tokens_list: 61 | assert ( 62 | all([True if s in self.token2idx else False for s in tokens]) 63 | is True 64 | ) 65 | seq = ( 66 | ([self.bos_symbol] if self.add_bos else []) 67 | + list(tokens) 68 | + ([self.eos_symbol] if self.add_eos else []) 69 | ) 70 | seqs.append(seq) 71 | seq_lens.append(len(seq)) 72 | 73 | max_len = max(seq_lens) 74 | for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)): 75 | seq.extend([self.pad_symbol] * (max_len - seq_len)) 76 | 77 | tokens = torch.from_numpy( 78 | np.array( 79 | [[self.token2idx[token] for token in seq] for seq in seqs], 80 | dtype=np.int64, 81 | ) 82 | ) 83 | tokens_lens = torch.IntTensor(seq_lens) 84 | 85 | return tokens, tokens_lens 86 | 87 | def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: 88 | tokens_seqs = [[p for p in text] for text in texts] 89 | max_len = len(max(tokens_seqs, key=len)) 90 | 91 | seqs = [ 92 | ([self.bos_symbol] if self.add_bos else []) 93 | + list(seq) 94 | + ([self.eos_symbol] if self.add_eos else []) 95 | + [self.pad_symbol] * (max_len - len(seq)) 96 | for seq in tokens_seqs 97 | ] 98 | 99 | tokens_batch = torch.from_numpy( 100 | np.array( 101 | [seq for seq in seqs], 102 | dtype=np.int64, 103 | ) 104 | ) 105 | 106 | tokens_lens = torch.IntTensor( 107 | [ 108 | len(seq) + int(self.add_eos) + int(self.add_bos) 109 | for seq in tokens_seqs 110 | ] 111 | ) 112 | 113 | return tokens_batch, tokens_lens 114 | 115 | 116 | def get_text_token_collater() -> TextTokenCollater: 117 | collater = TextTokenCollater( 118 | ['0'], add_bos=False, add_eos=False 119 | ) 120 | return collater 121 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 (authors: Feiteng Li) 2 | # 3 | # See ../../../../LICENSE for clarification regarding multiple authors 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """ 18 | modified from lhoste.dataset.speech_synthesis.py 19 | """ 20 | 21 | import torch 22 | import math 23 | import h5py 24 | from tokenizers import Tokenizer 25 | from typing import Union, List 26 | import numpy as np 27 | from tqdm import tqdm 28 | from utils.g2p import PhonemeBpeTokenizer 29 | from data.collation import get_text_token_collater 30 | text_collater = get_text_token_collater() 31 | 32 | _pad = '_' 33 | _punctuation = ',.!?-~…' 34 | _letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ ' 35 | symbols = [_pad] + list(_punctuation) + list(_letters) 36 | 37 | language_dict = { 38 | 'en': 0, 39 | 'zh': 1, 40 | 'ja': 2, 41 | } 42 | def seq2phone(tokens: Union[List, np.ndarray]): 43 | """ 44 | Convert tokenized phoneme ID sequence back to phoneme string 45 | :param tokens: phoneme tokens 46 | :return: recovered phoneme sequence 47 | """ 48 | phones = "".join([symbols[i] for i in tokens]) 49 | return phones 50 | 51 | class DynamicBatchSampler(torch.utils.data.Sampler): 52 | def __init__(self, sampler, num_tokens_fn, num_buckets=100, min_size=0, max_size=1000, 53 | max_tokens=None, max_sentences=None, drop_last=False): 54 | """ 55 | 56 | :param sampler: 57 | :param num_tokens_fn: 根据idx返回样本的长度的函数 58 | :param num_buckets: 利用桶原理将相似长度的样本放在一个batchsize中,桶的数量 59 | :param min_size: 最小长度的样本, 小于这个值的样本会被过滤掉。 依据这个值来创建样桶 60 | :param max_size: 最大长度的样本 61 | :param max_sentences: batch_size, 但是这里可以通过max_sentences 和 max_tokens 共同控制最终的大小 62 | """ 63 | super(DynamicBatchSampler, self).__init__(sampler) 64 | self.sampler = sampler 65 | self.num_tokens_fn = num_tokens_fn 66 | self.num_buckets = num_buckets 67 | 68 | self.min_size = min_size 69 | self.max_size = max_size 70 | 71 | assert max_size <= max_tokens, "max_size should be smaller than max tokens" 72 | assert max_tokens is not None or max_sentences is not None, \ 73 | "max_tokens and max_sentences should not be null at the same time, please specify one parameter at least" 74 | self.max_tokens = max_tokens if max_tokens is not None else float('Inf') 75 | self.max_sentences = max_sentences if max_sentences is not None else float('Inf') 76 | self.drop_last = drop_last 77 | 78 | def set_epoch(self, epoch): 79 | self.sampler.set_epoch(epoch) 80 | def is_batch_full(self, num_tokens, batch): 81 | if len(batch) == 0: 82 | return False 83 | if len(batch) == self.max_sentences: 84 | return True 85 | if num_tokens > self.max_tokens: 86 | return True 87 | return False 88 | 89 | def __iter__(self): 90 | buckets = [[] for _ in range(self.num_buckets)] 91 | sample_len = [0] * self.num_buckets 92 | 93 | for idx in self.sampler: 94 | idx_length = self.num_tokens_fn(idx) 95 | if not (self.min_size <= idx_length <= self.max_size): 96 | print("sentence at index {} of size {} exceeds max_tokens, the sentence is ignored".format(idx, idx_length)) 97 | continue 98 | 99 | index_buckets = math.floor((idx_length - self.min_size) / (self.max_size - self.min_size + 1) 100 | * self.num_buckets) 101 | sample_len[index_buckets] = max(sample_len[index_buckets], idx_length) 102 | 103 | num_tokens = (len(buckets[index_buckets]) + 1) * sample_len[index_buckets] 104 | if self.is_batch_full(num_tokens, buckets[index_buckets]): 105 | # yield this batch 106 | yield buckets[index_buckets] 107 | buckets[index_buckets] = [] 108 | sample_len[index_buckets] = 0 109 | 110 | buckets[index_buckets].append(idx) 111 | 112 | # process left-over 113 | leftover_batch = [] 114 | leftover_sample_len = 0 115 | leftover = [idx for bucket in buckets for idx in bucket] 116 | for idx in leftover: 117 | idx_length = self.num_tokens_fn(idx) 118 | leftover_sample_len = max(leftover_sample_len, idx_length) 119 | num_tokens = (len(leftover_batch) + 1) * leftover_sample_len 120 | if self.is_batch_full(num_tokens, leftover_batch): 121 | yield leftover_batch 122 | leftover_batch = [] 123 | leftover_sample_len = 0 124 | leftover_batch.append(idx) 125 | 126 | if len(leftover_batch) > 0 and not self.drop_last: 127 | yield leftover_batch 128 | 129 | def __len__(self): 130 | # we do not know the exactly batch size, so do not call len(dataloader) 131 | pass 132 | 133 | 134 | class AudioDataset(torch.utils.data.Dataset): 135 | def __init__(self, h5_path, ann_path, tokenizer_path): 136 | self.h5_path = h5_path 137 | with open(ann_path, 'r', encoding='utf-8') as f: 138 | lines = f.readlines() 139 | ls = [l.split("|") for l in lines] 140 | ls_T = list(zip(*ls)) 141 | #del ls_T[-1] 142 | self.h5_paths, self.durations, self.langs, self.texts = \ 143 | list(ls_T[0]), list(ls_T[1]), list(ls_T[2]), list(ls_T[3]) 144 | self.durations = [float(dur) for dur in self.durations] 145 | self.tokenizer = PhonemeBpeTokenizer(tokenizer_path) 146 | self._archive = None 147 | 148 | def __len__(self): 149 | return len(self.h5_paths) 150 | 151 | def get_dur(self, idx): 152 | return self.durations[idx] 153 | 154 | @property 155 | def archive(self): 156 | if self._archive is None: # lazy loading here! 157 | self._archive = h5py.File(self.h5_path, "r") 158 | return self._archive 159 | def __getitem__(self, idx): 160 | archive = self.archive 161 | h5_path = self.h5_paths[idx] 162 | sub = archive[h5_path] 163 | audio_tokens = sub['audio'][()] 164 | #phone_tokens = sub['text'][()] 165 | dur = self.durations[idx] 166 | lang = self.langs[idx] 167 | text = self.texts[idx] 168 | # tokenization should be done within dataloader 169 | #phones = seq2phone(phone_tokens) 170 | #phones = phones.replace(" ", "_") 171 | phonemes, langs = self.tokenizer.tokenize(text=f"{text}".strip()) 172 | cptpho_tokens, enroll_x_lens = text_collater([phonemes]) 173 | cptpho_tokens = cptpho_tokens.squeeze(0) 174 | text_token_lens = enroll_x_lens[0] 175 | ''' 176 | if not len(phones): 177 | cptpho_tokens = self.tokenizer.encode(text).ids 178 | else: 179 | cptpho_tokens = self.tokenizer.encode(phones).ids 180 | ''' 181 | assert len(cptpho_tokens) 182 | return { 183 | 'utt_id': h5_path, 184 | 'text': text, 185 | 'audio': None, 186 | 'audio_lens': None, 187 | 'audio_features': audio_tokens, 188 | 'audio_features_lens': audio_tokens.shape[1], 189 | 'text_tokens': np.array(cptpho_tokens), 190 | 'text_tokens_lens': text_token_lens, 191 | 'language': language_dict[lang], 192 | } 193 | 194 | def collate(batch): 195 | utt_id_s = [b['utt_id'] for b in batch] 196 | text_s = [b['text'] for b in batch] 197 | 198 | audio_s = [b['audio'] for b in batch] 199 | audio_lens_s = [b['audio_lens'] for b in batch] 200 | 201 | audio_features_lens_s = [b['audio_features_lens'] for b in batch] 202 | # create an empty tensor with maximum audio feature length 203 | audio_features_s = torch.zeros([len(batch), max(audio_features_lens_s), 8], dtype=torch.int64) - 1 # audio pad with -1 204 | 205 | text_tokens_lens_s = [b['text_tokens_lens'] for b in batch] 206 | # create an empty tensor with maximum text tokens length 207 | text_tokens_s = torch.zeros([len(batch), max(text_tokens_lens_s)], dtype=torch.int64) + 3 # [PAD] token id 3 208 | 209 | language_s = [b['language'] for b in batch] 210 | 211 | for i, b in enumerate(batch): 212 | audio_features = b['audio_features'] 213 | audio_features_lens = b['audio_features_lens'] 214 | audio_features_s[i, :audio_features_lens, :] = torch.LongTensor(audio_features) 215 | 216 | text_tokens = b['text_tokens'] 217 | text_tokens_lens = b['text_tokens_lens'] 218 | text_tokens_s[i, :text_tokens_lens] = torch.LongTensor(text_tokens) 219 | 220 | batch = { 221 | 'utt_id': utt_id_s, 222 | 'text': text_s, 223 | 'audio': audio_s, 224 | 'audio_lens': audio_lens_s, 225 | 'audio_features': audio_features_s, 226 | 'audio_features_lens': torch.LongTensor(np.array(audio_features_lens_s)), 227 | 'text_tokens': text_tokens_s, 228 | 'text_tokens_lens': torch.LongTensor(np.array(text_tokens_lens_s)), 229 | 'languages': torch.LongTensor(np.array(language_s)), 230 | } 231 | return batch 232 | 233 | def create_dataloader(data_dir="/root/valle/egs/mix", n_gpus=1, rank=0, num_workers=0, num_buckets=10, max_duration=120): 234 | train_dataset = AudioDataset(h5_path=f"{data_dir}/audio_sum.hdf5", 235 | ann_path=f"{data_dir}/audio_ann_sum.txt", 236 | tokenizer_path=f"{data_dir}/bpe_69.json") 237 | ran_sampler = torch.utils.data.distributed.DistributedSampler( 238 | train_dataset, 239 | num_replicas=n_gpus, 240 | rank=rank, 241 | shuffle=True, 242 | ) 243 | dynamic_sampler = DynamicBatchSampler(ran_sampler, train_dataset.get_dur, num_buckets=num_buckets, max_size=20, 244 | max_tokens=max_duration) 245 | 246 | 247 | train_loader = torch.utils.data.DataLoader(train_dataset, num_workers=num_workers, collate_fn=collate, 248 | batch_sampler=dynamic_sampler) 249 | 250 | return train_loader 251 | -------------------------------------------------------------------------------- /data/fbank.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 (authors: Feiteng Li) 2 | # 3 | # See ../../../../LICENSE for clarification regarding multiple authors 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | 18 | from dataclasses import asdict, dataclass 19 | from typing import Any, Dict, Optional, Union 20 | 21 | import numpy as np 22 | import torch 23 | # from lhotse.features.base import FeatureExtractor 24 | # from lhotse.utils import EPSILON, Seconds, compute_num_frames 25 | from librosa.filters import mel as librosa_mel_fn 26 | 27 | 28 | @dataclass 29 | class BigVGANFbankConfig: 30 | # Spectogram-related part 31 | # Note that frame_length and frame_shift will be converted to milliseconds before torchaudio/Kaldi sees them 32 | frame_length: Seconds = 1024 / 24000.0 33 | frame_shift: Seconds = 256 / 24000.0 34 | remove_dc_offset: bool = True 35 | round_to_power_of_two: bool = True 36 | 37 | # Fbank-related part 38 | low_freq: float = 0.0 39 | high_freq: float = 12000.0 40 | num_mel_bins: int = 100 41 | use_energy: bool = False 42 | 43 | def to_dict(self) -> Dict[str, Any]: 44 | return asdict(self) 45 | 46 | @staticmethod 47 | def from_dict(data: Dict[str, Any]) -> "BigVGANFbankConfig": 48 | return BigVGANFbankConfig(**data) 49 | 50 | 51 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 52 | return torch.log(torch.clamp(x, min=clip_val) * C) 53 | 54 | 55 | def spectral_normalize_torch(magnitudes): 56 | output = dynamic_range_compression_torch(magnitudes) 57 | return output 58 | 59 | 60 | # https://github.com/NVIDIA/BigVGAN 61 | # bigvgan_24khz_100band https://drive.google.com/drive/folders/1EpxX6AsxjCbbk0mmAhE0td6eYiABr8Oz 62 | class BigVGANFbank(FeatureExtractor): 63 | name = "fbank" 64 | config_type = BigVGANFbankConfig 65 | 66 | def __init__(self, config: Optional[Any] = None): 67 | super(BigVGANFbank, self).__init__(config) 68 | sampling_rate = 24000 69 | self.mel_basis = torch.from_numpy( 70 | librosa_mel_fn( 71 | sampling_rate, 72 | 1024, 73 | self.config.num_mel_bins, 74 | self.config.low_freq, 75 | self.config.high_freq, 76 | ).astype(np.float32) 77 | ) 78 | self.hann_window = torch.hann_window(1024) 79 | 80 | def _feature_fn(self, samples, **kwargs): 81 | win_length, n_fft = 1024, 1024 82 | hop_size = 256 83 | if True: 84 | sampling_rate = 24000 85 | duration = round(samples.shape[-1] / sampling_rate, ndigits=12) 86 | expected_num_frames = compute_num_frames( 87 | duration=duration, 88 | frame_shift=self.frame_shift, 89 | sampling_rate=sampling_rate, 90 | ) 91 | pad_size = ( 92 | (expected_num_frames - 1) * hop_size 93 | + win_length 94 | - samples.shape[-1] 95 | ) 96 | assert pad_size >= 0 97 | 98 | y = torch.nn.functional.pad( 99 | samples, 100 | (0, pad_size), 101 | mode="constant", 102 | ) 103 | else: 104 | y = torch.nn.functional.pad( 105 | samples, 106 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 107 | mode="reflect", 108 | ) 109 | 110 | y = y.squeeze(1) 111 | 112 | # complex tensor as default, then use view_as_real for future pytorch compatibility 113 | spec = torch.stft( 114 | y, 115 | n_fft, 116 | hop_length=hop_size, 117 | win_length=win_length, 118 | window=self.hann_window, 119 | center=False, 120 | pad_mode="reflect", 121 | normalized=False, 122 | onesided=True, 123 | return_complex=True, 124 | ) 125 | spec = torch.view_as_real(spec) 126 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 127 | 128 | spec = torch.matmul(self.mel_basis, spec) 129 | spec = spectral_normalize_torch(spec) 130 | 131 | return spec.transpose(2, 1).squeeze(0) 132 | 133 | def extract( 134 | self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int 135 | ) -> np.ndarray: 136 | assert sampling_rate == 24000 137 | params = asdict(self.config) 138 | params.update({"sample_frequency": sampling_rate, "snip_edges": False}) 139 | params["frame_shift"] *= 1000.0 140 | params["frame_length"] *= 1000.0 141 | if not isinstance(samples, torch.Tensor): 142 | samples = torch.from_numpy(samples) 143 | # Torchaudio Kaldi feature extractors expect the channel dimension to be first. 144 | if len(samples.shape) == 1: 145 | samples = samples.unsqueeze(0) 146 | features = self._feature_fn(samples, **params).to(torch.float32) 147 | return features.numpy() 148 | 149 | @property 150 | def frame_shift(self) -> Seconds: 151 | return self.config.frame_shift 152 | 153 | def feature_dim(self, sampling_rate: int) -> int: 154 | return self.config.num_mel_bins 155 | 156 | @staticmethod 157 | def mix( 158 | features_a: np.ndarray, 159 | features_b: np.ndarray, 160 | energy_scaling_factor_b: float, 161 | ) -> np.ndarray: 162 | return np.log( 163 | np.maximum( 164 | # protection against log(0); max with EPSILON is adequate since these are energies (always >= 0) 165 | EPSILON, 166 | np.exp(features_a) 167 | + energy_scaling_factor_b * np.exp(features_b), 168 | ) 169 | ) 170 | 171 | @staticmethod 172 | def compute_energy(features: np.ndarray) -> float: 173 | return float(np.sum(np.exp(features))) 174 | 175 | 176 | def get_fbank_extractor() -> BigVGANFbank: 177 | return BigVGANFbank(BigVGANFbankConfig()) 178 | 179 | 180 | if __name__ == "__main__": 181 | extractor = BigVGANFbank(BigVGANFbankConfig()) 182 | 183 | samples = torch.from_numpy(np.random.random([1000]).astype(np.float32)) 184 | samples = torch.clip(samples, -1.0, 1.0) 185 | fbank = extractor.extract(samples, 24000.0) 186 | print(f"fbank {fbank.shape}") 187 | 188 | from scipy.io.wavfile import read 189 | 190 | MAX_WAV_VALUE = 32768.0 191 | 192 | sampling_rate, samples = read( 193 | "egs/libritts/prompts/5639_40744_000000_000002.wav" 194 | ) 195 | print(f"samples: [{samples.min()}, {samples.max()}]") 196 | fbank = extractor.extract(samples.astype(np.float32) / MAX_WAV_VALUE, 24000) 197 | print(f"fbank {fbank.shape}") 198 | 199 | import matplotlib.pyplot as plt 200 | 201 | _ = plt.figure(figsize=(18, 10)) 202 | plt.imshow( 203 | X=fbank.transpose(1, 0), 204 | cmap=plt.get_cmap("jet"), 205 | aspect="auto", 206 | interpolation="nearest", 207 | ) 208 | plt.gca().invert_yaxis() 209 | plt.savefig("egs/libritts/prompts/5639_40744_000000_000002.png") 210 | plt.close() 211 | 212 | print("fbank test PASS!") 213 | -------------------------------------------------------------------------------- /data/input_strategies.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | from concurrent.futures import ThreadPoolExecutor 4 | from typing import Tuple, Type 5 | 6 | # from lhotse import CutSet 7 | # from lhotse.dataset.collation import collate_features 8 | # from lhotse.dataset.input_strategies import ( 9 | # ExecutorType, 10 | # PrecomputedFeatures, 11 | # _get_executor, 12 | # ) 13 | # from lhotse.utils import fastcopy 14 | 15 | 16 | class PromptedFeatures: 17 | def __init__(self, prompts, features): 18 | self.prompts = prompts 19 | self.features = features 20 | 21 | def to(self, device): 22 | return PromptedFeatures( 23 | self.prompts.to(device), self.features.to(device) 24 | ) 25 | 26 | def sum(self): 27 | return self.features.sum() 28 | 29 | @property 30 | def ndim(self): 31 | return self.features.ndim 32 | 33 | @property 34 | def data(self): 35 | return (self.prompts, self.features) 36 | 37 | 38 | # class PromptedPrecomputedFeatures(PrecomputedFeatures): 39 | # """ 40 | # :class:`InputStrategy` that reads pre-computed features, whose manifests 41 | # are attached to cuts, from disk. 42 | # 43 | # It automatically pads the feature matrices with pre or post feature. 44 | # 45 | # .. automethod:: __call__ 46 | # """ 47 | # 48 | # def __init__( 49 | # self, 50 | # dataset: str, 51 | # cuts: CutSet, 52 | # num_workers: int = 0, 53 | # executor_type: Type[ExecutorType] = ThreadPoolExecutor, 54 | # ) -> None: 55 | # super(PromptedPrecomputedFeatures, self).__init__( 56 | # num_workers, executor_type 57 | # ) 58 | # 59 | # self.utt2neighbors = defaultdict(lambda: []) 60 | # 61 | # if dataset.lower() == "libritts": 62 | # # 909_131041_000013_000002 63 | # # 909_131041_000013_000003 64 | # speaker2utts = defaultdict(lambda: []) 65 | # 66 | # utt2cut = {} 67 | # for cut in cuts: 68 | # speaker = cut.supervisions[0].speaker 69 | # speaker2utts[speaker].append(cut.id) 70 | # utt2cut[cut.id] = cut 71 | # 72 | # for spk in speaker2utts: 73 | # uttids = sorted(speaker2utts[spk]) 74 | # # Using the property of sorted keys to find previous utterance 75 | # # The keys has structure speaker_book_x_y e.g. 1089_134691_000004_000001 76 | # if len(uttids) == 1: 77 | # self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]]) 78 | # continue 79 | # 80 | # utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1])) 81 | # utt2postutt = dict(zip(uttids[:-1], uttids[1:])) 82 | # 83 | # for utt in utt2prevutt: 84 | # self.utt2neighbors[utt].append(utt2cut[utt2prevutt[utt]]) 85 | # 86 | # for utt in utt2postutt: 87 | # self.utt2neighbors[utt].append(utt2cut[utt2postutt[utt]]) 88 | # elif dataset.lower() == "ljspeech": 89 | # utt2cut = {} 90 | # uttids = [] 91 | # for cut in cuts: 92 | # uttids.append(cut.id) 93 | # utt2cut[cut.id] = cut 94 | # 95 | # if len(uttids) == 1: 96 | # self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]]) 97 | # else: 98 | # # Using the property of sorted keys to find previous utterance 99 | # # The keys has structure: LJ001-0010 100 | # utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1])) 101 | # utt2postutt = dict(zip(uttids[:-1], uttids[1:])) 102 | # 103 | # for utt in utt2postutt: 104 | # postutt = utt2postutt[utt] 105 | # if utt[:5] == postutt[:5]: 106 | # self.utt2neighbors[utt].append(utt2cut[postutt]) 107 | # 108 | # for utt in utt2prevutt: 109 | # prevutt = utt2prevutt[utt] 110 | # if utt[:5] == prevutt[:5] or not self.utt2neighbors[utt]: 111 | # self.utt2neighbors[utt].append(utt2cut[prevutt]) 112 | # else: 113 | # raise ValueError 114 | # 115 | # def __call__( 116 | # self, cuts: CutSet 117 | # ) -> Tuple[PromptedFeatures, PromptedFeatures]: 118 | # """ 119 | # Reads the pre-computed features from disk/other storage. 120 | # The returned shape is``(B, T, F) => (batch_size, num_frames, num_features)``. 121 | # 122 | # :return: a tensor with collated features, and a tensor of ``num_frames`` of each cut before padding. 123 | # """ 124 | # features, features_lens = collate_features( 125 | # cuts, 126 | # executor=_get_executor( 127 | # self.num_workers, executor_type=self._executor_type 128 | # ), 129 | # ) 130 | # 131 | # prompts_cuts = [] 132 | # for k, cut in enumerate(cuts): 133 | # prompts_cut = random.choice(self.utt2neighbors[cut.id]) 134 | # prompts_cuts.append(fastcopy(prompts_cut, id=f"{cut.id}-{str(k)}")) 135 | # 136 | # mini_duration = min([cut.duration for cut in prompts_cuts] + [3.0]) 137 | # # prompts_cuts = CutSet.from_cuts(prompts_cuts).truncate( 138 | # # max_duration=mini_duration, 139 | # # offset_type="random", 140 | # # preserve_id=True, 141 | # # ) 142 | # prompts_cuts = CutSet( 143 | # cuts={k: cut for k, cut in enumerate(prompts_cuts)} 144 | # ).truncate( 145 | # max_duration=mini_duration, 146 | # offset_type="random", 147 | # preserve_id=False, 148 | # ) 149 | # 150 | # prompts, prompts_lens = collate_features( 151 | # prompts_cuts, 152 | # executor=_get_executor( 153 | # self.num_workers, executor_type=self._executor_type 154 | # ), 155 | # ) 156 | # 157 | # return PromptedFeatures(prompts, features), PromptedFeatures( 158 | # prompts_lens, features_lens 159 | # ) 160 | -------------------------------------------------------------------------------- /data/tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2023 (authors: Feiteng Li) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import re 17 | from dataclasses import asdict, dataclass 18 | from typing import Any, Dict, List, Optional, Pattern, Union 19 | 20 | import numpy as np 21 | import torch 22 | import torchaudio 23 | from encodec import EncodecModel 24 | from encodec.utils import convert_audio 25 | from phonemizer.backend import EspeakBackend 26 | from phonemizer.backend.espeak.language_switch import LanguageSwitch 27 | from phonemizer.backend.espeak.words_mismatch import WordMismatch 28 | from phonemizer.punctuation import Punctuation 29 | from phonemizer.separator import Separator 30 | 31 | try: 32 | from pypinyin import Style, pinyin 33 | from pypinyin.style._utils import get_finals, get_initials 34 | except Exception: 35 | pass 36 | 37 | 38 | class PypinyinBackend: 39 | """PypinyinBackend for Chinese. Most codes is referenced from espnet. 40 | There are two types pinyin or initials_finals, one is 41 | just like "ni1 hao3", the other is like "n i1 h ao3". 42 | """ 43 | 44 | def __init__( 45 | self, 46 | backend="initials_finals", 47 | punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), 48 | ) -> None: 49 | self.backend = backend 50 | self.punctuation_marks = punctuation_marks 51 | 52 | def phonemize( 53 | self, text: List[str], separator: Separator, strip=True, njobs=1 54 | ) -> List[str]: 55 | assert isinstance(text, List) 56 | phonemized = [] 57 | for _text in text: 58 | _text = re.sub(" +", " ", _text.strip()) 59 | _text = _text.replace(" ", separator.word) 60 | phones = [] 61 | if self.backend == "pypinyin": 62 | for n, py in enumerate( 63 | pinyin( 64 | _text, style=Style.TONE3, neutral_tone_with_five=True 65 | ) 66 | ): 67 | if all([c in self.punctuation_marks for c in py[0]]): 68 | if len(phones): 69 | assert phones[-1] == separator.syllable 70 | phones.pop(-1) 71 | 72 | phones.extend(list(py[0])) 73 | else: 74 | phones.extend([py[0], separator.syllable]) 75 | elif self.backend == "pypinyin_initials_finals": 76 | for n, py in enumerate( 77 | pinyin( 78 | _text, style=Style.TONE3, neutral_tone_with_five=True 79 | ) 80 | ): 81 | if all([c in self.punctuation_marks for c in py[0]]): 82 | if len(phones): 83 | assert phones[-1] == separator.syllable 84 | phones.pop(-1) 85 | phones.extend(list(py[0])) 86 | else: 87 | if py[0][-1].isalnum(): 88 | initial = get_initials(py[0], strict=False) 89 | if py[0][-1].isdigit(): 90 | final = ( 91 | get_finals(py[0][:-1], strict=False) 92 | + py[0][-1] 93 | ) 94 | else: 95 | final = get_finals(py[0], strict=False) 96 | phones.extend( 97 | [ 98 | initial, 99 | separator.phone, 100 | final, 101 | separator.syllable, 102 | ] 103 | ) 104 | else: 105 | assert ValueError 106 | else: 107 | raise NotImplementedError 108 | phonemized.append( 109 | "".join(phones).rstrip(f"{separator.word}{separator.syllable}") 110 | ) 111 | return phonemized 112 | 113 | 114 | class TextTokenizer: 115 | """Phonemize Text.""" 116 | 117 | def __init__( 118 | self, 119 | language="en-us", 120 | backend="espeak", 121 | separator=Separator(word="_", syllable="-", phone="|"), 122 | preserve_punctuation=True, 123 | punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), 124 | with_stress: bool = False, 125 | tie: Union[bool, str] = False, 126 | language_switch: LanguageSwitch = "keep-flags", 127 | words_mismatch: WordMismatch = "ignore", 128 | ) -> None: 129 | if backend == "espeak": 130 | phonemizer = EspeakBackend( 131 | language, 132 | punctuation_marks=punctuation_marks, 133 | preserve_punctuation=preserve_punctuation, 134 | with_stress=with_stress, 135 | tie=tie, 136 | language_switch=language_switch, 137 | words_mismatch=words_mismatch, 138 | ) 139 | elif backend in ["pypinyin", "pypinyin_initials_finals"]: 140 | phonemizer = PypinyinBackend( 141 | backend=backend, 142 | punctuation_marks=punctuation_marks + separator.word, 143 | ) 144 | else: 145 | raise NotImplementedError(f"{backend}") 146 | 147 | self.backend = phonemizer 148 | self.separator = separator 149 | 150 | def to_list(self, phonemized: str) -> List[str]: 151 | fields = [] 152 | for word in phonemized.split(self.separator.word): 153 | # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z. 154 | pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE) 155 | fields.extend( 156 | [p for p in pp if p != self.separator.phone] 157 | + [self.separator.word] 158 | ) 159 | assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count( 160 | self.separator.phone 161 | ) 162 | return fields[:-1] 163 | 164 | def __call__(self, text, strip=True) -> List[List[str]]: 165 | if isinstance(text, str): 166 | text = [text] 167 | 168 | phonemized = self.backend.phonemize( 169 | text, separator=self.separator, strip=strip, njobs=1 170 | ) 171 | return [self.to_list(p) for p in phonemized] 172 | 173 | 174 | def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]: 175 | phonemes = tokenizer([text.strip()]) 176 | return phonemes[0] # k2symbols 177 | 178 | 179 | def remove_encodec_weight_norm(model): 180 | from encodec.modules import SConv1d 181 | from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock 182 | from torch.nn.utils import remove_weight_norm 183 | 184 | encoder = model.encoder.model 185 | for key in encoder._modules: 186 | if isinstance(encoder._modules[key], SEANetResnetBlock): 187 | remove_weight_norm(encoder._modules[key].shortcut.conv.conv) 188 | block_modules = encoder._modules[key].block._modules 189 | for skey in block_modules: 190 | if isinstance(block_modules[skey], SConv1d): 191 | remove_weight_norm(block_modules[skey].conv.conv) 192 | elif isinstance(encoder._modules[key], SConv1d): 193 | remove_weight_norm(encoder._modules[key].conv.conv) 194 | 195 | decoder = model.decoder.model 196 | for key in decoder._modules: 197 | if isinstance(decoder._modules[key], SEANetResnetBlock): 198 | remove_weight_norm(decoder._modules[key].shortcut.conv.conv) 199 | block_modules = decoder._modules[key].block._modules 200 | for skey in block_modules: 201 | if isinstance(block_modules[skey], SConv1d): 202 | remove_weight_norm(block_modules[skey].conv.conv) 203 | elif isinstance(decoder._modules[key], SConvTranspose1d): 204 | remove_weight_norm(decoder._modules[key].convtr.convtr) 205 | elif isinstance(decoder._modules[key], SConv1d): 206 | remove_weight_norm(decoder._modules[key].conv.conv) 207 | 208 | 209 | class AudioTokenizer: 210 | """EnCodec audio.""" 211 | 212 | def __init__( 213 | self, 214 | device: Any = None, 215 | ) -> None: 216 | # Instantiate a pretrained EnCodec model 217 | model = EncodecModel.encodec_model_24khz() 218 | model.set_target_bandwidth(6.0) 219 | remove_encodec_weight_norm(model) 220 | 221 | if not device: 222 | device = torch.device("cpu") 223 | if torch.cuda.is_available(): 224 | device = torch.device("cuda:0") 225 | 226 | self._device = device 227 | 228 | self.codec = model.to(device) 229 | self.sample_rate = model.sample_rate 230 | self.channels = model.channels 231 | 232 | @property 233 | def device(self): 234 | return self._device 235 | 236 | def encode(self, wav: torch.Tensor) -> torch.Tensor: 237 | return self.codec.encode(wav.to(self.device)) 238 | 239 | def decode(self, frames: torch.Tensor) -> torch.Tensor: 240 | return self.codec.decode(frames) 241 | 242 | 243 | def tokenize_audio(tokenizer: AudioTokenizer, audio): 244 | # Load and pre-process the audio waveform 245 | if isinstance(audio, str): 246 | wav, sr = torchaudio.load(audio) 247 | else: 248 | wav, sr = audio 249 | wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels) 250 | wav = wav.unsqueeze(0) 251 | 252 | # Extract discrete codes from EnCodec 253 | with torch.no_grad(): 254 | encoded_frames = tokenizer.encode(wav) 255 | return encoded_frames 256 | 257 | 258 | # @dataclass 259 | # class AudioTokenConfig: 260 | # frame_shift: Seconds = 320.0 / 24000 261 | # num_quantizers: int = 8 262 | # 263 | # def to_dict(self) -> Dict[str, Any]: 264 | # return asdict(self) 265 | # 266 | # @staticmethod 267 | # def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig": 268 | # return AudioTokenConfig(**data) 269 | # 270 | # 271 | # class AudioTokenExtractor(FeatureExtractor): 272 | # name = "encodec" 273 | # config_type = AudioTokenConfig 274 | # 275 | # def __init__(self, config: Optional[Any] = None): 276 | # super(AudioTokenExtractor, self).__init__(config) 277 | # self.tokenizer = AudioTokenizer() 278 | # 279 | # def extract( 280 | # self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int 281 | # ) -> np.ndarray: 282 | # if not isinstance(samples, torch.Tensor): 283 | # samples = torch.from_numpy(samples) 284 | # if sampling_rate != self.tokenizer.sample_rate: 285 | # samples = convert_audio( 286 | # samples, 287 | # sampling_rate, 288 | # self.tokenizer.sample_rate, 289 | # self.tokenizer.channels, 290 | # ) 291 | # if len(samples.shape) == 2: 292 | # samples = samples.unsqueeze(0) 293 | # else: 294 | # raise ValueError() 295 | # 296 | # device = self.tokenizer.device 297 | # encoded_frames = self.tokenizer.encode(samples.detach().to(device)) 298 | # codes = encoded_frames[0][0] # [B, n_q, T] 299 | # if True: 300 | # duration = round(samples.shape[-1] / sampling_rate, ndigits=12) 301 | # expected_num_frames = compute_num_frames( 302 | # duration=duration, 303 | # frame_shift=self.frame_shift, 304 | # sampling_rate=sampling_rate, 305 | # ) 306 | # assert abs(codes.shape[-1] - expected_num_frames) <= 1 307 | # codes = codes[..., :expected_num_frames] 308 | # return codes.cpu().squeeze(0).permute(1, 0).numpy() 309 | # 310 | # @property 311 | # def frame_shift(self) -> Seconds: 312 | # return self.config.frame_shift 313 | # 314 | # def feature_dim(self, sampling_rate: int) -> int: 315 | # return self.config.num_quantizers 316 | # 317 | # def pad_tensor_list(self, tensor_list, device, padding_value=0): 318 | # # 计算每个张量的长度 319 | # lengths = [tensor.shape[0] for tensor in tensor_list] 320 | # # 使用pad_sequence函数进行填充 321 | # tensor_list = [torch.Tensor(t).to(device) for t in tensor_list] 322 | # padded_tensor = torch.nn.utils.rnn.pad_sequence( 323 | # tensor_list, batch_first=True, padding_value=padding_value 324 | # ) 325 | # return padded_tensor, lengths 326 | # 327 | # def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray: 328 | # samples = [wav.squeeze() for wav in samples] 329 | # device = self.tokenizer.device 330 | # samples, lengths = self.pad_tensor_list(samples, device) 331 | # samples = samples.unsqueeze(1) 332 | # 333 | # if not isinstance(samples, torch.Tensor): 334 | # samples = torch.from_numpy(samples) 335 | # if len(samples.shape) != 3: 336 | # raise ValueError() 337 | # if sampling_rate != self.tokenizer.sample_rate: 338 | # samples = [ 339 | # convert_audio( 340 | # wav, 341 | # sampling_rate, 342 | # self.tokenizer.sample_rate, 343 | # self.tokenizer.channels, 344 | # ) 345 | # for wav in samples 346 | # ] 347 | # # Extract discrete codes from EnCodec 348 | # with torch.no_grad(): 349 | # encoded_frames = self.tokenizer.encode(samples.detach().to(device)) 350 | # encoded_frames = encoded_frames[0][0] # [B, n_q, T] 351 | # batch_codes = [] 352 | # for b, length in enumerate(lengths): 353 | # codes = encoded_frames[b] 354 | # duration = round(length / sampling_rate, ndigits=12) 355 | # expected_num_frames = compute_num_frames( 356 | # duration=duration, 357 | # frame_shift=self.frame_shift, 358 | # sampling_rate=sampling_rate, 359 | # ) 360 | # batch_codes.append(codes[..., :expected_num_frames]) 361 | # return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes] 362 | 363 | 364 | if __name__ == "__main__": 365 | model = EncodecModel.encodec_model_24khz() 366 | model.set_target_bandwidth(6.0) 367 | 368 | samples = torch.from_numpy(np.random.random([4, 1, 1600])).type( 369 | torch.float32 370 | ) 371 | codes_raw = model.encode(samples) 372 | 373 | remove_encodec_weight_norm(model) 374 | codes_norm = model.encode(samples) 375 | 376 | assert torch.allclose(codes_raw[0][0], codes_norm[0][0]) 377 | -------------------------------------------------------------------------------- /descriptions.py: -------------------------------------------------------------------------------- 1 | top_md = """ 2 | # VALL-E X 3 | VALL-E X can synthesize high-quality personalized speech with only a 3-second enrolled recording of 4 | an unseen speaker as an acoustic prompt, even in another language for a monolingual speaker.
5 | This implementation supports zero-shot, mono-lingual/cross-lingual text-to-speech functionality of three languages (English, Chinese, Japanese)
6 | See this [demo](https://plachtaa.github.io/) page for more details. 7 | """ 8 | 9 | infer_from_audio_md = """ 10 | Upload a speech of 3~10 seconds as the audio prompt and type in the text you'd like to synthesize.
11 | The model will synthesize speech of given text with the same voice of your audio prompt.
12 | The model also tends to preserve the emotion & acoustic environment of your given speech.
13 | For faster inference, please use **"Make prompt"** to get a `.npz` file as the encoded audio prompt, and use it by **"Infer from prompt"** 14 | """ 15 | 16 | make_prompt_md = """ 17 | Upload a speech of 3~10 seconds as the audio prompt.
18 | Get a `.npz` file as the encoded audio prompt. Use it by **"Infer with prompt"** 19 | """ 20 | 21 | infer_from_prompt_md = """ 22 | Faster than **"Infer from audio"**.
23 | You need to **"Make prompt"** first, and upload the encoded prompt (a `.npz` file) 24 | """ 25 | 26 | long_text_md = """ 27 | Very long text is chunked into several sentences, and each sentence is synthesized separately.
28 | Please make a prompt or use a preset prompt to infer long text. 29 | """ 30 | 31 | long_text_example = "Just a few years ago, there were no legions of deep learning scientists developing intelligent products and services at major companies and startups. When we entered the field, machine learning did not command headlines in daily newspapers. Our parents had no idea what machine learning was, let alone why we might prefer it to a career in medicine or law. Machine learning was a blue skies academic discipline whose industrial significance was limited to a narrow set of real-world applications, including speech recognition and computer vision. Moreover, many of these applications required so much domain knowledge that they were often regarded as entirely separate areas for which machine learning was one small component. At that time, neural networks—the predecessors of the deep learning methods that we focus on in this book—were generally regarded as outmoded." -------------------------------------------------------------------------------- /examples.py: -------------------------------------------------------------------------------- 1 | infer_from_audio_examples = [ 2 | ["This is how this machine has taken my voice.", 'English', 'no-accent', "prompts/en-2.wav", None, "Wow, look at that! That's no ordinary Teddy bear!"], 3 | ["我喜欢抽电子烟,尤其是锐刻五代。", '中文', 'no-accent', "prompts/zh-1.wav", None, "今天我很荣幸,"], 4 | ["私の声を真似するのはそんなに面白いですか?", '日本語', 'no-accent', "prompts/ja-2.ogg", None, "初めまして、朝武よしのです。"], 5 | ["你可以听得出来我有多困。", '中文', 'no-accent', "prompts/en-1.wav", None, ""], 6 | ["この文は、クロスリンガル合成の例です。", '日本語', 'no-accent', "prompts/zh-2.wav", None, ""], 7 | ["Actually, I can't speak English, but this machine helped me do it.", 'English', 'no-accent', "prompts/ja-1.wav", None, ""], 8 | ] 9 | 10 | make_npz_prompt_examples = [ 11 | ["Gem-trader", "prompts/en-2.wav", None, "Wow, look at that! That's no ordinary Teddy bear!"], 12 | ["Ding Zhen", "prompts/zh-1.wav", None, "今天我很荣幸,"], 13 | ["Yoshino", "prompts/ja-2.ogg", None, "初めまして、朝武よしのです。"], 14 | ["Sleepy-woman", "prompts/en-1.wav", None, ""], 15 | ["Yae", "prompts/zh-2.wav", None, ""], 16 | ["Cafe", "prompts/ja-1.wav", None, ""], 17 | ] 18 | 19 | infer_from_prompt_examples = [ 20 | ["A prompt contains voice, prosody and emotion information of a certain speaker.", "English", "no-accent", "vctk_1", None], 21 | ["This prompt is made with an audio of three seconds.", "English", "no-accent", "librispeech_1", None], 22 | ["This prompt is made with Chinese speech", "English", "no-accent", "seel", None], 23 | ] 24 | 25 | -------------------------------------------------------------------------------- /images/vallex_framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/images/vallex_framework.jpg -------------------------------------------------------------------------------- /macros.py: -------------------------------------------------------------------------------- 1 | NUM_LAYERS = 12 2 | NUM_HEAD = 16 3 | N_DIM = 1024 4 | PREFIX_MODE = 1 5 | NUM_QUANTIZERS = 8 6 | SAMPLE_RATE = 24000 7 | 8 | lang2token = { 9 | 'zh': "[ZH]", 10 | 'ja': "[JA]", 11 | "en": "[EN]", 12 | 'mix': "", 13 | } 14 | 15 | lang2code = { 16 | 'zh': 0, 17 | 'ja': 1, 18 | "en": 2, 19 | } 20 | 21 | token2lang = { 22 | '[ZH]': "zh", 23 | '[JA]': "ja", 24 | "[EN]": "en", 25 | "": "mix" 26 | } 27 | 28 | code2lang = { 29 | 0: 'zh', 30 | 1: 'ja', 31 | 2: "en", 32 | } 33 | 34 | langdropdown2token = { 35 | 'English': "[EN]", 36 | '中文': "[ZH]", 37 | '日本語': "[JA]", 38 | 'Mix': "", 39 | } -------------------------------------------------------------------------------- /model-card.md: -------------------------------------------------------------------------------- 1 | # Model Card: VALL-E X 2 | 3 | **Author**: [Songting](https://github.com/Plachtaa).
4 |
5 | This is the official codebase for running open-sourced VALL-E X. 6 | 7 | The following is additional information about the models released here. 8 | 9 | ## Model Details 10 | 11 | VALL-E X is a series of two transformer models that turn text into audio. 12 | 13 | ### Phoneme to acoustic tokens 14 | - Input: IPAs converted from input text by a rule-based G2P tool. 15 | - Output: tokens from the first codebook of the [EnCodec Codec](https://github.com/facebookresearch/encodec) from facebook 16 | 17 | ### Coarse to fine tokens 18 | - Input: IPAs converted from input text by a rule-based G2P tool & the first codebook from EnCodec 19 | - Output: 8 codebooks from EnCodec 20 | 21 | ### Architecture 22 | | Model | Parameters | Attention | Output Vocab size | 23 | |:------------------------:|:----------:|------------|:-----------------:| 24 | | G2P tool | - | - | 69 | 25 | | Phoneme to coarse tokens | 150 M | Causal | 1x 1,024 | 26 | | Coarse to fine tokens | 150 M | Non-causal | 7x 1,024 | 27 | 28 | ### Release date 29 | August 2023 30 | 31 | ## Broader Implications 32 | We anticipate that this model's text to audio capabilities can be used to improve accessbility tools in a variety of languages. 33 | Straightforward improvements will allow models to run faster than realtime, rendering them useful for applications such as virtual assistants. -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch.nn as nn 4 | # from icefall.utils import AttributeDict, str2bool 5 | 6 | from .macros import ( 7 | NUM_AUDIO_TOKENS, 8 | NUM_MEL_BINS, 9 | NUM_SPEAKER_CLASSES, 10 | NUM_TEXT_TOKENS, 11 | SPEAKER_EMBEDDING_DIM, 12 | ) 13 | from .transformer import Transformer 14 | from .vallex import VALLE, VALLF 15 | from .visualizer import visualize 16 | 17 | 18 | def add_model_arguments(parser: argparse.ArgumentParser): 19 | parser.add_argument( 20 | "--model-name", 21 | type=str, 22 | default="VALL-E", 23 | help="VALL-E, VALL-F, Transformer.", 24 | ) 25 | parser.add_argument( 26 | "--decoder-dim", 27 | type=int, 28 | default=1024, 29 | help="Embedding dimension in the decoder model.", 30 | ) 31 | parser.add_argument( 32 | "--nhead", 33 | type=int, 34 | default=16, 35 | help="Number of attention heads in the Decoder layers.", 36 | ) 37 | parser.add_argument( 38 | "--num-decoder-layers", 39 | type=int, 40 | default=12, 41 | help="Number of Decoder layers.", 42 | ) 43 | parser.add_argument( 44 | "--scale-factor", 45 | type=float, 46 | default=1.0, 47 | help="Model scale factor which will be assigned different meanings in different models.", 48 | ) 49 | parser.add_argument( 50 | "--norm-first", 51 | type=bool, 52 | default=True, 53 | help="Pre or Post Normalization.", 54 | ) 55 | parser.add_argument( 56 | "--add-prenet", 57 | type=bool, 58 | default=False, 59 | help="Whether add PreNet after Inputs.", 60 | ) 61 | 62 | # VALL-E & F 63 | parser.add_argument( 64 | "--prefix-mode", 65 | type=int, 66 | default=1, 67 | help="The mode for how to prefix VALL-E NAR Decoder, " 68 | "0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance.", 69 | ) 70 | parser.add_argument( 71 | "--share-embedding", 72 | type=bool, 73 | default=True, 74 | help="Share the parameters of the output projection layer with the parameters of the acoustic embedding.", 75 | ) 76 | parser.add_argument( 77 | "--prepend-bos", 78 | type=bool, 79 | default=False, 80 | help="Whether prepend to the acoustic tokens -> AR Decoder inputs.", 81 | ) 82 | parser.add_argument( 83 | "--num-quantizers", 84 | type=int, 85 | default=8, 86 | help="Number of Audio/Semantic quantization layers.", 87 | ) 88 | 89 | # Transformer 90 | parser.add_argument( 91 | "--scaling-xformers", 92 | type=bool, 93 | default=False, 94 | help="Apply Reworked Conformer scaling on Transformers.", 95 | ) 96 | 97 | 98 | def get_model(params) -> nn.Module: 99 | if params.model_name.lower() in ["vall-f", "vallf"]: 100 | model = VALLF( 101 | params.decoder_dim, 102 | params.nhead, 103 | params.num_decoder_layers, 104 | norm_first=params.norm_first, 105 | add_prenet=params.add_prenet, 106 | prefix_mode=params.prefix_mode, 107 | share_embedding=params.share_embedding, 108 | nar_scale_factor=params.scale_factor, 109 | prepend_bos=params.prepend_bos, 110 | num_quantizers=params.num_quantizers, 111 | ) 112 | elif params.model_name.lower() in ["vall-e", "valle"]: 113 | model = VALLE( 114 | params.decoder_dim, 115 | params.nhead, 116 | params.num_decoder_layers, 117 | norm_first=params.norm_first, 118 | add_prenet=params.add_prenet, 119 | prefix_mode=params.prefix_mode, 120 | share_embedding=params.share_embedding, 121 | nar_scale_factor=params.scale_factor, 122 | prepend_bos=params.prepend_bos, 123 | num_quantizers=params.num_quantizers, 124 | ) 125 | else: 126 | assert params.model_name in ["Transformer"] 127 | model = Transformer( 128 | params.decoder_dim, 129 | params.nhead, 130 | params.num_decoder_layers, 131 | norm_first=params.norm_first, 132 | add_prenet=params.add_prenet, 133 | scaling_xformers=params.scaling_xformers, 134 | ) 135 | 136 | return model 137 | -------------------------------------------------------------------------------- /models/macros.py: -------------------------------------------------------------------------------- 1 | # Text 2 | NUM_TEXT_TOKENS = 2048 3 | 4 | # Audio 5 | NUM_AUDIO_TOKENS = 1024 # EnCodec RVQ bins 6 | NUM_MEL_BINS = 100 # BigVGAN bigvgan_24khz_100band 7 | 8 | 9 | # Speaker 10 | NUM_SPEAKER_CLASSES = 4096 11 | SPEAKER_EMBEDDING_DIM = 64 12 | -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 (authors: Feiteng Li) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from functools import partial 16 | from typing import Any, Dict, List, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | # from icefall.utils import make_pad_mask 22 | # from torchmetrics.classification import BinaryAccuracy 23 | 24 | from models.vallex import Transpose 25 | from modules.embedding import SinePositionalEmbedding, TokenEmbedding 26 | from modules.scaling import BalancedDoubleSwish, ScaledLinear 27 | from modules.transformer import ( 28 | BalancedBasicNorm, 29 | IdentityNorm, 30 | TransformerDecoderLayer, 31 | TransformerEncoder, 32 | TransformerEncoderLayer, 33 | ) 34 | 35 | from .macros import NUM_MEL_BINS, NUM_TEXT_TOKENS 36 | from .visualizer import visualize 37 | 38 | IdentityNorm = IdentityNorm 39 | 40 | 41 | class Transformer(nn.Module): 42 | """It implements seq2seq Transformer TTS for debug(No StopPredictor and SpeakerEmbeding) 43 | Neural Speech Synthesis with Transformer Network 44 | https://arxiv.org/abs/1809.08895 45 | """ 46 | 47 | def __init__( 48 | self, 49 | d_model: int, 50 | nhead: int, 51 | num_layers: int, 52 | norm_first: bool = True, 53 | add_prenet: bool = False, 54 | scaling_xformers: bool = False, 55 | ): 56 | """ 57 | Args: 58 | d_model: 59 | The number of expected features in the input (required). 60 | nhead: 61 | The number of heads in the multiheadattention models (required). 62 | num_layers: 63 | The number of sub-decoder-layers in the decoder (required). 64 | """ 65 | super().__init__() 66 | self.text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x 67 | 68 | if add_prenet: 69 | self.encoder_prenet = nn.Sequential( 70 | Transpose(), 71 | nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), 72 | nn.BatchNorm1d(d_model), 73 | nn.ReLU(), 74 | nn.Dropout(0.5), 75 | nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), 76 | nn.BatchNorm1d(d_model), 77 | nn.ReLU(), 78 | nn.Dropout(0.5), 79 | nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), 80 | nn.BatchNorm1d(d_model), 81 | nn.ReLU(), 82 | nn.Dropout(0.5), 83 | Transpose(), 84 | nn.Linear(d_model, d_model), 85 | ) 86 | 87 | self.decoder_prenet = nn.Sequential( 88 | nn.Linear(NUM_MEL_BINS, 256), 89 | nn.ReLU(), 90 | nn.Dropout(0.5), 91 | nn.Linear(256, 256), 92 | nn.ReLU(), 93 | nn.Dropout(0.5), 94 | nn.Linear(256, d_model), 95 | ) 96 | 97 | assert scaling_xformers is False # TODO: update this block 98 | else: 99 | self.encoder_prenet = nn.Identity() 100 | if scaling_xformers: 101 | self.decoder_prenet = ScaledLinear(NUM_MEL_BINS, d_model) 102 | else: 103 | self.decoder_prenet = nn.Linear(NUM_MEL_BINS, d_model) 104 | 105 | self.encoder_position = SinePositionalEmbedding( 106 | d_model, 107 | dropout=0.1, 108 | scale=False, 109 | ) 110 | self.decoder_position = SinePositionalEmbedding( 111 | d_model, dropout=0.1, scale=False 112 | ) 113 | 114 | if scaling_xformers: 115 | self.encoder = TransformerEncoder( 116 | TransformerEncoderLayer( 117 | d_model, 118 | nhead, 119 | dim_feedforward=d_model * 4, 120 | dropout=0.1, 121 | batch_first=True, 122 | norm_first=norm_first, 123 | linear1_self_attention_cls=ScaledLinear, 124 | linear2_self_attention_cls=partial( 125 | ScaledLinear, initial_scale=0.01 126 | ), 127 | linear1_feedforward_cls=ScaledLinear, 128 | linear2_feedforward_cls=partial( 129 | ScaledLinear, initial_scale=0.01 130 | ), 131 | activation=partial( 132 | BalancedDoubleSwish, 133 | channel_dim=-1, 134 | max_abs=10.0, 135 | min_prob=0.25, 136 | ), 137 | layer_norm_cls=IdentityNorm, 138 | ), 139 | num_layers=num_layers, 140 | norm=BalancedBasicNorm(d_model) if norm_first else None, 141 | ) 142 | 143 | self.decoder = nn.TransformerDecoder( 144 | TransformerDecoderLayer( 145 | d_model, 146 | nhead, 147 | dim_feedforward=d_model * 4, 148 | dropout=0.1, 149 | batch_first=True, 150 | norm_first=norm_first, 151 | linear1_self_attention_cls=ScaledLinear, 152 | linear2_self_attention_cls=partial( 153 | ScaledLinear, initial_scale=0.01 154 | ), 155 | linear1_feedforward_cls=ScaledLinear, 156 | linear2_feedforward_cls=partial( 157 | ScaledLinear, initial_scale=0.01 158 | ), 159 | activation=partial( 160 | BalancedDoubleSwish, 161 | channel_dim=-1, 162 | max_abs=10.0, 163 | min_prob=0.25, 164 | ), 165 | layer_norm_cls=IdentityNorm, 166 | ), 167 | num_layers=num_layers, 168 | norm=BalancedBasicNorm(d_model) if norm_first else None, 169 | ) 170 | 171 | self.predict_layer = ScaledLinear(d_model, NUM_MEL_BINS) 172 | self.stop_layer = nn.Linear(d_model, 1) 173 | else: 174 | self.encoder = nn.TransformerEncoder( 175 | nn.TransformerEncoderLayer( 176 | d_model, 177 | nhead, 178 | dim_feedforward=d_model * 4, 179 | activation=F.relu, 180 | dropout=0.1, 181 | batch_first=True, 182 | norm_first=norm_first, 183 | ), 184 | num_layers=num_layers, 185 | norm=nn.LayerNorm(d_model) if norm_first else None, 186 | ) 187 | 188 | self.decoder = nn.TransformerDecoder( 189 | nn.TransformerDecoderLayer( 190 | d_model, 191 | nhead, 192 | dim_feedforward=d_model * 4, 193 | activation=F.relu, 194 | dropout=0.1, 195 | batch_first=True, 196 | norm_first=norm_first, 197 | ), 198 | num_layers=num_layers, 199 | norm=nn.LayerNorm(d_model) if norm_first else None, 200 | ) 201 | 202 | self.predict_layer = nn.Linear(d_model, NUM_MEL_BINS) 203 | self.stop_layer = nn.Linear(d_model, 1) 204 | 205 | self.stop_accuracy_metric = BinaryAccuracy( 206 | threshold=0.5, multidim_average="global" 207 | ) 208 | 209 | # self.apply(self._init_weights) 210 | 211 | # def _init_weights(self, module): 212 | # if isinstance(module, (nn.Linear)): 213 | # module.weight.data.normal_(mean=0.0, std=0.02) 214 | # if isinstance(module, nn.Linear) and module.bias is not None: 215 | # module.bias.data.zero_() 216 | # elif isinstance(module, nn.LayerNorm): 217 | # module.bias.data.zero_() 218 | # module.weight.data.fill_(1.0) 219 | # elif isinstance(module, nn.Embedding): 220 | # module.weight.data.normal_(mean=0.0, std=0.02) 221 | 222 | def forward( 223 | self, 224 | x: torch.Tensor, 225 | x_lens: torch.Tensor, 226 | y: torch.Tensor, 227 | y_lens: torch.Tensor, 228 | reduction: str = "sum", 229 | train_stage: int = 0, 230 | **kwargs, 231 | ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: 232 | """ 233 | Args: 234 | x: 235 | A 2-D tensor of shape (N, S). 236 | x_lens: 237 | A 1-D tensor of shape (N,). It contains the number of tokens in `x` 238 | before padding. 239 | y: 240 | A 3-D tensor of shape (N, T, 8). 241 | y_lens: 242 | A 1-D tensor of shape (N,). It contains the number of tokens in `x` 243 | before padding. 244 | train_stage: 245 | Not used in this model. 246 | Returns: 247 | Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy. 248 | """ 249 | del train_stage 250 | 251 | assert x.ndim == 2, x.shape 252 | assert x_lens.ndim == 1, x_lens.shape 253 | assert y.ndim == 3, y.shape 254 | assert y_lens.ndim == 1, y_lens.shape 255 | 256 | assert torch.all(x_lens > 0) 257 | 258 | # NOTE: x has been padded in TextTokenCollater 259 | x_mask = make_pad_mask(x_lens).to(x.device) 260 | 261 | x = self.text_embedding(x) 262 | x = self.encoder_prenet(x) 263 | x = self.encoder_position(x) 264 | x = self.encoder(x, src_key_padding_mask=x_mask) 265 | 266 | total_loss, metrics = 0.0, {} 267 | 268 | y_mask = make_pad_mask(y_lens).to(y.device) 269 | y_mask_float = y_mask.type(torch.float32) 270 | data_mask = 1.0 - y_mask_float.unsqueeze(-1) 271 | 272 | # Training 273 | # AR Decoder 274 | def pad_y(y): 275 | y = F.pad(y, (0, 0, 1, 0, 0, 0), value=0).detach() 276 | # inputs, targets 277 | return y[:, :-1], y[:, 1:] 278 | 279 | y, targets = pad_y(y * data_mask) # mask padding as zeros 280 | 281 | y_emb = self.decoder_prenet(y) 282 | y_pos = self.decoder_position(y_emb) 283 | 284 | y_len = y_lens.max() 285 | tgt_mask = torch.triu( 286 | torch.ones(y_len, y_len, device=y.device, dtype=torch.bool), 287 | diagonal=1, 288 | ) 289 | y_dec = self.decoder( 290 | y_pos, 291 | x, 292 | tgt_mask=tgt_mask, 293 | memory_key_padding_mask=x_mask, 294 | ) 295 | 296 | predict = self.predict_layer(y_dec) 297 | # loss 298 | total_loss = F.mse_loss(predict, targets, reduction=reduction) 299 | 300 | logits = self.stop_layer(y_dec).squeeze(-1) 301 | stop_loss = F.binary_cross_entropy_with_logits( 302 | logits, 303 | y_mask_float.detach(), 304 | weight=1.0 + y_mask_float.detach() * 4.0, 305 | reduction=reduction, 306 | ) 307 | metrics["stop_loss"] = stop_loss.detach() 308 | 309 | stop_accuracy = self.stop_accuracy_metric( 310 | (torch.sigmoid(logits) >= 0.5).type(torch.int64), 311 | y_mask.type(torch.int64), 312 | ) 313 | # icefall MetricsTracker.norm_items() 314 | metrics["stop_accuracy"] = stop_accuracy.item() * y_lens.sum().type( 315 | torch.float32 316 | ) 317 | 318 | return ((x, predict), total_loss + 100.0 * stop_loss, metrics) 319 | 320 | def inference( 321 | self, 322 | x: torch.Tensor, 323 | x_lens: torch.Tensor, 324 | y: Any = None, 325 | **kwargs, 326 | ) -> torch.Tensor: 327 | """ 328 | Args: 329 | x: 330 | A 2-D tensor of shape (1, S). 331 | x_lens: 332 | A 1-D tensor of shape (1,). It contains the number of tokens in `x` 333 | before padding. 334 | Returns: 335 | Return the predicted audio code matrix and cross-entropy loss. 336 | """ 337 | assert x.ndim == 2, x.shape 338 | assert x_lens.ndim == 1, x_lens.shape 339 | 340 | assert torch.all(x_lens > 0) 341 | 342 | x_mask = make_pad_mask(x_lens).to(x.device) 343 | 344 | x = self.text_embedding(x) 345 | x = self.encoder_prenet(x) 346 | x = self.encoder_position(x) 347 | x = self.encoder(x, src_key_padding_mask=x_mask) 348 | 349 | x_mask = make_pad_mask(x_lens).to(x.device) 350 | 351 | # AR Decoder 352 | # TODO: Managing decoder steps avoid repetitive computation 353 | y = torch.zeros( 354 | [x.shape[0], 1, NUM_MEL_BINS], dtype=torch.float32, device=x.device 355 | ) 356 | while True: 357 | y_emb = self.decoder_prenet(y) 358 | y_pos = self.decoder_position(y_emb) 359 | 360 | tgt_mask = torch.triu( 361 | torch.ones( 362 | y.shape[1], y.shape[1], device=y.device, dtype=torch.bool 363 | ), 364 | diagonal=1, 365 | ) 366 | 367 | y_dec = self.decoder( 368 | y_pos, 369 | x, 370 | tgt_mask=tgt_mask, 371 | memory_mask=None, 372 | memory_key_padding_mask=x_mask, 373 | ) 374 | predict = self.predict_layer(y_dec[:, -1:]) 375 | 376 | logits = self.stop_layer(y_dec[:, -1:]) > 0 # sigmoid(0.0) = 0.5 377 | if y.shape[1] > x_lens.max() * 10 or all(logits.cpu().numpy()): 378 | print( 379 | f"TransformerTTS EOS [Text {x_lens[0]} -> Audio {y.shape[1]}]" 380 | ) 381 | break 382 | 383 | y = torch.concat([y, predict], dim=1) 384 | 385 | return y[:, 1:] 386 | 387 | def visualize( 388 | self, 389 | predicts: Tuple[torch.Tensor], 390 | batch: Dict[str, Union[List, torch.Tensor]], 391 | output_dir: str, 392 | limit: int = 4, 393 | ) -> None: 394 | visualize(predicts, batch, output_dir, limit=limit) 395 | -------------------------------------------------------------------------------- /models/visualizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2023 (authors: Feiteng Li) 3 | # 4 | # See ../../../../LICENSE for clarification regarding multiple authors 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | 19 | from typing import Dict, List, Tuple, Union 20 | 21 | import matplotlib.pyplot as plt 22 | import numpy as np 23 | import torch 24 | 25 | 26 | def visualize( 27 | predicts: Tuple[torch.Tensor], 28 | batch: Dict[str, Union[List, torch.Tensor]], 29 | output_dir: str, 30 | limit: int = 4, 31 | ) -> None: 32 | text_tokens = batch["text_tokens"].to("cpu").detach().numpy() 33 | text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy() 34 | audio_features = batch["audio_features"].to("cpu").detach().numpy() 35 | audio_features_lens = ( 36 | batch["audio_features_lens"].to("cpu").detach().numpy() 37 | ) 38 | assert text_tokens.ndim == 2 39 | 40 | utt_ids, texts = batch["utt_id"], batch["text"] 41 | 42 | encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy() 43 | decoder_outputs = predicts[1] 44 | if isinstance(decoder_outputs, list): 45 | decoder_outputs = decoder_outputs[-1] 46 | decoder_outputs = ( 47 | decoder_outputs.to("cpu").type(torch.float32).detach().numpy() 48 | ) 49 | 50 | vmin, vmax = 0, 1024 # Encodec 51 | if decoder_outputs.dtype == np.float32: 52 | vmin, vmax = -6, 0 # Fbank 53 | 54 | num_figures = 3 55 | for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])): 56 | _ = plt.figure(figsize=(14, 8 * num_figures)) 57 | 58 | S = text_tokens_lens[b] 59 | T = audio_features_lens[b] 60 | 61 | # encoder 62 | plt.subplot(num_figures, 1, 1) 63 | plt.title(f"Text: {text}") 64 | plt.imshow( 65 | X=np.transpose(encoder_outputs[b]), 66 | cmap=plt.get_cmap("jet"), 67 | aspect="auto", 68 | interpolation="nearest", 69 | ) 70 | plt.gca().invert_yaxis() 71 | plt.axvline(x=S - 0.4, linewidth=2, color="r") 72 | plt.xlabel("Encoder Output") 73 | plt.colorbar() 74 | 75 | # decoder 76 | plt.subplot(num_figures, 1, 2) 77 | plt.imshow( 78 | X=np.transpose(decoder_outputs[b]), 79 | cmap=plt.get_cmap("jet"), 80 | aspect="auto", 81 | interpolation="nearest", 82 | vmin=vmin, 83 | vmax=vmax, 84 | ) 85 | plt.gca().invert_yaxis() 86 | plt.axvline(x=T - 0.4, linewidth=2, color="r") 87 | plt.xlabel("Decoder Output") 88 | plt.colorbar() 89 | 90 | # target 91 | plt.subplot(num_figures, 1, 3) 92 | plt.imshow( 93 | X=np.transpose(audio_features[b]), 94 | cmap=plt.get_cmap("jet"), 95 | aspect="auto", 96 | interpolation="nearest", 97 | vmin=vmin, 98 | vmax=vmax, 99 | ) 100 | plt.gca().invert_yaxis() 101 | plt.axvline(x=T - 0.4, linewidth=2, color="r") 102 | plt.xlabel("Decoder Target") 103 | plt.colorbar() 104 | 105 | plt.savefig(f"{output_dir}/{utt_id}.png") 106 | plt.close() 107 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/modules/__init__.py -------------------------------------------------------------------------------- /modules/embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 (authors: Feiteng Li) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | 21 | class TokenEmbedding(nn.Module): 22 | def __init__( 23 | self, 24 | dim_model: int, 25 | vocab_size: int, 26 | dropout: float = 0.0, 27 | ): 28 | super().__init__() 29 | 30 | self.vocab_size = vocab_size 31 | self.dim_model = dim_model 32 | 33 | self.dropout = torch.nn.Dropout(p=dropout) 34 | self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model) 35 | 36 | @property 37 | def weight(self) -> torch.Tensor: 38 | return self.word_embeddings.weight 39 | 40 | def embedding(self, index: int) -> torch.Tensor: 41 | return self.word_embeddings.weight[index : index + 1] 42 | 43 | def forward(self, x: torch.Tensor): 44 | X = self.word_embeddings(x) 45 | X = self.dropout(X) 46 | 47 | return X 48 | 49 | 50 | class SinePositionalEmbedding(nn.Module): 51 | def __init__( 52 | self, 53 | dim_model: int, 54 | dropout: float = 0.0, 55 | scale: bool = False, 56 | alpha: bool = False, 57 | ): 58 | super().__init__() 59 | self.dim_model = dim_model 60 | self.x_scale = math.sqrt(dim_model) if scale else 1.0 61 | self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) 62 | self.dropout = torch.nn.Dropout(p=dropout) 63 | 64 | self.reverse = False 65 | self.pe = None 66 | self.extend_pe(torch.tensor(0.0).expand(1, 4000)) 67 | 68 | def extend_pe(self, x): 69 | """Reset the positional encodings.""" 70 | if self.pe is not None: 71 | if self.pe.size(1) >= x.size(1): 72 | if self.pe.dtype != x.dtype or self.pe.device != x.device: 73 | self.pe = self.pe.to(dtype=x.dtype, device=x.device) 74 | return 75 | pe = torch.zeros(x.size(1), self.dim_model) 76 | if self.reverse: 77 | position = torch.arange( 78 | x.size(1) - 1, -1, -1.0, dtype=torch.float32 79 | ).unsqueeze(1) 80 | else: 81 | position = torch.arange( 82 | 0, x.size(1), dtype=torch.float32 83 | ).unsqueeze(1) 84 | div_term = torch.exp( 85 | torch.arange(0, self.dim_model, 2, dtype=torch.float32) 86 | * -(math.log(10000.0) / self.dim_model) 87 | ) 88 | pe[:, 0::2] = torch.sin(position * div_term) 89 | pe[:, 1::2] = torch.cos(position * div_term) 90 | pe = pe.unsqueeze(0) 91 | self.pe = pe.to(device=x.device, dtype=x.dtype).detach() 92 | 93 | def forward(self, x: torch.Tensor) -> torch.Tensor: 94 | self.extend_pe(x) 95 | output = x.unsqueeze(-1) if x.ndim == 2 else x 96 | output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)] 97 | return self.dropout(output) 98 | -------------------------------------------------------------------------------- /modules/scheduler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2023 (authors: Feiteng Li) 3 | # 4 | # See ../../../../LICENSE for clarification regarding multiple authors 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | 19 | import torch 20 | 21 | from modules.optim import Eden 22 | 23 | 24 | def calc_lr(step, dim_embed, warmup_steps): 25 | return dim_embed ** (-0.5) * min( 26 | step ** (-0.5), step * warmup_steps ** (-1.5) 27 | ) 28 | 29 | 30 | class NoamScheduler(torch.optim.lr_scheduler._LRScheduler): 31 | def __init__( 32 | self, 33 | base_lr: float, 34 | optimizer: torch.optim.Optimizer, 35 | dim_embed: int, 36 | warmup_steps: int, 37 | last_epoch: int = -1, 38 | verbose: bool = False, 39 | ) -> None: 40 | 41 | self.dim_embed = dim_embed 42 | self.base_lr = base_lr 43 | self.warmup_steps = warmup_steps 44 | self.num_param_groups = len(optimizer.param_groups) 45 | 46 | super().__init__(optimizer, last_epoch, verbose) 47 | 48 | def get_lr(self) -> float: 49 | lr = self.base_lr * calc_lr( 50 | self._step_count, self.dim_embed, self.warmup_steps 51 | ) 52 | return [lr] * self.num_param_groups 53 | 54 | def set_step(self, step: int): 55 | self._step_count = step 56 | 57 | 58 | def get_scheduler(params, optimizer): 59 | if params.scheduler_name.lower() == "eden": 60 | scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps) 61 | elif params.scheduler_name.lower() == "noam": 62 | scheduler = NoamScheduler( 63 | params.base_lr, 64 | optimizer, 65 | params.decoder_dim, 66 | warmup_steps=params.warmup_steps, 67 | ) 68 | # scheduler.set_step(params.start_batch or params.batch_idx_train) 69 | elif params.scheduler_name.lower() == "cosine": 70 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 71 | params.warmup_steps, 72 | optimizer, 73 | eta_min=params.base_lr, 74 | ) 75 | else: 76 | raise NotImplementedError(f"{params.scheduler_name}") 77 | 78 | return scheduler 79 | -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/.DS_Store -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/README: -------------------------------------------------------------------------------- 1 | Pretrained Punkt Models -- Jan Strunk (New version trained after issues 313 and 514 had been corrected) 2 | 3 | Most models were prepared using the test corpora from Kiss and Strunk (2006). Additional models have 4 | been contributed by various people using NLTK for sentence boundary detection. 5 | 6 | For information about how to use these models, please confer the tokenization HOWTO: 7 | http://nltk.googlecode.com/svn/trunk/doc/howto/tokenize.html 8 | and chapter 3.8 of the NLTK book: 9 | http://nltk.googlecode.com/svn/trunk/doc/book/ch03.html#sec-segmentation 10 | 11 | There are pretrained tokenizers for the following languages: 12 | 13 | File Language Source Contents Size of training corpus(in tokens) Model contributed by 14 | ======================================================================================================================================================================= 15 | czech.pickle Czech Multilingual Corpus 1 (ECI) Lidove Noviny ~345,000 Jan Strunk / Tibor Kiss 16 | Literarni Noviny 17 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 18 | danish.pickle Danish Avisdata CD-Rom Ver. 1.1. 1995 Berlingske Tidende ~550,000 Jan Strunk / Tibor Kiss 19 | (Berlingske Avisdata, Copenhagen) Weekend Avisen 20 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 21 | dutch.pickle Dutch Multilingual Corpus 1 (ECI) De Limburger ~340,000 Jan Strunk / Tibor Kiss 22 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 23 | english.pickle English Penn Treebank (LDC) Wall Street Journal ~469,000 Jan Strunk / Tibor Kiss 24 | (American) 25 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 26 | estonian.pickle Estonian University of Tartu, Estonia Eesti Ekspress ~359,000 Jan Strunk / Tibor Kiss 27 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 28 | finnish.pickle Finnish Finnish Parole Corpus, Finnish Books and major national ~364,000 Jan Strunk / Tibor Kiss 29 | Text Bank (Suomen Kielen newspapers 30 | Tekstipankki) 31 | Finnish Center for IT Science 32 | (CSC) 33 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 34 | french.pickle French Multilingual Corpus 1 (ECI) Le Monde ~370,000 Jan Strunk / Tibor Kiss 35 | (European) 36 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 37 | german.pickle German Neue Zürcher Zeitung AG Neue Zürcher Zeitung ~847,000 Jan Strunk / Tibor Kiss 38 | (Switzerland) CD-ROM 39 | (Uses "ss" 40 | instead of "ß") 41 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 42 | greek.pickle Greek Efstathios Stamatatos To Vima (TO BHMA) ~227,000 Jan Strunk / Tibor Kiss 43 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 44 | italian.pickle Italian Multilingual Corpus 1 (ECI) La Stampa, Il Mattino ~312,000 Jan Strunk / Tibor Kiss 45 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 46 | norwegian.pickle Norwegian Centre for Humanities Bergens Tidende ~479,000 Jan Strunk / Tibor Kiss 47 | (Bokmål and Information Technologies, 48 | Nynorsk) Bergen 49 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 50 | polish.pickle Polish Polish National Corpus Literature, newspapers, etc. ~1,000,000 Krzysztof Langner 51 | (http://www.nkjp.pl/) 52 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 53 | portuguese.pickle Portuguese CETENFolha Corpus Folha de São Paulo ~321,000 Jan Strunk / Tibor Kiss 54 | (Brazilian) (Linguateca) 55 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 56 | slovene.pickle Slovene TRACTOR Delo ~354,000 Jan Strunk / Tibor Kiss 57 | Slovene Academy for Arts 58 | and Sciences 59 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 60 | spanish.pickle Spanish Multilingual Corpus 1 (ECI) Sur ~353,000 Jan Strunk / Tibor Kiss 61 | (European) 62 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 63 | swedish.pickle Swedish Multilingual Corpus 1 (ECI) Dagens Nyheter ~339,000 Jan Strunk / Tibor Kiss 64 | (and some other texts) 65 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 66 | turkish.pickle Turkish METU Turkish Corpus Milliyet ~333,000 Jan Strunk / Tibor Kiss 67 | (Türkçe Derlem Projesi) 68 | University of Ankara 69 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 70 | 71 | The corpora contained about 400,000 tokens on average and mostly consisted of newspaper text converted to 72 | Unicode using the codecs module. 73 | 74 | Kiss, Tibor and Strunk, Jan (2006): Unsupervised Multilingual Sentence Boundary Detection. 75 | Computational Linguistics 32: 485-525. 76 | 77 | ---- Training Code ---- 78 | 79 | # import punkt 80 | import nltk.tokenize.punkt 81 | 82 | # Make a new Tokenizer 83 | tokenizer = nltk.tokenize.punkt.PunktSentenceTokenizer() 84 | 85 | # Read in training corpus (one example: Slovene) 86 | import codecs 87 | text = codecs.open("slovene.plain","Ur","iso-8859-2").read() 88 | 89 | # Train tokenizer 90 | tokenizer.train(text) 91 | 92 | # Dump pickled tokenizer 93 | import pickle 94 | out = open("slovene.pickle","wb") 95 | pickle.dump(tokenizer, out) 96 | out.close() 97 | 98 | --------- 99 | -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/czech.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/czech.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/danish.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/danish.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/dutch.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/dutch.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/english.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/english.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/estonian.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/estonian.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/finnish.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/finnish.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/french.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/french.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/german.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/german.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/greek.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/greek.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/italian.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/italian.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/malayalam.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/malayalam.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/norwegian.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/norwegian.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/polish.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/polish.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/portuguese.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/portuguese.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/russian.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/russian.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/slovene.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/slovene.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/spanish.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/spanish.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/swedish.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/swedish.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/turkish.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/PY3/turkish.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/README: -------------------------------------------------------------------------------- 1 | Pretrained Punkt Models -- Jan Strunk (New version trained after issues 313 and 514 had been corrected) 2 | 3 | Most models were prepared using the test corpora from Kiss and Strunk (2006). Additional models have 4 | been contributed by various people using NLTK for sentence boundary detection. 5 | 6 | For information about how to use these models, please confer the tokenization HOWTO: 7 | http://nltk.googlecode.com/svn/trunk/doc/howto/tokenize.html 8 | and chapter 3.8 of the NLTK book: 9 | http://nltk.googlecode.com/svn/trunk/doc/book/ch03.html#sec-segmentation 10 | 11 | There are pretrained tokenizers for the following languages: 12 | 13 | File Language Source Contents Size of training corpus(in tokens) Model contributed by 14 | ======================================================================================================================================================================= 15 | czech.pickle Czech Multilingual Corpus 1 (ECI) Lidove Noviny ~345,000 Jan Strunk / Tibor Kiss 16 | Literarni Noviny 17 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 18 | danish.pickle Danish Avisdata CD-Rom Ver. 1.1. 1995 Berlingske Tidende ~550,000 Jan Strunk / Tibor Kiss 19 | (Berlingske Avisdata, Copenhagen) Weekend Avisen 20 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 21 | dutch.pickle Dutch Multilingual Corpus 1 (ECI) De Limburger ~340,000 Jan Strunk / Tibor Kiss 22 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 23 | english.pickle English Penn Treebank (LDC) Wall Street Journal ~469,000 Jan Strunk / Tibor Kiss 24 | (American) 25 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 26 | estonian.pickle Estonian University of Tartu, Estonia Eesti Ekspress ~359,000 Jan Strunk / Tibor Kiss 27 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 28 | finnish.pickle Finnish Finnish Parole Corpus, Finnish Books and major national ~364,000 Jan Strunk / Tibor Kiss 29 | Text Bank (Suomen Kielen newspapers 30 | Tekstipankki) 31 | Finnish Center for IT Science 32 | (CSC) 33 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 34 | french.pickle French Multilingual Corpus 1 (ECI) Le Monde ~370,000 Jan Strunk / Tibor Kiss 35 | (European) 36 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 37 | german.pickle German Neue Zürcher Zeitung AG Neue Zürcher Zeitung ~847,000 Jan Strunk / Tibor Kiss 38 | (Switzerland) CD-ROM 39 | (Uses "ss" 40 | instead of "ß") 41 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 42 | greek.pickle Greek Efstathios Stamatatos To Vima (TO BHMA) ~227,000 Jan Strunk / Tibor Kiss 43 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 44 | italian.pickle Italian Multilingual Corpus 1 (ECI) La Stampa, Il Mattino ~312,000 Jan Strunk / Tibor Kiss 45 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 46 | norwegian.pickle Norwegian Centre for Humanities Bergens Tidende ~479,000 Jan Strunk / Tibor Kiss 47 | (Bokmål and Information Technologies, 48 | Nynorsk) Bergen 49 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 50 | polish.pickle Polish Polish National Corpus Literature, newspapers, etc. ~1,000,000 Krzysztof Langner 51 | (http://www.nkjp.pl/) 52 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 53 | portuguese.pickle Portuguese CETENFolha Corpus Folha de São Paulo ~321,000 Jan Strunk / Tibor Kiss 54 | (Brazilian) (Linguateca) 55 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 56 | slovene.pickle Slovene TRACTOR Delo ~354,000 Jan Strunk / Tibor Kiss 57 | Slovene Academy for Arts 58 | and Sciences 59 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 60 | spanish.pickle Spanish Multilingual Corpus 1 (ECI) Sur ~353,000 Jan Strunk / Tibor Kiss 61 | (European) 62 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 63 | swedish.pickle Swedish Multilingual Corpus 1 (ECI) Dagens Nyheter ~339,000 Jan Strunk / Tibor Kiss 64 | (and some other texts) 65 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 66 | turkish.pickle Turkish METU Turkish Corpus Milliyet ~333,000 Jan Strunk / Tibor Kiss 67 | (Türkçe Derlem Projesi) 68 | University of Ankara 69 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 70 | 71 | The corpora contained about 400,000 tokens on average and mostly consisted of newspaper text converted to 72 | Unicode using the codecs module. 73 | 74 | Kiss, Tibor and Strunk, Jan (2006): Unsupervised Multilingual Sentence Boundary Detection. 75 | Computational Linguistics 32: 485-525. 76 | 77 | ---- Training Code ---- 78 | 79 | # import punkt 80 | import nltk.tokenize.punkt 81 | 82 | # Make a new Tokenizer 83 | tokenizer = nltk.tokenize.punkt.PunktSentenceTokenizer() 84 | 85 | # Read in training corpus (one example: Slovene) 86 | import codecs 87 | text = codecs.open("slovene.plain","Ur","iso-8859-2").read() 88 | 89 | # Train tokenizer 90 | tokenizer.train(text) 91 | 92 | # Dump pickled tokenizer 93 | import pickle 94 | out = open("slovene.pickle","wb") 95 | pickle.dump(tokenizer, out) 96 | out.close() 97 | 98 | --------- 99 | -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/czech.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/czech.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/danish.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/danish.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/dutch.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/dutch.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/estonian.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/estonian.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/finnish.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/finnish.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/french.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/french.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/german.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/german.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/italian.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/italian.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/malayalam.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/malayalam.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/norwegian.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/norwegian.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/polish.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/polish.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/portuguese.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/portuguese.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/russian.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/russian.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/slovene.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/slovene.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/spanish.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/spanish.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/swedish.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/swedish.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/turkish.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/nltk_data/tokenizers/punkt/turkish.pickle -------------------------------------------------------------------------------- /presets/acou_1.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/acou_1.npz -------------------------------------------------------------------------------- /presets/acou_2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/acou_2.npz -------------------------------------------------------------------------------- /presets/acou_3.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/acou_3.npz -------------------------------------------------------------------------------- /presets/acou_4.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/acou_4.npz -------------------------------------------------------------------------------- /presets/alan.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/alan.npz -------------------------------------------------------------------------------- /presets/amused.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/amused.npz -------------------------------------------------------------------------------- /presets/anger.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/anger.npz -------------------------------------------------------------------------------- /presets/babara.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/babara.npz -------------------------------------------------------------------------------- /presets/bronya.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/bronya.npz -------------------------------------------------------------------------------- /presets/cafe.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/cafe.npz -------------------------------------------------------------------------------- /presets/dingzhen.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/dingzhen.npz -------------------------------------------------------------------------------- /presets/disgust.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/disgust.npz -------------------------------------------------------------------------------- /presets/emo_amused.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/emo_amused.npz -------------------------------------------------------------------------------- /presets/emo_anger.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/emo_anger.npz -------------------------------------------------------------------------------- /presets/emo_neutral.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/emo_neutral.npz -------------------------------------------------------------------------------- /presets/emo_sleepy.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/emo_sleepy.npz -------------------------------------------------------------------------------- /presets/emotion_sleepiness.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/emotion_sleepiness.npz -------------------------------------------------------------------------------- /presets/en2zh_tts_1.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/en2zh_tts_1.npz -------------------------------------------------------------------------------- /presets/en2zh_tts_2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/en2zh_tts_2.npz -------------------------------------------------------------------------------- /presets/en2zh_tts_3.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/en2zh_tts_3.npz -------------------------------------------------------------------------------- /presets/en2zh_tts_4.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/en2zh_tts_4.npz -------------------------------------------------------------------------------- /presets/esta.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/esta.npz -------------------------------------------------------------------------------- /presets/fuxuan.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/fuxuan.npz -------------------------------------------------------------------------------- /presets/librispeech_1.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/librispeech_1.npz -------------------------------------------------------------------------------- /presets/librispeech_2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/librispeech_2.npz -------------------------------------------------------------------------------- /presets/librispeech_3.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/librispeech_3.npz -------------------------------------------------------------------------------- /presets/librispeech_4.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/librispeech_4.npz -------------------------------------------------------------------------------- /presets/neutral.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/neutral.npz -------------------------------------------------------------------------------- /presets/paimon.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/paimon.npz -------------------------------------------------------------------------------- /presets/rosalia.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/rosalia.npz -------------------------------------------------------------------------------- /presets/seel.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/seel.npz -------------------------------------------------------------------------------- /presets/sleepiness.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/sleepiness.npz -------------------------------------------------------------------------------- /presets/vctk_1.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/vctk_1.npz -------------------------------------------------------------------------------- /presets/vctk_2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/vctk_2.npz -------------------------------------------------------------------------------- /presets/vctk_3.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/vctk_3.npz -------------------------------------------------------------------------------- /presets/vctk_4.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/vctk_4.npz -------------------------------------------------------------------------------- /presets/yaesakura.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/yaesakura.npz -------------------------------------------------------------------------------- /presets/zh2en_tts_1.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/zh2en_tts_1.npz -------------------------------------------------------------------------------- /presets/zh2en_tts_2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/zh2en_tts_2.npz -------------------------------------------------------------------------------- /presets/zh2en_tts_3.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/zh2en_tts_3.npz -------------------------------------------------------------------------------- /presets/zh2en_tts_4.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/presets/zh2en_tts_4.npz -------------------------------------------------------------------------------- /prompts/en-1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/prompts/en-1.wav -------------------------------------------------------------------------------- /prompts/en-2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/prompts/en-2.wav -------------------------------------------------------------------------------- /prompts/ja-1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/prompts/ja-1.wav -------------------------------------------------------------------------------- /prompts/ja-2.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/prompts/ja-2.ogg -------------------------------------------------------------------------------- /prompts/ph.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/prompts/ph.txt -------------------------------------------------------------------------------- /prompts/zh-1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/prompts/zh-1.wav -------------------------------------------------------------------------------- /prompts/zh-2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/prompts/zh-2.wav -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | soundfile 2 | numpy 3 | torch 4 | torchvision 5 | torchaudio 6 | tokenizers 7 | encodec 8 | langid 9 | wget 10 | unidecode 11 | pyopenjtalk-prebuilt 12 | pypinyin 13 | inflect 14 | cn2an 15 | jieba 16 | eng_to_ipa 17 | openai-whisper 18 | phonemizer==3.2.0 19 | matplotlib 20 | gradio 21 | nltk 22 | sudachipy 23 | sudachidict_core 24 | vocos -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | from data.dataset import create_dataloader 5 | from macros import * 6 | from data.tokenizer import ( 7 | AudioTokenizer, 8 | tokenize_audio, 9 | ) 10 | from data.collation import get_text_token_collater 11 | from models.vallex import VALLE 12 | if torch.cuda.is_available(): 13 | device = torch.device("cuda", 0) 14 | from vocos import Vocos 15 | 16 | def get_model(device): 17 | url = 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt' 18 | 19 | checkpoints_dir = "./checkpoints" 20 | 21 | model_checkpoint_name = "vallex-checkpoint_modified.pt" 22 | if not os.path.exists(checkpoints_dir): os.mkdir(checkpoints_dir) 23 | if not os.path.exists(os.path.join(checkpoints_dir, model_checkpoint_name)): 24 | import wget 25 | print("3") 26 | try: 27 | logging.info( 28 | "Downloading model from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt ...") 29 | # download from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt to ./checkpoints/vallex-checkpoint.pt 30 | wget.download("https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt", 31 | out="./checkpoints/vallex-checkpoint.pt", bar=wget.bar_adaptive) 32 | except Exception as e: 33 | logging.info(e) 34 | raise Exception( 35 | "\n Model weights download failed, please go to 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'" 36 | "\n manually download model weights and put it to {} .".format(os.getcwd() + "\checkpoints")) 37 | # VALL-E 38 | model = VALLE( 39 | N_DIM, 40 | NUM_HEAD, 41 | NUM_LAYERS, 42 | norm_first=True, 43 | add_prenet=False, 44 | prefix_mode=PREFIX_MODE, 45 | share_embedding=True, 46 | nar_scale_factor=1.0, 47 | prepend_bos=True, 48 | num_quantizers=NUM_QUANTIZERS, 49 | ).to(device) 50 | checkpoint = torch.load(os.path.join(checkpoints_dir, model_checkpoint_name), map_location='cpu') 51 | missing_keys, unexpected_keys = model.load_state_dict( 52 | checkpoint["model"], strict=True 53 | ) 54 | assert not missing_keys 55 | 56 | # Encodec 57 | codec = AudioTokenizer(device) 58 | 59 | vocos = Vocos.from_pretrained('charactr/vocos-encodec-24khz').to(device) 60 | 61 | return model, codec, vocos -------------------------------------------------------------------------------- /train_utils/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0417keito/VALL-E-X-Trainer-by-CustomData/8ff6a7987b46f72b8a8d8cabbd71979f39318f80/train_utils/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /train_utils/lhotse/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import uuid 5 | 6 | def fix_random_seed(random_seed: int): 7 | """ 8 | Set the same random seed for the libraries and modules that Lhotse interacts with. 9 | Includes the ``random`` module, numpy, torch, and ``uuid4()`` function defined in this file. 10 | """ 11 | global _lhotse_uuid 12 | random.seed(random_seed) 13 | np.random.seed(random_seed) 14 | torch.random.manual_seed(random_seed) 15 | # Ensure deterministic ID creation 16 | rd = random.Random() 17 | rd.seed(random_seed) 18 | _lhotse_uuid = lambda: uuid.UUID(int=rd.getrandbits(128)) -------------------------------------------------------------------------------- /train_utils/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | from macros import * 5 | from data.tokenizer import ( 6 | AudioTokenizer, 7 | tokenize_audio, 8 | ) 9 | from models.vallex import VALLE 10 | from vocos import Vocos 11 | 12 | def get_model(device): 13 | url = 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt' 14 | 15 | checkpoints_dir = "./checkpoints" 16 | 17 | model_checkpoint_name = "vallex-checkpoint_modified.pt" 18 | if not os.path.exists(checkpoints_dir): os.mkdir(checkpoints_dir) 19 | if not os.path.exists(os.path.join(checkpoints_dir, model_checkpoint_name)): 20 | import wget 21 | print("3") 22 | try: 23 | logging.info( 24 | "Downloading model from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt ...") 25 | # download from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt to ./checkpoints/vallex-checkpoint.pt 26 | wget.download("https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt", 27 | out="./checkpoints/vallex-checkpoint.pt", bar=wget.bar_adaptive) 28 | except Exception as e: 29 | logging.info(e) 30 | raise Exception( 31 | "\n Model weights download failed, please go to 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'" 32 | "\n manually download model weights and put it to {} .".format(os.getcwd() + "\checkpoints")) 33 | # VALL-E 34 | model = VALLE( 35 | N_DIM, 36 | NUM_HEAD, 37 | NUM_LAYERS, 38 | norm_first=True, 39 | add_prenet=False, 40 | prefix_mode=PREFIX_MODE, 41 | share_embedding=True, 42 | nar_scale_factor=1.0, 43 | prepend_bos=True, 44 | num_quantizers=NUM_QUANTIZERS, 45 | ).to(device) 46 | checkpoint = torch.load(os.path.join(checkpoints_dir, model_checkpoint_name), map_location='cpu') 47 | missing_keys, unexpected_keys = model.load_state_dict( 48 | checkpoint["model"], strict=True 49 | ) 50 | assert not missing_keys 51 | 52 | # Encodec 53 | codec = AudioTokenizer(device) 54 | 55 | vocos = Vocos.from_pretrained('charactr/vocos-encodec-24khz').to(device) 56 | 57 | return model, codec, vocos -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | # from icefall.utils import make_pad_mask 4 | 5 | from .symbol_table import SymbolTable 6 | 7 | # make_pad_mask = make_pad_mask 8 | SymbolTable = SymbolTable 9 | 10 | 11 | class Transpose(nn.Identity): 12 | """(N, T, D) -> (N, D, T)""" 13 | 14 | def forward(self, input: torch.Tensor) -> torch.Tensor: 15 | return input.transpose(1, 2) 16 | -------------------------------------------------------------------------------- /utils/download.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import requests 3 | 4 | 5 | def download_file_from_google_drive(id, destination): 6 | URL = "https://docs.google.com/uc?export=download&confirm=1" 7 | 8 | session = requests.Session() 9 | 10 | response = session.get(URL, params={"id": id}, stream=True) 11 | token = get_confirm_token(response) 12 | 13 | if token: 14 | params = {"id": id, "confirm": token} 15 | response = session.get(URL, params=params, stream=True) 16 | 17 | save_response_content(response, destination) 18 | 19 | 20 | def get_confirm_token(response): 21 | for key, value in response.cookies.items(): 22 | if key.startswith("download_warning"): 23 | return value 24 | 25 | return None 26 | 27 | 28 | def save_response_content(response, destination): 29 | CHUNK_SIZE = 32768 30 | 31 | with open(destination, "wb", encoding='utf-8') as f: 32 | for chunk in response.iter_content(CHUNK_SIZE): 33 | if chunk: # filter out keep-alive new chunks 34 | f.write(chunk) 35 | 36 | 37 | def main(): 38 | if len(sys.argv) >= 3: 39 | file_id = sys.argv[1] 40 | destination = sys.argv[2] 41 | else: 42 | file_id = "TAKE_ID_FROM_SHAREABLE_LINK" 43 | destination = "DESTINATION_FILE_ON_YOUR_DISK" 44 | print(f"dowload {file_id} to {destination}") 45 | download_file_from_google_drive(file_id, destination) 46 | 47 | 48 | if __name__ == "__main__": 49 | main() -------------------------------------------------------------------------------- /utils/g2p/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | import utils.g2p.cleaners 3 | from utils.g2p.symbols import symbols 4 | from tokenizers import Tokenizer 5 | 6 | # Mappings from symbol to numeric ID and vice versa: 7 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 8 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 9 | 10 | 11 | class PhonemeBpeTokenizer: 12 | def __init__(self, tokenizer_path = "./utils/g2p/bpe_1024.json"): 13 | self.tokenizer = Tokenizer.from_file(tokenizer_path) 14 | 15 | def tokenize(self, text): 16 | # 1. convert text to phoneme 17 | phonemes, langs = _clean_text(text, ['cje_cleaners']) 18 | # 2. replace blank space " " with "_" 19 | phonemes = phonemes.replace(" ", "_") 20 | # 3. tokenize phonemes 21 | phoneme_tokens = self.tokenizer.encode(phonemes).ids 22 | assert(len(phoneme_tokens) == len(langs)) 23 | if not len(phoneme_tokens): 24 | raise ValueError("Empty text is given") 25 | return phoneme_tokens, langs 26 | 27 | def text_to_sequence(text, cleaner_names): 28 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 29 | Args: 30 | text: string to convert to a sequence 31 | cleaner_names: names of the cleaner functions to run the text through 32 | Returns: 33 | List of integers corresponding to the symbols in the text 34 | ''' 35 | sequence = [] 36 | symbol_to_id = {s: i for i, s in enumerate(symbols)} 37 | clean_text = _clean_text(text, cleaner_names) 38 | for symbol in clean_text: 39 | if symbol not in symbol_to_id.keys(): 40 | continue 41 | symbol_id = symbol_to_id[symbol] 42 | sequence += [symbol_id] 43 | return sequence 44 | 45 | 46 | def cleaned_text_to_sequence(cleaned_text): 47 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 48 | Args: 49 | text: string to convert to a sequence 50 | Returns: 51 | List of integers corresponding to the symbols in the text 52 | ''' 53 | sequence = [_symbol_to_id[symbol] for symbol in cleaned_text if symbol in _symbol_to_id.keys()] 54 | return sequence 55 | 56 | 57 | def sequence_to_text(sequence): 58 | '''Converts a sequence of IDs back to a string''' 59 | result = '' 60 | for symbol_id in sequence: 61 | s = _id_to_symbol[symbol_id] 62 | result += s 63 | return result 64 | 65 | 66 | def _clean_text(text, cleaner_names): 67 | for name in cleaner_names: 68 | cleaner = getattr(utils.g2p.cleaners, name) 69 | if not cleaner: 70 | raise Exception('Unknown cleaner: %s' % name) 71 | text, langs = cleaner(text) 72 | return text, langs 73 | -------------------------------------------------------------------------------- /utils/g2p/bpe_69.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "1.0", 3 | "truncation": null, 4 | "padding": null, 5 | "added_tokens": [ 6 | { 7 | "id": 0, 8 | "content": "[UNK]", 9 | "single_word": false, 10 | "lstrip": false, 11 | "rstrip": false, 12 | "normalized": false, 13 | "special": true 14 | }, 15 | { 16 | "id": 1, 17 | "content": "[CLS]", 18 | "single_word": false, 19 | "lstrip": false, 20 | "rstrip": false, 21 | "normalized": false, 22 | "special": true 23 | }, 24 | { 25 | "id": 2, 26 | "content": "[SEP]", 27 | "single_word": false, 28 | "lstrip": false, 29 | "rstrip": false, 30 | "normalized": false, 31 | "special": true 32 | }, 33 | { 34 | "id": 3, 35 | "content": "[PAD]", 36 | "single_word": false, 37 | "lstrip": false, 38 | "rstrip": false, 39 | "normalized": false, 40 | "special": true 41 | }, 42 | { 43 | "id": 4, 44 | "content": "[MASK]", 45 | "single_word": false, 46 | "lstrip": false, 47 | "rstrip": false, 48 | "normalized": false, 49 | "special": true 50 | } 51 | ], 52 | "normalizer": null, 53 | "pre_tokenizer": { 54 | "type": "Whitespace" 55 | }, 56 | "post_processor": null, 57 | "decoder": null, 58 | "model": { 59 | "type": "BPE", 60 | "dropout": null, 61 | "unk_token": "[UNK]", 62 | "continuing_subword_prefix": null, 63 | "end_of_word_suffix": null, 64 | "fuse_unk": false, 65 | "byte_fallback": false, 66 | "vocab": { 67 | "[UNK]": 0, 68 | "[CLS]": 1, 69 | "[SEP]": 2, 70 | "[PAD]": 3, 71 | "[MASK]": 4, 72 | "!": 5, 73 | "#": 6, 74 | "*": 7, 75 | ",": 8, 76 | "-": 9, 77 | ".": 10, 78 | "=": 11, 79 | "?": 12, 80 | "N": 13, 81 | "Q": 14, 82 | "^": 15, 83 | "_": 16, 84 | "`": 17, 85 | "a": 18, 86 | "b": 19, 87 | "d": 20, 88 | "e": 21, 89 | "f": 22, 90 | "g": 23, 91 | "h": 24, 92 | "i": 25, 93 | "j": 26, 94 | "k": 27, 95 | "l": 28, 96 | "m": 29, 97 | "n": 30, 98 | "o": 31, 99 | "p": 32, 100 | "s": 33, 101 | "t": 34, 102 | "u": 35, 103 | "v": 36, 104 | "w": 37, 105 | "x": 38, 106 | "y": 39, 107 | "z": 40, 108 | "~": 41, 109 | "æ": 42, 110 | "ç": 43, 111 | "ð": 44, 112 | "ŋ": 45, 113 | "ɑ": 46, 114 | "ɔ": 47, 115 | "ə": 48, 116 | "ɛ": 49, 117 | "ɥ": 50, 118 | "ɪ": 51, 119 | "ɫ": 52, 120 | "ɯ": 53, 121 | "ɸ": 54, 122 | "ɹ": 55, 123 | "ɾ": 56, 124 | "ʃ": 57, 125 | "ʊ": 58, 126 | "ʑ": 59, 127 | "ʒ": 60, 128 | "ʰ": 61, 129 | "ˈ": 62, 130 | "ˌ": 63, 131 | "θ": 64, 132 | "…": 65, 133 | "⁼": 66, 134 | "↑": 67, 135 | "→": 68, 136 | "↓": 69 137 | }, 138 | "merges": [ 139 | ] 140 | } 141 | } -------------------------------------------------------------------------------- /utils/g2p/cleaners.py: -------------------------------------------------------------------------------- 1 | import re 2 | from utils.g2p.japanese import japanese_to_romaji_with_accent, japanese_to_ipa, japanese_to_ipa2, japanese_to_ipa3 3 | from utils.g2p.mandarin import number_to_chinese, chinese_to_bopomofo, latin_to_bopomofo, chinese_to_romaji, chinese_to_lazy_ipa, chinese_to_ipa, chinese_to_ipa2 4 | from utils.g2p.english import english_to_lazy_ipa, english_to_ipa2, english_to_lazy_ipa2 5 | patterns = [r'\[EN\](.*?)\[EN\]', r'\[ZH\](.*?)\[ZH\]', r'\[JA\](.*?)\[JA\]'] 6 | def japanese_cleaners(text): 7 | text = japanese_to_romaji_with_accent(text) 8 | text = re.sub(r'([A-Za-z])$', r'\1.', text) 9 | return text 10 | 11 | def japanese_cleaners2(text): 12 | return japanese_cleaners(text).replace('ts', 'ʦ').replace('...', '…') 13 | 14 | def chinese_cleaners(text): 15 | '''Pipeline for Chinese text''' 16 | text = number_to_chinese(text) 17 | text = chinese_to_bopomofo(text) 18 | text = latin_to_bopomofo(text) 19 | text = re.sub(r'([ˉˊˇˋ˙])$', r'\1。', text) 20 | return text 21 | 22 | def cje_cleaners(text): 23 | matches = [] 24 | for pattern in patterns: 25 | matches.extend(re.finditer(pattern, text)) 26 | 27 | matches.sort(key=lambda x: x.start()) # Sort matches by their start positions 28 | 29 | outputs = "" 30 | output_langs = [] 31 | 32 | for match in matches: 33 | text_segment = text[match.start():match.end()] 34 | phon = clean_one(text_segment) 35 | if "[EN]" in text_segment: 36 | lang = 'en' 37 | elif "[ZH]" in text_segment: 38 | lang = 'zh' 39 | elif "[JA]" in text_segment: 40 | lang = 'ja' 41 | else: 42 | raise ValueError("If you see this error, please report this bug to issues.") 43 | outputs += phon 44 | output_langs += [lang] * len(phon) 45 | assert len(outputs) == len(output_langs) 46 | return outputs, output_langs 47 | 48 | 49 | def clean_one(text): 50 | if text.find('[ZH]') != -1: 51 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', 52 | lambda x: chinese_to_ipa(x.group(1))+' ', text) 53 | if text.find('[JA]') != -1: 54 | text = re.sub(r'\[JA\](.*?)\[JA\]', 55 | lambda x: japanese_to_ipa2(x.group(1))+' ', text) 56 | if text.find('[EN]') != -1: 57 | text = re.sub(r'\[EN\](.*?)\[EN\]', 58 | lambda x: english_to_ipa2(x.group(1))+' ', text) 59 | text = re.sub(r'\s+$', '', text) 60 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 61 | return text 62 | -------------------------------------------------------------------------------- /utils/g2p/english.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | 16 | # Regular expression matching whitespace: 17 | 18 | 19 | import re 20 | from unidecode import unidecode 21 | import inflect 22 | _inflect = inflect.engine() 23 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 24 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 25 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 26 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 27 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 28 | _number_re = re.compile(r'[0-9]+') 29 | 30 | # List of (regular expression, replacement) pairs for abbreviations: 31 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 32 | ('mrs', 'misess'), 33 | ('mr', 'mister'), 34 | ('dr', 'doctor'), 35 | ('st', 'saint'), 36 | ('co', 'company'), 37 | ('jr', 'junior'), 38 | ('maj', 'major'), 39 | ('gen', 'general'), 40 | ('drs', 'doctors'), 41 | ('rev', 'reverend'), 42 | ('lt', 'lieutenant'), 43 | ('hon', 'honorable'), 44 | ('sgt', 'sergeant'), 45 | ('capt', 'captain'), 46 | ('esq', 'esquire'), 47 | ('ltd', 'limited'), 48 | ('col', 'colonel'), 49 | ('ft', 'fort'), 50 | ]] 51 | 52 | 53 | # List of (ipa, lazy ipa) pairs: 54 | _lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 55 | ('r', 'ɹ'), 56 | ('æ', 'e'), 57 | ('ɑ', 'a'), 58 | ('ɔ', 'o'), 59 | ('ð', 'z'), 60 | ('θ', 's'), 61 | ('ɛ', 'e'), 62 | ('ɪ', 'i'), 63 | ('ʊ', 'u'), 64 | ('ʒ', 'ʥ'), 65 | ('ʤ', 'ʥ'), 66 | ('ˈ', '↓'), 67 | ]] 68 | 69 | # List of (ipa, lazy ipa2) pairs: 70 | _lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 71 | ('r', 'ɹ'), 72 | ('ð', 'z'), 73 | ('θ', 's'), 74 | ('ʒ', 'ʑ'), 75 | ('ʤ', 'dʑ'), 76 | ('ˈ', '↓'), 77 | ]] 78 | 79 | # List of (ipa, ipa2) pairs 80 | _ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 81 | ('r', 'ɹ'), 82 | ('ʤ', 'dʒ'), 83 | ('ʧ', 'tʃ') 84 | ]] 85 | 86 | 87 | def expand_abbreviations(text): 88 | for regex, replacement in _abbreviations: 89 | text = re.sub(regex, replacement, text) 90 | return text 91 | 92 | 93 | def collapse_whitespace(text): 94 | return re.sub(r'\s+', ' ', text) 95 | 96 | 97 | def _remove_commas(m): 98 | return m.group(1).replace(',', '') 99 | 100 | 101 | def _expand_decimal_point(m): 102 | return m.group(1).replace('.', ' point ') 103 | 104 | 105 | def _expand_dollars(m): 106 | match = m.group(1) 107 | parts = match.split('.') 108 | if len(parts) > 2: 109 | return match + ' dollars' # Unexpected format 110 | dollars = int(parts[0]) if parts[0] else 0 111 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 112 | if dollars and cents: 113 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 114 | cent_unit = 'cent' if cents == 1 else 'cents' 115 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 116 | elif dollars: 117 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 118 | return '%s %s' % (dollars, dollar_unit) 119 | elif cents: 120 | cent_unit = 'cent' if cents == 1 else 'cents' 121 | return '%s %s' % (cents, cent_unit) 122 | else: 123 | return 'zero dollars' 124 | 125 | 126 | def _expand_ordinal(m): 127 | return _inflect.number_to_words(m.group(0)) 128 | 129 | 130 | def _expand_number(m): 131 | num = int(m.group(0)) 132 | if num > 1000 and num < 3000: 133 | if num == 2000: 134 | return 'two thousand' 135 | elif num > 2000 and num < 2010: 136 | return 'two thousand ' + _inflect.number_to_words(num % 100) 137 | elif num % 100 == 0: 138 | return _inflect.number_to_words(num // 100) + ' hundred' 139 | else: 140 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 141 | else: 142 | return _inflect.number_to_words(num, andword='') 143 | 144 | 145 | def normalize_numbers(text): 146 | text = re.sub(_comma_number_re, _remove_commas, text) 147 | text = re.sub(_pounds_re, r'\1 pounds', text) 148 | text = re.sub(_dollars_re, _expand_dollars, text) 149 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 150 | text = re.sub(_ordinal_re, _expand_ordinal, text) 151 | text = re.sub(_number_re, _expand_number, text) 152 | return text 153 | 154 | 155 | def mark_dark_l(text): 156 | return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text) 157 | 158 | 159 | def english_to_ipa(text): 160 | import eng_to_ipa as ipa 161 | text = unidecode(text).lower() 162 | text = expand_abbreviations(text) 163 | text = normalize_numbers(text) 164 | phonemes = ipa.convert(text) 165 | phonemes = collapse_whitespace(phonemes) 166 | return phonemes 167 | 168 | 169 | def english_to_lazy_ipa(text): 170 | text = english_to_ipa(text) 171 | for regex, replacement in _lazy_ipa: 172 | text = re.sub(regex, replacement, text) 173 | return text 174 | 175 | 176 | def english_to_ipa2(text): 177 | text = english_to_ipa(text) 178 | text = mark_dark_l(text) 179 | for regex, replacement in _ipa_to_ipa2: 180 | text = re.sub(regex, replacement, text) 181 | return text.replace('...', '…') 182 | 183 | 184 | def english_to_lazy_ipa2(text): 185 | text = english_to_ipa(text) 186 | for regex, replacement in _lazy_ipa2: 187 | text = re.sub(regex, replacement, text) 188 | return text 189 | -------------------------------------------------------------------------------- /utils/g2p/japanese.py: -------------------------------------------------------------------------------- 1 | import re 2 | from unidecode import unidecode 3 | 4 | 5 | 6 | # Regular expression matching Japanese without punctuation marks: 7 | _japanese_characters = re.compile( 8 | r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') 9 | 10 | # Regular expression matching non-Japanese characters or punctuation marks: 11 | _japanese_marks = re.compile( 12 | r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') 13 | 14 | # List of (symbol, Japanese) pairs for marks: 15 | _symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [ 16 | ('%', 'パーセント') 17 | ]] 18 | 19 | # List of (romaji, ipa) pairs for marks: 20 | _romaji_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 21 | ('ts', 'ʦ'), 22 | ('u', 'ɯ'), 23 | ('j', 'ʥ'), 24 | ('y', 'j'), 25 | ('ni', 'n^i'), 26 | ('nj', 'n^'), 27 | ('hi', 'çi'), 28 | ('hj', 'ç'), 29 | ('f', 'ɸ'), 30 | ('I', 'i*'), 31 | ('U', 'ɯ*'), 32 | ('r', 'ɾ') 33 | ]] 34 | 35 | # List of (romaji, ipa2) pairs for marks: 36 | _romaji_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 37 | ('u', 'ɯ'), 38 | ('ʧ', 'tʃ'), 39 | ('j', 'dʑ'), 40 | ('y', 'j'), 41 | ('ni', 'n^i'), 42 | ('nj', 'n^'), 43 | ('hi', 'çi'), 44 | ('hj', 'ç'), 45 | ('f', 'ɸ'), 46 | ('I', 'i*'), 47 | ('U', 'ɯ*'), 48 | ('r', 'ɾ') 49 | ]] 50 | 51 | # List of (consonant, sokuon) pairs: 52 | _real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [ 53 | (r'Q([↑↓]*[kg])', r'k#\1'), 54 | (r'Q([↑↓]*[tdjʧ])', r't#\1'), 55 | (r'Q([↑↓]*[sʃ])', r's\1'), 56 | (r'Q([↑↓]*[pb])', r'p#\1') 57 | ]] 58 | 59 | # List of (consonant, hatsuon) pairs: 60 | _real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [ 61 | (r'N([↑↓]*[pbm])', r'm\1'), 62 | (r'N([↑↓]*[ʧʥj])', r'n^\1'), 63 | (r'N([↑↓]*[tdn])', r'n\1'), 64 | (r'N([↑↓]*[kg])', r'ŋ\1') 65 | ]] 66 | 67 | 68 | def symbols_to_japanese(text): 69 | for regex, replacement in _symbols_to_japanese: 70 | text = re.sub(regex, replacement, text) 71 | return text 72 | 73 | 74 | def japanese_to_romaji_with_accent(text): 75 | '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html''' 76 | import pyopenjtalk 77 | text = symbols_to_japanese(text) 78 | sentences = re.split(_japanese_marks, text) 79 | marks = re.findall(_japanese_marks, text) 80 | text = '' 81 | for i, sentence in enumerate(sentences): 82 | if re.match(_japanese_characters, sentence): 83 | if text != '': 84 | text += ' ' 85 | labels = pyopenjtalk.extract_fullcontext(sentence) 86 | for n, label in enumerate(labels): 87 | phoneme = re.search(r'\-([^\+]*)\+', label).group(1) 88 | if phoneme not in ['sil', 'pau']: 89 | text += phoneme.replace('ch', 'ʧ').replace('sh', 90 | 'ʃ').replace('cl', 'Q') 91 | else: 92 | continue 93 | # n_moras = int(re.search(r'/F:(\d+)_', label).group(1)) 94 | a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1)) 95 | a2 = int(re.search(r"\+(\d+)\+", label).group(1)) 96 | a3 = int(re.search(r"\+(\d+)/", label).group(1)) 97 | if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil', 'pau']: 98 | a2_next = -1 99 | else: 100 | a2_next = int( 101 | re.search(r"\+(\d+)\+", labels[n + 1]).group(1)) 102 | # Accent phrase boundary 103 | if a3 == 1 and a2_next == 1: 104 | text += ' ' 105 | # Falling 106 | elif a1 == 0 and a2_next == a2 + 1: 107 | text += '↓' 108 | # Rising 109 | elif a2 == 1 and a2_next == 2: 110 | text += '↑' 111 | if i < len(marks): 112 | text += unidecode(marks[i]).replace(' ', '') 113 | return text 114 | 115 | 116 | def get_real_sokuon(text): 117 | for regex, replacement in _real_sokuon: 118 | text = re.sub(regex, replacement, text) 119 | return text 120 | 121 | 122 | def get_real_hatsuon(text): 123 | for regex, replacement in _real_hatsuon: 124 | text = re.sub(regex, replacement, text) 125 | return text 126 | 127 | 128 | def japanese_to_ipa(text): 129 | text = japanese_to_romaji_with_accent(text).replace('...', '…') 130 | text = re.sub( 131 | r'([aiueo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text) 132 | text = get_real_sokuon(text) 133 | text = get_real_hatsuon(text) 134 | for regex, replacement in _romaji_to_ipa: 135 | text = re.sub(regex, replacement, text) 136 | return text 137 | 138 | 139 | def japanese_to_ipa2(text): 140 | text = japanese_to_romaji_with_accent(text).replace('...', '…') 141 | text = get_real_sokuon(text) 142 | text = get_real_hatsuon(text) 143 | for regex, replacement in _romaji_to_ipa2: 144 | text = re.sub(regex, replacement, text) 145 | return text 146 | 147 | 148 | def japanese_to_ipa3(text): 149 | text = japanese_to_ipa2(text).replace('n^', 'ȵ').replace( 150 | 'ʃ', 'ɕ').replace('*', '\u0325').replace('#', '\u031a') 151 | text = re.sub( 152 | r'([aiɯeo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text) 153 | text = re.sub(r'((?:^|\s)(?:ts|tɕ|[kpt]))', r'\1ʰ', text) 154 | return text 155 | -------------------------------------------------------------------------------- /utils/g2p/mandarin.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | import jieba 5 | import cn2an 6 | import logging 7 | 8 | 9 | # List of (Latin alphabet, bopomofo) pairs: 10 | _latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 11 | ('a', 'ㄟˉ'), 12 | ('b', 'ㄅㄧˋ'), 13 | ('c', 'ㄙㄧˉ'), 14 | ('d', 'ㄉㄧˋ'), 15 | ('e', 'ㄧˋ'), 16 | ('f', 'ㄝˊㄈㄨˋ'), 17 | ('g', 'ㄐㄧˋ'), 18 | ('h', 'ㄝˇㄑㄩˋ'), 19 | ('i', 'ㄞˋ'), 20 | ('j', 'ㄐㄟˋ'), 21 | ('k', 'ㄎㄟˋ'), 22 | ('l', 'ㄝˊㄛˋ'), 23 | ('m', 'ㄝˊㄇㄨˋ'), 24 | ('n', 'ㄣˉ'), 25 | ('o', 'ㄡˉ'), 26 | ('p', 'ㄆㄧˉ'), 27 | ('q', 'ㄎㄧㄡˉ'), 28 | ('r', 'ㄚˋ'), 29 | ('s', 'ㄝˊㄙˋ'), 30 | ('t', 'ㄊㄧˋ'), 31 | ('u', 'ㄧㄡˉ'), 32 | ('v', 'ㄨㄧˉ'), 33 | ('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'), 34 | ('x', 'ㄝˉㄎㄨˋㄙˋ'), 35 | ('y', 'ㄨㄞˋ'), 36 | ('z', 'ㄗㄟˋ') 37 | ]] 38 | 39 | # List of (bopomofo, romaji) pairs: 40 | _bopomofo_to_romaji = [(re.compile('%s' % x[0]), x[1]) for x in [ 41 | ('ㄅㄛ', 'p⁼wo'), 42 | ('ㄆㄛ', 'pʰwo'), 43 | ('ㄇㄛ', 'mwo'), 44 | ('ㄈㄛ', 'fwo'), 45 | ('ㄅ', 'p⁼'), 46 | ('ㄆ', 'pʰ'), 47 | ('ㄇ', 'm'), 48 | ('ㄈ', 'f'), 49 | ('ㄉ', 't⁼'), 50 | ('ㄊ', 'tʰ'), 51 | ('ㄋ', 'n'), 52 | ('ㄌ', 'l'), 53 | ('ㄍ', 'k⁼'), 54 | ('ㄎ', 'kʰ'), 55 | ('ㄏ', 'h'), 56 | ('ㄐ', 'ʧ⁼'), 57 | ('ㄑ', 'ʧʰ'), 58 | ('ㄒ', 'ʃ'), 59 | ('ㄓ', 'ʦ`⁼'), 60 | ('ㄔ', 'ʦ`ʰ'), 61 | ('ㄕ', 's`'), 62 | ('ㄖ', 'ɹ`'), 63 | ('ㄗ', 'ʦ⁼'), 64 | ('ㄘ', 'ʦʰ'), 65 | ('ㄙ', 's'), 66 | ('ㄚ', 'a'), 67 | ('ㄛ', 'o'), 68 | ('ㄜ', 'ə'), 69 | ('ㄝ', 'e'), 70 | ('ㄞ', 'ai'), 71 | ('ㄟ', 'ei'), 72 | ('ㄠ', 'au'), 73 | ('ㄡ', 'ou'), 74 | ('ㄧㄢ', 'yeNN'), 75 | ('ㄢ', 'aNN'), 76 | ('ㄧㄣ', 'iNN'), 77 | ('ㄣ', 'əNN'), 78 | ('ㄤ', 'aNg'), 79 | ('ㄧㄥ', 'iNg'), 80 | ('ㄨㄥ', 'uNg'), 81 | ('ㄩㄥ', 'yuNg'), 82 | ('ㄥ', 'əNg'), 83 | ('ㄦ', 'əɻ'), 84 | ('ㄧ', 'i'), 85 | ('ㄨ', 'u'), 86 | ('ㄩ', 'ɥ'), 87 | ('ˉ', '→'), 88 | ('ˊ', '↑'), 89 | ('ˇ', '↓↑'), 90 | ('ˋ', '↓'), 91 | ('˙', ''), 92 | (',', ','), 93 | ('。', '.'), 94 | ('!', '!'), 95 | ('?', '?'), 96 | ('—', '-') 97 | ]] 98 | 99 | # List of (romaji, ipa) pairs: 100 | _romaji_to_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 101 | ('ʃy', 'ʃ'), 102 | ('ʧʰy', 'ʧʰ'), 103 | ('ʧ⁼y', 'ʧ⁼'), 104 | ('NN', 'n'), 105 | ('Ng', 'ŋ'), 106 | ('y', 'j'), 107 | ('h', 'x') 108 | ]] 109 | 110 | # List of (bopomofo, ipa) pairs: 111 | _bopomofo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 112 | ('ㄅㄛ', 'p⁼wo'), 113 | ('ㄆㄛ', 'pʰwo'), 114 | ('ㄇㄛ', 'mwo'), 115 | ('ㄈㄛ', 'fwo'), 116 | ('ㄅ', 'p⁼'), 117 | ('ㄆ', 'pʰ'), 118 | ('ㄇ', 'm'), 119 | ('ㄈ', 'f'), 120 | ('ㄉ', 't⁼'), 121 | ('ㄊ', 'tʰ'), 122 | ('ㄋ', 'n'), 123 | ('ㄌ', 'l'), 124 | ('ㄍ', 'k⁼'), 125 | ('ㄎ', 'kʰ'), 126 | ('ㄏ', 'x'), 127 | ('ㄐ', 'tʃ⁼'), 128 | ('ㄑ', 'tʃʰ'), 129 | ('ㄒ', 'ʃ'), 130 | ('ㄓ', 'ts`⁼'), 131 | ('ㄔ', 'ts`ʰ'), 132 | ('ㄕ', 's`'), 133 | ('ㄖ', 'ɹ`'), 134 | ('ㄗ', 'ts⁼'), 135 | ('ㄘ', 'tsʰ'), 136 | ('ㄙ', 's'), 137 | ('ㄚ', 'a'), 138 | ('ㄛ', 'o'), 139 | ('ㄜ', 'ə'), 140 | ('ㄝ', 'ɛ'), 141 | ('ㄞ', 'aɪ'), 142 | ('ㄟ', 'eɪ'), 143 | ('ㄠ', 'ɑʊ'), 144 | ('ㄡ', 'oʊ'), 145 | ('ㄧㄢ', 'jɛn'), 146 | ('ㄩㄢ', 'ɥæn'), 147 | ('ㄢ', 'an'), 148 | ('ㄧㄣ', 'in'), 149 | ('ㄩㄣ', 'ɥn'), 150 | ('ㄣ', 'ən'), 151 | ('ㄤ', 'ɑŋ'), 152 | ('ㄧㄥ', 'iŋ'), 153 | ('ㄨㄥ', 'ʊŋ'), 154 | ('ㄩㄥ', 'jʊŋ'), 155 | ('ㄥ', 'əŋ'), 156 | ('ㄦ', 'əɻ'), 157 | ('ㄧ', 'i'), 158 | ('ㄨ', 'u'), 159 | ('ㄩ', 'ɥ'), 160 | ('ˉ', '→'), 161 | ('ˊ', '↑'), 162 | ('ˇ', '↓↑'), 163 | ('ˋ', '↓'), 164 | ('˙', ''), 165 | (',', ','), 166 | ('。', '.'), 167 | ('!', '!'), 168 | ('?', '?'), 169 | ('—', '-') 170 | ]] 171 | 172 | # List of (bopomofo, ipa2) pairs: 173 | _bopomofo_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 174 | ('ㄅㄛ', 'pwo'), 175 | ('ㄆㄛ', 'pʰwo'), 176 | ('ㄇㄛ', 'mwo'), 177 | ('ㄈㄛ', 'fwo'), 178 | ('ㄅ', 'p'), 179 | ('ㄆ', 'pʰ'), 180 | ('ㄇ', 'm'), 181 | ('ㄈ', 'f'), 182 | ('ㄉ', 't'), 183 | ('ㄊ', 'tʰ'), 184 | ('ㄋ', 'n'), 185 | ('ㄌ', 'l'), 186 | ('ㄍ', 'k'), 187 | ('ㄎ', 'kʰ'), 188 | ('ㄏ', 'h'), 189 | ('ㄐ', 'tɕ'), 190 | ('ㄑ', 'tɕʰ'), 191 | ('ㄒ', 'ɕ'), 192 | ('ㄓ', 'tʂ'), 193 | ('ㄔ', 'tʂʰ'), 194 | ('ㄕ', 'ʂ'), 195 | ('ㄖ', 'ɻ'), 196 | ('ㄗ', 'ts'), 197 | ('ㄘ', 'tsʰ'), 198 | ('ㄙ', 's'), 199 | ('ㄚ', 'a'), 200 | ('ㄛ', 'o'), 201 | ('ㄜ', 'ɤ'), 202 | ('ㄝ', 'ɛ'), 203 | ('ㄞ', 'aɪ'), 204 | ('ㄟ', 'eɪ'), 205 | ('ㄠ', 'ɑʊ'), 206 | ('ㄡ', 'oʊ'), 207 | ('ㄧㄢ', 'jɛn'), 208 | ('ㄩㄢ', 'yæn'), 209 | ('ㄢ', 'an'), 210 | ('ㄧㄣ', 'in'), 211 | ('ㄩㄣ', 'yn'), 212 | ('ㄣ', 'ən'), 213 | ('ㄤ', 'ɑŋ'), 214 | ('ㄧㄥ', 'iŋ'), 215 | ('ㄨㄥ', 'ʊŋ'), 216 | ('ㄩㄥ', 'jʊŋ'), 217 | ('ㄥ', 'ɤŋ'), 218 | ('ㄦ', 'əɻ'), 219 | ('ㄧ', 'i'), 220 | ('ㄨ', 'u'), 221 | ('ㄩ', 'y'), 222 | ('ˉ', '˥'), 223 | ('ˊ', '˧˥'), 224 | ('ˇ', '˨˩˦'), 225 | ('ˋ', '˥˩'), 226 | ('˙', ''), 227 | (',', ','), 228 | ('。', '.'), 229 | ('!', '!'), 230 | ('?', '?'), 231 | ('—', '-') 232 | ]] 233 | 234 | 235 | def number_to_chinese(text): 236 | numbers = re.findall(r'\d+(?:\.?\d+)?', text) 237 | for number in numbers: 238 | text = text.replace(number, cn2an.an2cn(number), 1) 239 | return text 240 | 241 | 242 | def chinese_to_bopomofo(text): 243 | from pypinyin import lazy_pinyin, BOPOMOFO 244 | text = text.replace('、', ',').replace(';', ',').replace(':', ',') 245 | words = jieba.lcut(text, cut_all=False) 246 | text = '' 247 | for word in words: 248 | bopomofos = lazy_pinyin(word, BOPOMOFO) 249 | if not re.search('[\u4e00-\u9fff]', word): 250 | text += word 251 | continue 252 | for i in range(len(bopomofos)): 253 | bopomofos[i] = re.sub(r'([\u3105-\u3129])$', r'\1ˉ', bopomofos[i]) 254 | if text != '': 255 | text += ' ' 256 | text += ''.join(bopomofos) 257 | return text 258 | 259 | 260 | def latin_to_bopomofo(text): 261 | for regex, replacement in _latin_to_bopomofo: 262 | text = re.sub(regex, replacement, text) 263 | return text 264 | 265 | 266 | def bopomofo_to_romaji(text): 267 | for regex, replacement in _bopomofo_to_romaji: 268 | text = re.sub(regex, replacement, text) 269 | return text 270 | 271 | 272 | def bopomofo_to_ipa(text): 273 | for regex, replacement in _bopomofo_to_ipa: 274 | text = re.sub(regex, replacement, text) 275 | return text 276 | 277 | 278 | def bopomofo_to_ipa2(text): 279 | for regex, replacement in _bopomofo_to_ipa2: 280 | text = re.sub(regex, replacement, text) 281 | return text 282 | 283 | 284 | def chinese_to_romaji(text): 285 | text = number_to_chinese(text) 286 | text = chinese_to_bopomofo(text) 287 | text = latin_to_bopomofo(text) 288 | text = bopomofo_to_romaji(text) 289 | text = re.sub('i([aoe])', r'y\1', text) 290 | text = re.sub('u([aoəe])', r'w\1', text) 291 | text = re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑ ]+|$)', 292 | r'\1ɹ`\2', text).replace('ɻ', 'ɹ`') 293 | text = re.sub('([ʦs][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text) 294 | return text 295 | 296 | 297 | def chinese_to_lazy_ipa(text): 298 | text = chinese_to_romaji(text) 299 | for regex, replacement in _romaji_to_ipa: 300 | text = re.sub(regex, replacement, text) 301 | return text 302 | 303 | 304 | def chinese_to_ipa(text): 305 | text = number_to_chinese(text) 306 | text = chinese_to_bopomofo(text) 307 | text = latin_to_bopomofo(text) 308 | text = bopomofo_to_ipa(text) 309 | text = re.sub('i([aoe])', r'j\1', text) 310 | text = re.sub('u([aoəe])', r'w\1', text) 311 | text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)', 312 | r'\1ɹ`\2', text).replace('ɻ', 'ɹ`') 313 | text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text) 314 | return text 315 | 316 | 317 | def chinese_to_ipa2(text): 318 | text = number_to_chinese(text) 319 | text = chinese_to_bopomofo(text) 320 | text = latin_to_bopomofo(text) 321 | text = bopomofo_to_ipa2(text) 322 | text = re.sub(r'i([aoe])', r'j\1', text) 323 | text = re.sub(r'u([aoəe])', r'w\1', text) 324 | text = re.sub(r'([ʂɹ]ʰ?)([˩˨˧˦˥ ]+|$)', r'\1ʅ\2', text) 325 | text = re.sub(r'(sʰ?)([˩˨˧˦˥ ]+|$)', r'\1ɿ\2', text) 326 | return text 327 | -------------------------------------------------------------------------------- /utils/g2p/symbols.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Defines the set of symbols used in text input to the model. 3 | ''' 4 | 5 | # japanese_cleaners 6 | # _pad = '_' 7 | # _punctuation = ',.!?-' 8 | # _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ ' 9 | 10 | 11 | '''# japanese_cleaners2 12 | _pad = '_' 13 | _punctuation = ',.!?-~…' 14 | _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ ' 15 | ''' 16 | 17 | 18 | '''# korean_cleaners 19 | _pad = '_' 20 | _punctuation = ',.!?…~' 21 | _letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ ' 22 | ''' 23 | 24 | '''# chinese_cleaners 25 | _pad = '_' 26 | _punctuation = ',。!?—…' 27 | _letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ ' 28 | ''' 29 | 30 | # # zh_ja_mixture_cleaners 31 | # _pad = '_' 32 | # _punctuation = ',.!?-~…' 33 | # _letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ ' 34 | 35 | 36 | '''# sanskrit_cleaners 37 | _pad = '_' 38 | _punctuation = '।' 39 | _letters = 'ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऽािीुूृॄेैोौ्ॠॢ ' 40 | ''' 41 | 42 | '''# cjks_cleaners 43 | _pad = '_' 44 | _punctuation = ',.!?-~…' 45 | _letters = 'NQabdefghijklmnopstuvwxyzʃʧʥʦɯɹəɥçɸɾβŋɦː⁼ʰ`^#*=→↓↑ ' 46 | ''' 47 | 48 | '''# thai_cleaners 49 | _pad = '_' 50 | _punctuation = '.!? ' 51 | _letters = 'กขฃคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลวศษสหฬอฮฯะัาำิีึืุูเแโใไๅๆ็่้๊๋์' 52 | ''' 53 | 54 | # # cjke_cleaners2 55 | _pad = '_' 56 | _punctuation = ',.!?-~…' 57 | _letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ ' 58 | 59 | 60 | '''# shanghainese_cleaners 61 | _pad = '_' 62 | _punctuation = ',.!?…' 63 | _letters = 'abdfghiklmnopstuvyzøŋȵɑɔɕəɤɦɪɿʑʔʰ̩̃ᴀᴇ15678 ' 64 | ''' 65 | 66 | '''# chinese_dialect_cleaners 67 | _pad = '_' 68 | _punctuation = ',.!?~…─' 69 | _letters = '#Nabdefghijklmnoprstuvwxyzæçøŋœȵɐɑɒɓɔɕɗɘəɚɛɜɣɤɦɪɭɯɵɷɸɻɾɿʂʅʊʋʌʏʑʔʦʮʰʷˀː˥˦˧˨˩̥̩̃̚ᴀᴇ↑↓∅ⱼ ' 70 | ''' 71 | 72 | # Export all symbols: 73 | symbols = [_pad] + list(_punctuation) + list(_letters) 74 | 75 | # Special symbol ids 76 | SPACE_ID = symbols.index(" ") 77 | -------------------------------------------------------------------------------- /utils/generation.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import os 3 | import torch 4 | from vocos import Vocos 5 | import logging 6 | import langid 7 | langid.set_languages(['en', 'zh', 'ja']) 8 | 9 | import pathlib 10 | import platform 11 | if platform.system().lower() == 'windows': 12 | temp = pathlib.PosixPath 13 | pathlib.PosixPath = pathlib.WindowsPath 14 | else: 15 | temp = pathlib.WindowsPath 16 | pathlib.WindowsPath = pathlib.PosixPath 17 | 18 | import numpy as np 19 | from data.tokenizer import ( 20 | AudioTokenizer, 21 | tokenize_audio, 22 | ) 23 | from data.collation import get_text_token_collater 24 | from models.vallex import VALLE 25 | from utils.g2p import PhonemeBpeTokenizer 26 | from utils.sentence_cutter import split_text_into_sentences 27 | 28 | from macros import * 29 | 30 | device = torch.device("cpu") 31 | if torch.cuda.is_available(): 32 | device = torch.device("cuda", 0) 33 | 34 | url = 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt' 35 | 36 | checkpoints_dir = "./checkpoints/" 37 | 38 | model_checkpoint_name = "vallex-checkpoint.pt" 39 | 40 | model = None 41 | 42 | codec = None 43 | 44 | vocos = None 45 | 46 | text_tokenizer = PhonemeBpeTokenizer(tokenizer_path="./utils/g2p/bpe_69.json") 47 | text_collater = get_text_token_collater() 48 | 49 | def preload_models(): 50 | global model, codec, vocos 51 | if not os.path.exists(checkpoints_dir): os.mkdir(checkpoints_dir) 52 | if not os.path.exists(os.path.join(checkpoints_dir, model_checkpoint_name)): 53 | import wget 54 | try: 55 | logging.info( 56 | "Downloading model from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt ...") 57 | # download from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt to ./checkpoints/vallex-checkpoint.pt 58 | wget.download("https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt", 59 | out="./checkpoints/vallex-checkpoint.pt", bar=wget.bar_adaptive) 60 | except Exception as e: 61 | logging.info(e) 62 | raise Exception( 63 | "\n Model weights download failed, please go to 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'" 64 | "\n manually download model weights and put it to {} .".format(os.getcwd() + "\checkpoints")) 65 | # VALL-E 66 | model = VALLE( 67 | N_DIM, 68 | NUM_HEAD, 69 | NUM_LAYERS, 70 | norm_first=True, 71 | add_prenet=False, 72 | prefix_mode=PREFIX_MODE, 73 | share_embedding=True, 74 | nar_scale_factor=1.0, 75 | prepend_bos=True, 76 | num_quantizers=NUM_QUANTIZERS, 77 | ).to(device) 78 | checkpoint = torch.load(os.path.join(checkpoints_dir, model_checkpoint_name), map_location='cpu') 79 | missing_keys, unexpected_keys = model.load_state_dict( 80 | checkpoint["model"], strict=True 81 | ) 82 | assert not missing_keys 83 | model.eval() 84 | 85 | # Encodec 86 | codec = AudioTokenizer(device) 87 | 88 | vocos = Vocos.from_pretrained('charactr/vocos-encodec-24khz').to(device) 89 | 90 | @torch.no_grad() 91 | def generate_audio(text, prompt=None, language='auto', accent='no-accent'): 92 | global model, codec, vocos, text_tokenizer, text_collater 93 | text = text.replace("\n", "").strip(" ") 94 | # detect language 95 | if language == "auto": 96 | language = langid.classify(text)[0] 97 | lang_token = lang2token[language] 98 | lang = token2lang[lang_token] 99 | text = lang_token + text + lang_token 100 | 101 | # load prompt 102 | if prompt is not None: 103 | prompt_path = prompt 104 | if not os.path.exists(prompt_path): 105 | prompt_path = "./presets/" + prompt + ".npz" 106 | if not os.path.exists(prompt_path): 107 | prompt_path = "./customs/" + prompt + ".npz" 108 | if not os.path.exists(prompt_path): 109 | raise ValueError(f"Cannot find prompt {prompt}") 110 | prompt_data = np.load(prompt_path) 111 | audio_prompts = prompt_data['audio_tokens'] 112 | text_prompts = prompt_data['text_tokens'] 113 | lang_pr = prompt_data['lang_code'] 114 | lang_pr = code2lang[int(lang_pr)] 115 | 116 | # numpy to tensor 117 | audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device) 118 | text_prompts = torch.tensor(text_prompts).type(torch.int32) 119 | else: 120 | audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device) 121 | text_prompts = torch.zeros([1, 0]).type(torch.int32) 122 | lang_pr = lang if lang != 'mix' else 'en' 123 | 124 | enroll_x_lens = text_prompts.shape[-1] 125 | logging.info(f"synthesize text: {text}") 126 | phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip()) 127 | text_tokens, text_tokens_lens = text_collater( 128 | [ 129 | phone_tokens 130 | ] 131 | ) 132 | text_tokens = torch.cat([text_prompts, text_tokens], dim=-1) 133 | text_tokens_lens += enroll_x_lens 134 | # accent control 135 | lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]] 136 | encoded_frames = model.inference( 137 | text_tokens.to(device), 138 | text_tokens_lens.to(device), 139 | audio_prompts, 140 | enroll_x_lens=enroll_x_lens, 141 | top_k=-100, 142 | temperature=1, 143 | prompt_language=lang_pr, 144 | text_language=langs if accent == "no-accent" else lang, 145 | ) 146 | # Decode with Vocos 147 | frames = encoded_frames.permute(2,0,1) 148 | features = vocos.codes_to_features(frames) 149 | samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device)) 150 | 151 | return samples.squeeze().cpu().numpy() 152 | 153 | @torch.no_grad() 154 | def generate_audio_from_long_text(text, prompt=None, language='auto', accent='no-accent', mode='sliding-window'): 155 | """ 156 | For long audio generation, two modes are available. 157 | fixed-prompt: This mode will keep using the same prompt the user has provided, and generate audio sentence by sentence. 158 | sliding-window: This mode will use the last sentence as the prompt for the next sentence, but has some concern on speaker maintenance. 159 | """ 160 | global model, codec, vocos, text_tokenizer, text_collater 161 | if prompt is None or prompt == "": 162 | mode = 'sliding-window' # If no prompt is given, use sliding-window mode 163 | sentences = split_text_into_sentences(text) 164 | # detect language 165 | if language == "auto": 166 | language = langid.classify(text)[0] 167 | 168 | # if initial prompt is given, encode it 169 | if prompt is not None and prompt != "": 170 | prompt_path = prompt 171 | if not os.path.exists(prompt_path): 172 | prompt_path = "./presets/" + prompt + ".npz" 173 | if not os.path.exists(prompt_path): 174 | prompt_path = "./customs/" + prompt + ".npz" 175 | if not os.path.exists(prompt_path): 176 | raise ValueError(f"Cannot find prompt {prompt}") 177 | prompt_data = np.load(prompt_path) 178 | audio_prompts = prompt_data['audio_tokens'] 179 | text_prompts = prompt_data['text_tokens'] 180 | lang_pr = prompt_data['lang_code'] 181 | lang_pr = code2lang[int(lang_pr)] 182 | 183 | # numpy to tensor 184 | audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device) 185 | text_prompts = torch.tensor(text_prompts).type(torch.int32) 186 | else: 187 | audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device) 188 | text_prompts = torch.zeros([1, 0]).type(torch.int32) 189 | lang_pr = language if language != 'mix' else 'en' 190 | if mode == 'fixed-prompt': 191 | complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device) 192 | for text in sentences: 193 | text = text.replace("\n", "").strip(" ") 194 | if text == "": 195 | continue 196 | lang_token = lang2token[language] 197 | lang = token2lang[lang_token] 198 | text = lang_token + text + lang_token 199 | 200 | enroll_x_lens = text_prompts.shape[-1] 201 | logging.info(f"synthesize text: {text}") 202 | phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip()) 203 | text_tokens, text_tokens_lens = text_collater( 204 | [ 205 | phone_tokens 206 | ] 207 | ) 208 | text_tokens = torch.cat([text_prompts, text_tokens], dim=-1) 209 | text_tokens_lens += enroll_x_lens 210 | # accent control 211 | lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]] 212 | encoded_frames = model.inference( 213 | text_tokens.to(device), 214 | text_tokens_lens.to(device), 215 | audio_prompts, 216 | enroll_x_lens=enroll_x_lens, 217 | top_k=-100, 218 | temperature=1, 219 | prompt_language=lang_pr, 220 | text_language=langs if accent == "no-accent" else lang, 221 | ) 222 | complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1) 223 | # Decode with Vocos 224 | frames = complete_tokens.permute(1,0,2) 225 | features = vocos.codes_to_features(frames) 226 | samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device)) 227 | return samples.squeeze().cpu().numpy() 228 | elif mode == "sliding-window": 229 | complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device) 230 | original_audio_prompts = audio_prompts 231 | original_text_prompts = text_prompts 232 | for text in sentences: 233 | text = text.replace("\n", "").strip(" ") 234 | if text == "": 235 | continue 236 | lang_token = lang2token[language] 237 | lang = token2lang[lang_token] 238 | text = lang_token + text + lang_token 239 | 240 | enroll_x_lens = text_prompts.shape[-1] 241 | logging.info(f"synthesize text: {text}") 242 | phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip()) 243 | text_tokens, text_tokens_lens = text_collater( 244 | [ 245 | phone_tokens 246 | ] 247 | ) 248 | text_tokens = torch.cat([text_prompts, text_tokens], dim=-1) 249 | text_tokens_lens += enroll_x_lens 250 | # accent control 251 | lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]] 252 | encoded_frames = model.inference( 253 | text_tokens.to(device), 254 | text_tokens_lens.to(device), 255 | audio_prompts, 256 | enroll_x_lens=enroll_x_lens, 257 | top_k=-100, 258 | temperature=1, 259 | prompt_language=lang_pr, 260 | text_language=langs if accent == "no-accent" else lang, 261 | ) 262 | complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1) 263 | if torch.rand(1) < 0.5: 264 | audio_prompts = encoded_frames[:, :, -NUM_QUANTIZERS:] 265 | text_prompts = text_tokens[:, enroll_x_lens:] 266 | else: 267 | audio_prompts = original_audio_prompts 268 | text_prompts = original_text_prompts 269 | # Decode with Vocos 270 | frames = complete_tokens.permute(1,0,2) 271 | features = vocos.codes_to_features(frames) 272 | samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device)) 273 | return samples.squeeze().cpu().numpy() 274 | else: 275 | raise ValueError(f"No such mode {mode}") -------------------------------------------------------------------------------- /utils/prompt_making.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchaudio 4 | import logging 5 | import langid 6 | import whisper 7 | langid.set_languages(['en', 'zh', 'ja']) 8 | 9 | import numpy as np 10 | from data.tokenizer import ( 11 | AudioTokenizer, 12 | tokenize_audio, 13 | ) 14 | from data.collation import get_text_token_collater 15 | from utils.g2p import PhonemeBpeTokenizer 16 | 17 | from macros import * 18 | 19 | text_tokenizer = PhonemeBpeTokenizer(tokenizer_path="./utils/g2p/bpe_69.json") 20 | text_collater = get_text_token_collater() 21 | 22 | device = torch.device("cpu") 23 | if torch.cuda.is_available(): 24 | device = torch.device("cuda", 0) 25 | 26 | codec = AudioTokenizer(device) 27 | 28 | if not os.path.exists("./whisper/"): os.mkdir("./whisper/") 29 | whisper_model = None 30 | 31 | @torch.no_grad() 32 | def transcribe_one(model, audio_path): 33 | # load audio and pad/trim it to fit 30 seconds 34 | audio = whisper.load_audio(audio_path) 35 | audio = whisper.pad_or_trim(audio) 36 | 37 | # make log-Mel spectrogram and move to the same device as the model 38 | mel = whisper.log_mel_spectrogram(audio).to(model.device) 39 | 40 | # detect the spoken language 41 | _, probs = model.detect_language(mel) 42 | print(f"Detected language: {max(probs, key=probs.get)}") 43 | lang = max(probs, key=probs.get) 44 | # decode the audio 45 | options = whisper.DecodingOptions(temperature=1.0, best_of=5, fp16=False if device == torch.device("cpu") else True, sample_len=150) 46 | result = whisper.decode(model, mel, options) 47 | 48 | # print the recognized text 49 | print(result.text) 50 | 51 | text_pr = result.text 52 | if text_pr.strip(" ")[-1] not in "?!.,。,?!。、": 53 | text_pr += "." 54 | return lang, text_pr 55 | 56 | def make_prompt(name, audio_prompt_path, transcript=None): 57 | global model, text_collater, text_tokenizer, codec 58 | wav_pr, sr = torchaudio.load(audio_prompt_path) 59 | # check length 60 | if wav_pr.size(-1) / sr > 15: 61 | raise ValueError(f"Prompt too long, expect length below 15 seconds, got {wav_pr / sr} seconds.") 62 | if wav_pr.size(0) == 2: 63 | wav_pr = wav_pr.mean(0, keepdim=True) 64 | text_pr, lang_pr = make_transcript(name, wav_pr, sr, transcript) 65 | 66 | # tokenize audio 67 | encoded_frames = tokenize_audio(codec, (wav_pr, sr)) 68 | audio_tokens = encoded_frames[0][0].transpose(2, 1).cpu().numpy() 69 | 70 | # tokenize text 71 | phonemes, langs = text_tokenizer.tokenize(text=f"{text_pr}".strip()) 72 | text_tokens, enroll_x_lens = text_collater( 73 | [ 74 | phonemes 75 | ] 76 | ) 77 | 78 | message = f"Detected language: {lang_pr}\n Detected text {text_pr}\n" 79 | 80 | # save as npz file 81 | save_path = os.path.join("./customs/", f"{name}.npz") 82 | np.savez(save_path, audio_tokens=audio_tokens, text_tokens=text_tokens, lang_code=lang2code[lang_pr]) 83 | logging.info(f"Successful. Prompt saved to {save_path}") 84 | 85 | 86 | def make_transcript(name, wav, sr, transcript=None): 87 | 88 | if not isinstance(wav, torch.FloatTensor): 89 | wav = torch.tensor(wav) 90 | if wav.abs().max() > 1: 91 | wav /= wav.abs().max() 92 | if wav.size(-1) == 2: 93 | wav = wav.mean(-1, keepdim=False) 94 | if wav.ndim == 1: 95 | wav = wav.unsqueeze(0) 96 | assert wav.ndim and wav.size(0) == 1 97 | if transcript is None or transcript == "": 98 | logging.info("Transcript not given, using Whisper...") 99 | global whisper_model 100 | if whisper_model is None: 101 | whisper_model = whisper.load_model("medium", download_root=os.path.join(os.getcwd(), "whisper")) 102 | whisper_model.to(device) 103 | torchaudio.save(f"./prompts/{name}.wav", wav, sr) 104 | lang, text = transcribe_one(whisper_model, f"./prompts/{name}.wav") 105 | lang_token = lang2token[lang] 106 | text = lang_token + text + lang_token 107 | os.remove(f"./prompts/{name}.wav") 108 | whisper_model.cpu() 109 | else: 110 | text = transcript 111 | lang, _ = langid.classify(text) 112 | lang_token = lang2token[lang] 113 | text = lang_token + text + lang_token 114 | 115 | torch.cuda.empty_cache() 116 | return text, lang -------------------------------------------------------------------------------- /utils/sentence_cutter.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | import jieba 3 | import sudachipy 4 | import langid 5 | langid.set_languages(['en', 'zh', 'ja']) 6 | 7 | def split_text_into_sentences(text): 8 | if langid.classify(text)[0] == "en": 9 | sentences = nltk.tokenize.sent_tokenize(text) 10 | 11 | return sentences 12 | elif langid.classify(text)[0] == "zh": 13 | sentences = [] 14 | segs = jieba.cut(text, cut_all=False) 15 | segs = list(segs) 16 | start = 0 17 | for i, seg in enumerate(segs): 18 | if seg in ["。", "!", "?", "……"]: 19 | sentences.append("".join(segs[start:i + 1])) 20 | start = i + 1 21 | if start < len(segs): 22 | sentences.append("".join(segs[start:])) 23 | 24 | return sentences 25 | elif langid.classify(text)[0] == "ja": 26 | sentences = [] 27 | tokenizer = sudachipy.Dictionary().create() 28 | tokens = tokenizer.tokenize(text) 29 | current_sentence = "" 30 | 31 | for token in tokens: 32 | current_sentence += token.surface() 33 | if token.part_of_speech()[0] == "補助記号" and token.part_of_speech()[1] == "句点": 34 | sentences.append(current_sentence) 35 | current_sentence = "" 36 | 37 | if current_sentence: 38 | sentences.append(current_sentence) 39 | 40 | return sentences 41 | 42 | raise RuntimeError("It is impossible to reach here.") 43 | 44 | long_text = """ 45 | This is a very long paragraph, so most TTS model is unable to handle it. Hence, we have to split it into several sentences. With the help of NLTK, we can split it into sentences. However, the punctuation is not preserved, so we have to add it back. How are we going to do write this code? Let's see. 46 | """ 47 | 48 | long_text = """ 49 | 现在我们要来尝试一下中文分句。因为很不幸的是,NLTK不支持中文分句。幸运的是,我们可以使用jieba来分句。但是,jieba分句后,标点符号会丢失,所以我们要手动添加回去。我现在正在想办法把这个例句写的更长更复杂一点,来测试jieba分句的性能。嗯......省略号,感觉不太好,因为省略号不是句号,所以jieba不会把它当作句子的结尾。会这样吗?我们来试试看。 50 | """ 51 | 52 | long_text = """ 53 | これなら、英語と中国語の分句もできる。でも、日本語はどうする?まつわ、ChatGPTに僕と教えてください。ちょーと待ってください。あ、出来た! 54 | """ -------------------------------------------------------------------------------- /utils/symbol_table.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang) 2 | # 3 | # See ../../../LICENSE for clarification regarding multiple authors 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from dataclasses import dataclass 18 | from dataclasses import field 19 | from typing import Dict 20 | from typing import Generic 21 | from typing import List 22 | from typing import Optional 23 | from typing import TypeVar 24 | from typing import Union 25 | 26 | Symbol = TypeVar('Symbol') 27 | 28 | 29 | # Disable __repr__ otherwise it could freeze e.g. Jupyter. 30 | @dataclass(repr=False) 31 | class SymbolTable(Generic[Symbol]): 32 | '''SymbolTable that maps symbol IDs, found on the FSA arcs to 33 | actual objects. These objects can be arbitrary Python objects 34 | that can serve as keys in a dictionary (i.e. they need to be 35 | hashable and immutable). 36 | 37 | The SymbolTable can only be read to/written from disk if the 38 | symbols are strings. 39 | ''' 40 | _id2sym: Dict[int, Symbol] = field(default_factory=dict) 41 | '''Map an integer to a symbol. 42 | ''' 43 | 44 | _sym2id: Dict[Symbol, int] = field(default_factory=dict) 45 | '''Map a symbol to an integer. 46 | ''' 47 | 48 | _next_available_id: int = 1 49 | '''A helper internal field that helps adding new symbols 50 | to the table efficiently. 51 | ''' 52 | 53 | eps: Symbol = '' 54 | '''Null symbol, always mapped to index 0. 55 | ''' 56 | 57 | def __post_init__(self): 58 | for idx, sym in self._id2sym.items(): 59 | assert self._sym2id[sym] == idx 60 | assert idx >= 0 61 | 62 | for sym, idx in self._sym2id.items(): 63 | assert idx >= 0 64 | assert self._id2sym[idx] == sym 65 | 66 | if 0 not in self._id2sym: 67 | self._id2sym[0] = self.eps 68 | self._sym2id[self.eps] = 0 69 | else: 70 | assert self._id2sym[0] == self.eps 71 | assert self._sym2id[self.eps] == 0 72 | 73 | self._next_available_id = max(self._id2sym) + 1 74 | 75 | @staticmethod 76 | def from_str(s: str) -> 'SymbolTable': 77 | '''Build a symbol table from a string. 78 | 79 | The string consists of lines. Every line has two fields separated 80 | by space(s), tab(s) or both. The first field is the symbol and the 81 | second the integer id of the symbol. 82 | 83 | Args: 84 | s: 85 | The input string with the format described above. 86 | Returns: 87 | An instance of :class:`SymbolTable`. 88 | ''' 89 | id2sym: Dict[int, str] = dict() 90 | sym2id: Dict[str, int] = dict() 91 | 92 | for line in s.split('\n'): 93 | fields = line.split() 94 | if len(fields) == 0: 95 | continue # skip empty lines 96 | assert len(fields) == 2, \ 97 | f'Expect a line with 2 fields. Given: {len(fields)}' 98 | sym, idx = fields[0], int(fields[1]) 99 | assert sym not in sym2id, f'Duplicated symbol {sym}' 100 | assert idx not in id2sym, f'Duplicated id {idx}' 101 | id2sym[idx] = sym 102 | sym2id[sym] = idx 103 | 104 | eps = id2sym.get(0, '') 105 | 106 | return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=eps) 107 | 108 | @staticmethod 109 | def from_file(filename: str) -> 'SymbolTable': 110 | '''Build a symbol table from file. 111 | 112 | Every line in the symbol table file has two fields separated by 113 | space(s), tab(s) or both. The following is an example file: 114 | 115 | .. code-block:: 116 | 117 | 0 118 | a 1 119 | b 2 120 | c 3 121 | 122 | Args: 123 | filename: 124 | Name of the symbol table file. Its format is documented above. 125 | 126 | Returns: 127 | An instance of :class:`SymbolTable`. 128 | 129 | ''' 130 | with open(filename, 'r', encoding='utf-8') as f: 131 | return SymbolTable.from_str(f.read().strip()) 132 | 133 | def to_str(self) -> str: 134 | ''' 135 | Returns: 136 | Return a string representation of this object. You can pass 137 | it to the method ``from_str`` to recreate an identical object. 138 | ''' 139 | s = '' 140 | for idx, symbol in sorted(self._id2sym.items()): 141 | s += f'{symbol} {idx}\n' 142 | return s 143 | 144 | def to_file(self, filename: str): 145 | '''Serialize the SymbolTable to a file. 146 | 147 | Every line in the symbol table file has two fields separated by 148 | space(s), tab(s) or both. The following is an example file: 149 | 150 | .. code-block:: 151 | 152 | 0 153 | a 1 154 | b 2 155 | c 3 156 | 157 | Args: 158 | filename: 159 | Name of the symbol table file. Its format is documented above. 160 | ''' 161 | with open(filename, 'w', encoding='utf-8') as f: 162 | for idx, symbol in sorted(self._id2sym.items()): 163 | print(symbol, idx, file=f) 164 | 165 | def add(self, symbol: Symbol, index: Optional[int] = None) -> int: 166 | '''Add a new symbol to the SymbolTable. 167 | 168 | Args: 169 | symbol: 170 | The symbol to be added. 171 | index: 172 | Optional int id to which the symbol should be assigned. 173 | If it is not available, a ValueError will be raised. 174 | 175 | Returns: 176 | The int id to which the symbol has been assigned. 177 | ''' 178 | # Already in the table? Return its ID. 179 | if symbol in self._sym2id: 180 | return self._sym2id[symbol] 181 | # Specific ID not provided - use next available. 182 | if index is None: 183 | index = self._next_available_id 184 | # Specific ID provided but not available. 185 | if index in self._id2sym: 186 | raise ValueError(f"Cannot assign id '{index}' to '{symbol}' - " 187 | f"already occupied by {self._id2sym[index]}") 188 | self._sym2id[symbol] = index 189 | self._id2sym[index] = symbol 190 | 191 | # Update next available ID if needed 192 | if self._next_available_id <= index: 193 | self._next_available_id = index + 1 194 | 195 | return index 196 | 197 | def get(self, k: Union[int, Symbol]) -> Union[Symbol, int]: 198 | '''Get a symbol for an id or get an id for a symbol 199 | 200 | Args: 201 | k: 202 | If it is an id, it tries to find the symbol corresponding 203 | to the id; if it is a symbol, it tries to find the id 204 | corresponding to the symbol. 205 | 206 | Returns: 207 | An id or a symbol depending on the given `k`. 208 | ''' 209 | if isinstance(k, int): 210 | return self._id2sym[k] 211 | else: 212 | return self._sym2id[k] 213 | 214 | def merge(self, other: 'SymbolTable') -> 'SymbolTable': 215 | '''Create a union of two SymbolTables. 216 | Raises an AssertionError if the same IDs are occupied by 217 | different symbols. 218 | 219 | Args: 220 | other: 221 | A symbol table to merge with ``self``. 222 | 223 | Returns: 224 | A new symbol table. 225 | ''' 226 | self._check_compatible(other) 227 | 228 | id2sym = {**self._id2sym, **other._id2sym} 229 | sym2id = {**self._sym2id, **other._sym2id} 230 | 231 | return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=self.eps) 232 | 233 | def _check_compatible(self, other: 'SymbolTable') -> None: 234 | # Epsilon compatibility 235 | assert self.eps == other.eps, f'Mismatched epsilon symbol: ' \ 236 | f'{self.eps} != {other.eps}' 237 | # IDs compatibility 238 | common_ids = set(self._id2sym).intersection(other._id2sym) 239 | for idx in common_ids: 240 | assert self[idx] == other[idx], f'ID conflict for id: {idx}, ' \ 241 | f'self[idx] = "{self[idx]}", ' \ 242 | f'other[idx] = "{other[idx]}"' 243 | # Symbols compatibility 244 | common_symbols = set(self._sym2id).intersection(other._sym2id) 245 | for sym in common_symbols: 246 | assert self[sym] == other[sym], f'ID conflict for id: {sym}, ' \ 247 | f'self[sym] = "{self[sym]}", ' \ 248 | f'other[sym] = "{other[sym]}"' 249 | 250 | def __getitem__(self, item: Union[int, Symbol]) -> Union[Symbol, int]: 251 | return self.get(item) 252 | 253 | def __contains__(self, item: Union[int, Symbol]) -> bool: 254 | if isinstance(item, int): 255 | return item in self._id2sym 256 | else: 257 | return item in self._sym2id 258 | 259 | def __len__(self) -> int: 260 | return len(self._id2sym) 261 | 262 | def __eq__(self, other: 'SymbolTable') -> bool: 263 | if len(self) != len(other): 264 | return False 265 | 266 | for s in self.symbols: 267 | if self[s] != other[s]: 268 | return False 269 | 270 | return True 271 | 272 | @property 273 | def ids(self) -> List[int]: 274 | '''Returns a list of integer IDs corresponding to the symbols. 275 | ''' 276 | ans = list(self._id2sym.keys()) 277 | ans.sort() 278 | return ans 279 | 280 | @property 281 | def symbols(self) -> List[Symbol]: 282 | '''Returns a list of symbols (e.g., strings) corresponding to 283 | the integer IDs. 284 | ''' 285 | ans = list(self._sym2id.keys()) 286 | ans.sort() 287 | return ans 288 | --------------------------------------------------------------------------------