├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── app.py
├── attentions.py
├── bert
├── ProsodyModel.py
├── __init__.py
├── config.json
├── prosody_tool.py
└── vocab.txt
├── commons.py
├── configs
├── bert_vits.json
└── bert_vits_student.json
├── data
└── 000001-010000.txt
├── data_utils.py
├── filelists
├── all.txt
├── train.txt
└── valid.txt
├── losses.py
├── mel_processing.py
├── model_onnx.py
├── model_onnx_stream.py
├── models.py
├── modules.py
├── monotonic_align
├── __init__.py
├── core.pyx
└── setup.py
├── requirements.txt
├── text
├── __init__.py
├── pinyin-local.txt
└── symbols.py
├── train.py
├── transforms.py
├── utils.py
├── vits_infer.py
├── vits_infer_item.txt
├── vits_infer_no_bert.py
├── vits_infer_onnx.py
├── vits_infer_onnx_stream.py
├── vits_infer_out
├── bert_vits_1.wav
├── bert_vits_2.wav
└── bert_vits_3.wav
├── vits_infer_pause.py
├── vits_infer_stream.py
├── vits_pinyin.py
├── vits_prepare.py
└── vits_resample.py
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | We as members, contributors, and leaders pledge to make participation in our
6 | community a harassment-free experience for everyone, regardless of age, body
7 | size, visible or invisible disability, ethnicity, sex characteristics, gender
8 | identity and expression, level of experience, education, socio-economic status,
9 | nationality, personal appearance, race, caste, color, religion, or sexual
10 | identity and orientation.
11 |
12 | We pledge to act and interact in ways that contribute to an open, welcoming,
13 | diverse, inclusive, and healthy community.
14 |
15 | ## Our Standards
16 |
17 | Examples of behavior that contributes to a positive environment for our
18 | community include:
19 |
20 | * Demonstrating empathy and kindness toward other people
21 | * Being respectful of differing opinions, viewpoints, and experiences
22 | * Giving and gracefully accepting constructive feedback
23 | * Accepting responsibility and apologizing to those affected by our mistakes,
24 | and learning from the experience
25 | * Focusing on what is best not just for us as individuals, but for the overall
26 | community
27 |
28 | Examples of unacceptable behavior include:
29 |
30 | * The use of sexualized language or imagery, and sexual attention or advances of
31 | any kind
32 | * Trolling, insulting or derogatory comments, and personal or political attacks
33 | * Public or private harassment
34 | * Publishing others' private information, such as a physical or email address,
35 | without their explicit permission
36 | * Other conduct which could reasonably be considered inappropriate in a
37 | professional setting
38 |
39 | ## Enforcement Responsibilities
40 |
41 | Community leaders are responsible for clarifying and enforcing our standards of
42 | acceptable behavior and will take appropriate and fair corrective action in
43 | response to any behavior that they deem inappropriate, threatening, offensive,
44 | or harmful.
45 |
46 | Community leaders have the right and responsibility to remove, edit, or reject
47 | comments, commits, code, wiki edits, issues, and other contributions that are
48 | not aligned to this Code of Conduct, and will communicate reasons for moderation
49 | decisions when appropriate.
50 |
51 | ## Scope
52 |
53 | This Code of Conduct applies within all community spaces, and also applies when
54 | an individual is officially representing the community in public spaces.
55 | Examples of representing our community include using an official e-mail address,
56 | posting via an official social media account, or acting as an appointed
57 | representative at an online or offline event.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported to the community leaders responsible for enforcement.
63 | All complaints will be reviewed and investigated promptly and fairly.
64 |
65 | All community leaders are obligated to respect the privacy and security of the
66 | reporter of any incident.
67 |
68 | ## Enforcement Guidelines
69 |
70 | Community leaders will follow these Community Impact Guidelines in determining
71 | the consequences for any action they deem in violation of this Code of Conduct:
72 |
73 | ### 1. Correction
74 |
75 | **Community Impact**: Use of inappropriate language or other behavior deemed
76 | unprofessional or unwelcome in the community.
77 |
78 | **Consequence**: A private, written warning from community leaders, providing
79 | clarity around the nature of the violation and an explanation of why the
80 | behavior was inappropriate. A public apology may be requested.
81 |
82 | ### 2. Warning
83 |
84 | **Community Impact**: A violation through a single incident or series of
85 | actions.
86 |
87 | **Consequence**: A warning with consequences for continued behavior. No
88 | interaction with the people involved, including unsolicited interaction with
89 | those enforcing the Code of Conduct, for a specified period of time. This
90 | includes avoiding interactions in community spaces as well as external channels
91 | like social media. Violating these terms may lead to a temporary or permanent
92 | ban.
93 |
94 | ### 3. Temporary Ban
95 |
96 | **Community Impact**: A serious violation of community standards, including
97 | sustained inappropriate behavior.
98 |
99 | **Consequence**: A temporary ban from any sort of interaction or public
100 | communication with the community for a specified period of time. No public or
101 | private interaction with the people involved, including unsolicited interaction
102 | with those enforcing the Code of Conduct, is allowed during this period.
103 | Violating these terms may lead to a permanent ban.
104 |
105 | ### 4. Permanent Ban
106 |
107 | **Community Impact**: Demonstrating a pattern of violation of community
108 | standards, including sustained inappropriate behavior, harassment of an
109 | individual, or aggression toward or disparagement of classes of individuals.
110 |
111 | **Consequence**: A permanent ban from any sort of public interaction within the
112 | community.
113 |
114 | ## Attribution
115 |
116 | This Code of Conduct is adapted from the [Contributor Covenant][homepage],
117 | version 2.1, available at
118 | [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
119 |
120 | Community Impact Guidelines were inspired by
121 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC].
122 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | First off, thanks for taking the time to contribute!👏
4 |
5 | ## Fork the Repository 🍴
6 |
7 | 1. Start by forking the repository. You can do this by clicking the "Fork" button in the upper right corner of the repository page. This will create a copy of the repository in your GitHub account.
8 |
9 | ## Clone Your Fork 📥
10 |
11 | 2. Clone your newly created fork of the repository to your local machine with the following command:
12 |
13 | ```bash
14 | git clone https://github.com/your-username/static_status.git
15 | ```
16 |
17 | ## Create a New Branch 🌿
18 |
19 | 3. Create a new branch for the specific issue or feature you are working on. Use a descriptive branch name:
20 |
21 | ```bash
22 | git checkout -b "branch_name"
23 | ```
24 |
25 | ## Submitting Changes 🚀
26 |
27 | 4. Make your desired changes to the codebase.
28 |
29 | 5. Stage your changes using the following command:
30 |
31 | ```bash
32 | git add .
33 | ```
34 |
35 | 6. Commit your changes with a clear and concise commit message:
36 |
37 | ```bash
38 | git commit -m "A brief summary of the commit."
39 | ```
40 |
41 | ## Push Your Changes 🚢
42 |
43 | 7. Push your local commits to your remote repository:
44 |
45 | ```bash
46 | git push origin "branch_name"
47 | ```
48 |
49 | ## Create a Pull Request 🌟
50 |
51 | 8. Go to your forked repository on GitHub and click on the "New Pull Request" button. This will open a new pull request to the original repository.
52 |
53 | ## Coding Style 📝
54 |
55 | Start reading the code, and you'll get the hang of it. It is optimized for readability:
56 |
57 | - Variables must be uppercase and should begin with MY\_.
58 | - Functions must be lowercase.
59 | - Check your shell scripts with ShellCheck before submitting.
60 | - Please use tabs to indent.
61 |
62 | One more thing:
63 |
64 | Keep it simple! 👍
65 |
66 | Thanks! ❤️❤️❤️
67 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 PlayVoice
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Best practice TTS based on BERT and VITS with some Natural Speech Features Of Microsoft
2 |
3 | [](https://huggingface.co/spaces/maxmax20160403/vits_chinese)
4 |
5 |
6 |
7 |
8 |
9 | ## 这是一个用于TTS算法学习的项目,如果您在寻找直接用于生产的TTS,本项目可能不适合您!
10 | https://user-images.githubusercontent.com/16432329/220678182-4775dec8-9229-4578-870f-2eebc3a5d660.mp4
11 |
12 | > 天空呈现的透心的蓝,像极了当年。总在这样的时候,透过窗棂,心,在天空里无尽的游弋!柔柔的,浓浓的,痴痴的风,牵引起心底灵动的思潮;情愫悠悠,思情绵绵,风里默坐,红尘中的浅醉,诗词中的优柔,任那自在飞花轻似梦的情怀,裁一束霓衣,织就清浅淡薄的安寂。
13 | >
14 | > 风的影子翻阅过淡蓝色的信笺,柔和的文字浅浅地漫过我安静的眸,一如几朵悠闲的云儿,忽而氤氲成汽,忽而修饰成花,铅华洗尽后的透彻和靓丽,爽爽朗朗,轻轻盈盈
15 | >
16 | > 时光仿佛有穿越到了从前,在你诗情画意的眼波中,在你舒适浪漫的暇思里,我如风中的思绪徜徉广阔天际,仿佛一片沾染了快乐的羽毛,在云环影绕颤动里浸润着风的呼吸,风的诗韵,那清新的耳语,那婉约的甜蜜,那恬淡的温馨,将一腔情澜染得愈发的缠绵。
17 |
18 | ### Features,特性
19 | 1, Hidden prosody embedding from **BERT**,get natural pauses in grammar
20 |
21 | 2, Infer loss from **NaturalSpeech**,get less sound error
22 |
23 | 3, Framework of **VITS**,get high audio quality
24 |
25 | 4, Module-wise Distillation, get speedup
26 |
27 | :heartpulse:**Tip**: It is recommended to use **Infer Loss** fine-tune model after base model trained, and freeze **PosteriorEncoder** during fine-tuning.
28 |
29 | :heartpulse:**意思就是:初步训练时,不用loss_kl_r;训练好后,添加loss_kl_r继续训练,稍微训练一下就行了,如果音频质量差,可以给loss_kl_r乘以一个小于1的系数、降低loss_kl_r对模型的贡献;继续训练时,可以尝试冻结音频编码器Posterior Encoder;总之,玩法很多,需要多尝试!**
30 |
31 |
32 |
33 | 
34 |
35 |
36 |
37 | ### 为什么不升级为VITS2
38 | VITS2最重要的改进是将Flow的WaveNet模块使用Transformer替换,而在TTS流式实现中,通常需要用纯CNN替换Transformer。
39 |
40 | ### Online demo,在线体验
41 | https://huggingface.co/spaces/maxmax20160403/vits_chinese
42 |
43 | ### Install,安装依赖和MAS对齐
44 |
45 | > pip install -r requirements.txt
46 |
47 | > cd monotonic_align
48 |
49 | > python setup.py build_ext --inplace
50 |
51 | ### Infer with Pretrained model,用示例模型推理
52 |
53 | Get from release page [vits_chinese/releases/](https://github.com/PlayVoice/vits_chinese/releases/tag/v1.0)
54 |
55 | put [prosody_model.pt](https://github.com/PlayVoice/vits_chinese/releases/tag/v1.0) To ./bert/prosody_model.pt
56 |
57 | put [vits_bert_model.pth](https://github.com/PlayVoice/vits_chinese/releases/tag/v1.0) To ./vits_bert_model.pth
58 |
59 | ```
60 | python vits_infer.py --config ./configs/bert_vits.json --model vits_bert_model.pth
61 | ```
62 |
63 | ./vits_infer_out have the waves inferred, listen !!!
64 |
65 | ### Infer with chunk wave streaming out,分块流式推理
66 |
67 | as key parameter, ***hop_frame = ∑decoder.ups.padding*** :heartpulse:
68 |
69 | ```
70 | python vits_infer_stream.py --config ./configs/bert_vits.json --model vits_bert_model.pth
71 | ```
72 |
73 | ### Ceil duration affect naturalness
74 | So change **w_ceil = torch.ceil(w)** to **w_ceil = torch.ceil(w + 0.35)**
75 |
76 | ### All Thanks To Our Contributors:
77 |
78 |
79 |
80 |
81 | ### Train,训练
82 | download baker data [https://aistudio.baidu.com/datasetdetail/36741](https://aistudio.baidu.com/datasetdetail/36741), more info: https://www.data-baker.com/data/index/TNtts/
83 |
84 | change sample rate of waves to **16kHz**, and put waves to ./data/waves
85 |
86 | ```
87 | python vits_resample.py -w [input path]:[./data/Wave/] -o ./data/waves -s 16000
88 | ```
89 |
90 | put 000001-010000.txt to ./data/000001-010000.txt
91 |
92 | ```
93 | python vits_prepare.py -c ./configs/bert_vits.json
94 | ```
95 |
96 | ```
97 | python train.py -c configs/bert_vits.json -m bert_vits
98 | ```
99 |
100 | 
101 |
102 | ### 额外说明
103 |
104 | 原始标注为
105 | ``` c
106 | 000001 卡尔普#2陪外孙#1玩滑梯#4。
107 | ka2 er2 pu3 pei2 wai4 sun1 wan2 hua2 ti1
108 | 000002 假语村言#2别再#1拥抱我#4。
109 | jia2 yu3 cun1 yan2 bie2 zai4 yong1 bao4 wo3
110 | ```
111 |
112 | 标注规整后:
113 | - BERT需要汉字 `卡尔普陪外孙玩滑梯。` (包括标点)
114 | - TTS需要声韵母 `sil k a2 ^ er2 p u3 p ei2 ^ uai4 s uen1 ^ uan2 h ua2 t i1 sp sil`
115 | ``` c
116 | 000001 卡尔普陪外孙玩滑梯。
117 | ka2 er2 pu3 pei2 wai4 sun1 wan2 hua2 ti1
118 | sil k a2 ^ er2 p u3 p ei2 ^ uai4 s uen1 ^ uan2 h ua2 t i1 sp sil
119 | 000002 假语村言别再拥抱我。
120 | jia2 yu3 cun1 yan2 bie2 zai4 yong1 bao4 wo3
121 | sil j ia2 ^ v3 c uen1 ^ ian2 b ie2 z ai4 ^ iong1 b ao4 ^ uo3 sp sil
122 | ```
123 |
124 | 训练标注为
125 | ```
126 | ./data/wavs/000001.wav|./data/temps/000001.spec.pt|./data/berts/000001.npy|sil k a2 ^ er2 p u3 p ei2 ^ uai4 s uen1 ^ uan2 h ua2 t i1 sp sil
127 | ./data/wavs/000002.wav|./data/temps/000002.spec.pt|./data/berts/000002.npy|sil j ia2 ^ v3 c uen1 ^ ian2 b ie2 z ai4 ^ iong1 b ao4 ^ uo3 sp sil
128 | ```
129 |
130 | 遇到这句话会出错
131 | ```
132 | 002365 这图#2难不成#2是#1P过的#4?
133 | zhe4 tu2 nan2 bu4 cheng2 shi4 P IY1 guo4 de5
134 | ```
135 |
136 | ### 拼音错误修改
137 | 将正确的词语和拼音写入文件: [./text/pinyin-local.txt](./text/pinyin-local.txt)
138 | ```
139 | 渐渐 jian4 jian4
140 | 浅浅 qian3 qian3
141 | ```
142 |
143 | ### 数字播报支持
144 | 已支持,基于WeNet开源社区[WeTextProcessing](https://github.com/wenet-e2e/WeTextProcessing);当然,不可能是完美的
145 |
146 | ### 不使用Bert也能推理
147 | ```
148 | python vits_infer_no_bert.py --config ./configs/bert_vits.json --model vits_bert_model.pth
149 | ```
150 | 虽然训练使用了Bert,但推理可以完全不用Bert,牺牲自然停顿来适配低计算资源设备,比如手机
151 |
152 | 低资源设备通常会分句合成,这样牺牲的自然停顿也没那么明显
153 |
154 | ### ONNX非流式
155 | 导出:会有许多警告,直接忽略
156 | ```
157 | python model_onnx.py --config configs/bert_vits.json --model vits_bert_model.pth
158 | ```
159 | 推理
160 | ```
161 | python vits_infer_onnx.py --model vits-chinese.onnx
162 | ```
163 |
164 | ### ONNX流式
165 |
166 | 具体实现,将VITS拆解为两个模型,取名为Encoder和Decoder。
167 |
168 | - Encoder包括TextEncoder与DurationPredictor等;
169 |
170 | - Decoder包括ResidualCouplingBlock与Generator等;
171 |
172 | - ResidualCouplingBlock,即Flow,可以放在Encoder或Decoder,放在Decoder需要更大的**hop_frame**
173 |
174 | 并且将推理逻辑也进行了切分;特别的,先验分布的采样过程放在了Encoder中:
175 | ```
176 | z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
177 | ```
178 |
179 | ONNX流式模型导出
180 | ```
181 | python model_onnx_stream.py --config configs/bert_vits.json --model vits_bert_model.pth
182 | ```
183 |
184 | ONNX流式模型推理
185 | ```
186 | python vits_infer_onnx_stream.py --encoder vits-chinese-encoder.onnx --decoder vits-chinese-decoder.onnx
187 | ```
188 |
189 | 在流式推理中,**hop_frame**是一个重要参数,需要去尝试出合适的值
190 |
191 | ### Model compression based on knowledge distillation,应该叫迁移学习还是知识蒸馏呢?
192 | Student model has 53M size and 3× speed of teacher model.
193 |
194 | To train:
195 |
196 | ```
197 | python train.py -c configs/bert_vits_student.json -m bert_vits_student
198 | ```
199 |
200 | To infer, get [student model](https://github.com/PlayVoice/vits_chinese/releases/tag/v2.0) at the release page
201 |
202 | ```
203 | python vits_infer.py --config ./configs/bert_vits_student.json --model vits_bert_student.pth
204 | ```
205 |
206 | ### 代码来源
207 | [Microsoft's NaturalSpeech: End-to-End Text to Speech Synthesis with Human-Level Quality](https://arxiv.org/abs/2205.04421)
208 |
209 | [Nix-TTS: Lightweight and End-to-End Text-to-Speech via Module-wise Distillation](https://arxiv.org/abs/2203.15643)
210 |
211 | https://github.com/Executedone/Chinese-FastSpeech2 **bert prosody**
212 |
213 | https://github.com/wenet-e2e/WeTextProcessing
214 |
215 | [https://github.com/TensorSpeech/TensorFlowTTS](https://github.com/TensorSpeech/TensorFlowTTS/blob/master/tensorflow_tts/processor/baker.py) **Heavily depend on**
216 |
217 | https://github.com/jaywalnut310/vits
218 |
219 | https://github.com/wenet-e2e/wetts
220 |
221 | https://github.com/csukuangfj **onnx and android**
222 |
223 | ### BERT应用于TTS
224 | 2019 BERT+Tacotron2: Pre-trained text embeddings for enhanced text-tospeech synthesis.
225 |
226 | 2020 BERT+Tacotron2-MultiSpeaker: Improving prosody with linguistic and bert derived features in multi-speaker based mandarin chinese neural tts.
227 |
228 | 2021 BERT+Tacotron2: Extracting and predicting word-level style variations for speech synthesis.
229 |
230 | 2022 https://github.com/Executedone/Chinese-FastSpeech2
231 |
232 | 2023 BERT+VISINGER: Towards Improving the Expressiveness of Singing Voice Synthesis with BERT Derived Semantic Information.
233 |
234 | # AISHELL3多发音人训练,训练出的模型可用于克隆
235 | 切换代码分支[bert_vits_aishell3](https://github.com/PlayVoice/vits_chinese/tree/bert_vits_aishell3),对比分支代码可以看到**针对多发音人所做出的修改**
236 |
237 | ## 数据下载
238 | http://www.openslr.org/93/
239 |
240 | ## 采样率转换
241 | ```
242 | python prep_resample.py --wav aishell-3/train/wav/ --out vits_data/waves-16k
243 | ```
244 |
245 | ## 标注规范化(labels.txt,名称不能改)
246 | ```
247 | python prep_format_label.py --txt aishell-3/train/content.txt --out vits_data/labels.txt
248 | ```
249 |
250 | - 原始标注
251 | ```
252 | SSB00050001.wav 广 guang3 州 zhou1 女 nv3 大 da4 学 xue2 生 sheng1 登 deng1 山 shan1 失 shi1 联 lian2 四 si4 天 tian1 警 jing3 方 fang1 找 zhao3 到 dao4 疑 yi2 似 si4 女 nv3 尸 shi1
253 | SSB00050002.wav 尊 zhun1 重 zhong4 科 ke1 学 xue2 规 gui1 律 lv4 的 de5 要 yao1 求 qiu2
254 | SSB00050003.wav 七 qi1 路 lu4 无 wu2 人 ren2 售 shou4 票 piao4
255 | ```
256 | - 规范标注
257 | ```
258 | SSB00050001.wav 广州女大学生登山失联四天警方找到疑似女尸
259 | guang3 zhou1 nv3 da4 xue2 sheng1 deng1 shan1 shi1 lian2 si4 tian1 jing3 fang1 zhao3 dao4 yi2 si4 nv3 shi1
260 | SSB00050002.wav 尊重科学规律的要求
261 | zhun1 zhong4 ke1 xue2 gui1 lv4 de5 yao1 qiu2
262 | SSB00050003.wav 七路无人售票
263 | qi1 lu4 wu2 ren2 shou4 piao4
264 | ```
265 | ## 数据预处理
266 | ```
267 | python prep_bert.py --conf configs/bert_vits.json --data vits_data/
268 | ```
269 |
270 | 打印信息,在过滤本项目不支持的**儿化音**
271 |
272 | 生成 vits_data/speakers.txt
273 | ```
274 | {'SSB0005': 0, 'SSB0009': 1, 'SSB0011': 2..., 'SSB1956': 173}
275 | ```
276 | 生成 filelists
277 | ```
278 | 0|vits_data/waves-16k/SSB0005/SSB00050001.wav|vits_data/temps/SSB0005/SSB00050001.spec.pt|vits_data/berts/SSB0005/SSB00050001.npy|sil g uang3 zh ou1 n v3 d a4 x ve2 sh eng1 d eng1 sh an1 sh iii1 l ian2 s ii4 t ian1 j ing3 f ang1 zh ao3 d ao4 ^ i2 s ii4 n v3 sh iii1 sil
279 | 0|vits_data/waves-16k/SSB0005/SSB00050002.wav|vits_data/temps/SSB0005/SSB00050002.spec.pt|vits_data/berts/SSB0005/SSB00050002.npy|sil zh uen1 zh ong4 k e1 x ve2 g uei1 l v4 d e5 ^ iao1 q iou2 sil
280 | 0|vits_data/waves-16k/SSB0005/SSB00050004.wav|vits_data/temps/SSB0005/SSB00050004.spec.pt|vits_data/berts/SSB0005/SSB00050004.npy|sil h ei1 k e4 x van1 b u4 zh iii3 ^ iao4 b o1 d a2 m ou3 ^ i2 g e4 d ian4 h ua4 sil
281 | ```
282 | ## 数据调试
283 | ```
284 | python prep_debug.py
285 | ```
286 |
287 | ## 启动训练
288 |
289 | ```
290 | cd monotonic_align
291 |
292 | python setup.py build_ext --inplace
293 |
294 | cd -
295 |
296 | python train.py -c configs/bert_vits.json -m bert_vits
297 | ```
298 |
299 | ## 下载权重
300 | AISHELL3_G.pth:https://github.com/PlayVoice/vits_chinese/releases/v4.0
301 |
302 | ## 推理测试
303 | ```
304 | python vits_infer.py -c configs/bert_vits.json -m AISHELL3_G.pth -i 0
305 | ```
306 | -i 为发音人序号,取值范围:0 ~ 173
307 |
308 | **AISHELL3训练数据都是短短的一句话,所以,推理语句中不能有标点**
309 |
310 | ## 训练的AISHELL3模型,使用小米K2社区开源的AISHELL3模型来初始化训练权重,以节约训练时间
311 |
312 | K2开源模型 https://huggingface.co/jackyqs/vits-aishell3-175-chinese/tree/main 下载模型
313 |
314 | K2在线试用 https://huggingface.co/spaces/k2-fsa/text-to-speech
315 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | from models import SynthesizerTrn
2 | from vits_pinyin import VITS_PinYin
3 | from text import cleaned_text_to_sequence
4 | from text.symbols import symbols
5 | import gradio as gr
6 | import utils
7 | import torch
8 | import argparse
9 | import os
10 | import re
11 | import logging
12 |
13 | logging.getLogger('numba').setLevel(logging.WARNING)
14 | limitation = os.getenv("SYSTEM") == "spaces"
15 |
16 |
17 | def create_calback(net_g: SynthesizerTrn, tts_front: VITS_PinYin):
18 | def tts_calback(text, dur_scale):
19 | if limitation:
20 | text_len = len(re.sub("\[([A-Z]{2})\]", "", text))
21 | max_len = 150
22 | if text_len > max_len:
23 | return "Error: Text is too long", None
24 |
25 | phonemes, char_embeds = tts_front.chinese_to_phonemes(text)
26 | input_ids = cleaned_text_to_sequence(phonemes)
27 | with torch.no_grad():
28 | x_tst = torch.LongTensor(input_ids).unsqueeze(0).to(device)
29 | x_tst_lengths = torch.LongTensor([len(input_ids)]).to(device)
30 | x_tst_prosody = torch.FloatTensor(
31 | char_embeds).unsqueeze(0).to(device)
32 | audio = net_g.infer(x_tst, x_tst_lengths, x_tst_prosody, noise_scale=0.5,
33 | length_scale=dur_scale)[0][0, 0].data.cpu().float().numpy()
34 | del x_tst, x_tst_lengths, x_tst_prosody
35 | return "Success", (16000, audio)
36 |
37 | return tts_calback
38 |
39 |
40 | example = [['天空呈现的透心的蓝,像极了当年。总在这样的时候,透过窗棂,心,在天空里无尽的游弋!柔柔的,浓浓的,痴痴的风,牵引起心底灵动的思潮;情愫悠悠,思情绵绵,风里默坐,红尘中的浅醉,诗词中的优柔,任那自在飞花轻似梦的情怀,裁一束霓衣,织就清浅淡薄的安寂。', 1],
41 | ['风的影子翻阅过淡蓝色的信笺,柔和的文字浅浅地漫过我安静的眸,一如几朵悠闲的云儿,忽而氤氲成汽,忽而修饰成花,铅华洗尽后的透彻和靓丽,爽爽朗朗,轻轻盈盈', 1],
42 | ['时光仿佛有穿越到了从前,在你诗情画意的眼波中,在你舒适浪漫的暇思里,我如风中的思绪徜徉广阔天际,仿佛一片沾染了快乐的羽毛,在云环影绕颤动里浸润着风的呼吸,风的诗韵,那清新的耳语,那婉约的甜蜜,那恬淡的温馨,将一腔情澜染得愈发的缠绵。', 1],]
43 |
44 |
45 | if __name__ == "__main__":
46 | parser = argparse.ArgumentParser()
47 | parser.add_argument("--share", action="store_true",
48 | default=False, help="share gradio app")
49 | args = parser.parse_args()
50 |
51 | device = torch.device("cpu")
52 |
53 | # pinyin
54 | tts_front = VITS_PinYin("./bert", device)
55 |
56 | # config
57 | hps = utils.get_hparams_from_file("./configs/bert_vits.json")
58 |
59 | # model
60 | net_g = SynthesizerTrn(
61 | len(symbols),
62 | hps.data.filter_length // 2 + 1,
63 | hps.train.segment_size // hps.data.hop_length,
64 | **hps.model)
65 |
66 | model_path = "vits_bert_model.pth"
67 | utils.load_model(model_path, net_g)
68 | net_g.eval()
69 | net_g.to(device)
70 |
71 | tts_calback = create_calback(net_g, tts_front)
72 |
73 | app = gr.Blocks()
74 | with app:
75 | gr.Markdown("# Best TTS based on BERT and VITS with some Natural Speech Features Of Microsoft\n\n"
76 | "code : github.com/PlayVoice/vits_chinese\n\n"
77 | "1, Hidden prosody embedding from BERT,get natural pauses in grammar\n\n"
78 | "2, Infer loss from NaturalSpeech,get less sound error\n\n"
79 | "3, Framework of VITS,get high audio quality\n\n"
80 | "\n\n"
81 | "\n\n"
82 | "\n\n"
83 | )
84 |
85 | with gr.Tabs():
86 | with gr.TabItem("TTS"):
87 | with gr.Row():
88 | with gr.Column():
89 | textbox = gr.TextArea(label="Text",
90 | placeholder="Type your sentence here (Maximum 150 words)",
91 | value="中文语音合成", elem_id=f"tts-input")
92 | duration_slider = gr.Slider(minimum=0.1, maximum=5, value=1, step=0.1,
93 | label='速度 Speed')
94 | with gr.Column():
95 | text_output = gr.Textbox(label="Message")
96 | audio_output = gr.Audio(
97 | label="Output Audio", elem_id="tts-audio")
98 | btn = gr.Button("Generate!")
99 | btn.click(tts_calback,
100 | inputs=[textbox, duration_slider],
101 | outputs=[text_output, audio_output])
102 | gr.Examples(
103 | examples=example,
104 | inputs=[textbox, duration_slider],
105 | outputs=[text_output, audio_output],
106 | fn=tts_calback
107 | )
108 | app.queue(concurrency_count=3).launch(show_api=False, share=args.share)
109 |
--------------------------------------------------------------------------------
/attentions.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | import numpy as np
4 | import torch
5 | from torch import nn
6 | from torch.nn import functional as F
7 |
8 | import commons
9 | import modules
10 | from modules import LayerNorm
11 |
12 |
13 | class Encoder(nn.Module):
14 | def __init__(
15 | self,
16 | hidden_channels,
17 | filter_channels,
18 | n_heads,
19 | n_layers,
20 | kernel_size=1,
21 | p_dropout=0.0,
22 | window_size=4,
23 | **kwargs
24 | ):
25 | super().__init__()
26 | self.hidden_channels = hidden_channels
27 | self.filter_channels = filter_channels
28 | self.n_heads = n_heads
29 | self.n_layers = n_layers
30 | self.kernel_size = kernel_size
31 | self.p_dropout = p_dropout
32 | self.window_size = window_size
33 |
34 | self.drop = nn.Dropout(p_dropout)
35 | self.attn_layers = nn.ModuleList()
36 | self.norm_layers_1 = nn.ModuleList()
37 | self.ffn_layers = nn.ModuleList()
38 | self.norm_layers_2 = nn.ModuleList()
39 | for i in range(self.n_layers):
40 | self.attn_layers.append(
41 | MultiHeadAttention(
42 | hidden_channels,
43 | hidden_channels,
44 | n_heads,
45 | p_dropout=p_dropout,
46 | window_size=window_size,
47 | )
48 | )
49 | self.norm_layers_1.append(LayerNorm(hidden_channels))
50 | self.ffn_layers.append(
51 | FFN(
52 | hidden_channels,
53 | hidden_channels,
54 | filter_channels,
55 | kernel_size,
56 | p_dropout=p_dropout,
57 | )
58 | )
59 | self.norm_layers_2.append(LayerNorm(hidden_channels))
60 |
61 | def forward(self, x, x_mask):
62 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
63 | x = x * x_mask
64 | for i in range(self.n_layers):
65 | y = self.attn_layers[i](x, x, attn_mask)
66 | y = self.drop(y)
67 | x = self.norm_layers_1[i](x + y)
68 |
69 | y = self.ffn_layers[i](x, x_mask)
70 | y = self.drop(y)
71 | x = self.norm_layers_2[i](x + y)
72 | x = x * x_mask
73 | return x
74 |
75 |
76 | class Decoder(nn.Module):
77 | def __init__(
78 | self,
79 | hidden_channels,
80 | filter_channels,
81 | n_heads,
82 | n_layers,
83 | kernel_size=1,
84 | p_dropout=0.0,
85 | proximal_bias=False,
86 | proximal_init=True,
87 | **kwargs
88 | ):
89 | super().__init__()
90 | self.hidden_channels = hidden_channels
91 | self.filter_channels = filter_channels
92 | self.n_heads = n_heads
93 | self.n_layers = n_layers
94 | self.kernel_size = kernel_size
95 | self.p_dropout = p_dropout
96 | self.proximal_bias = proximal_bias
97 | self.proximal_init = proximal_init
98 |
99 | self.drop = nn.Dropout(p_dropout)
100 | self.self_attn_layers = nn.ModuleList()
101 | self.norm_layers_0 = nn.ModuleList()
102 | self.encdec_attn_layers = nn.ModuleList()
103 | self.norm_layers_1 = nn.ModuleList()
104 | self.ffn_layers = nn.ModuleList()
105 | self.norm_layers_2 = nn.ModuleList()
106 | for i in range(self.n_layers):
107 | self.self_attn_layers.append(
108 | MultiHeadAttention(
109 | hidden_channels,
110 | hidden_channels,
111 | n_heads,
112 | p_dropout=p_dropout,
113 | proximal_bias=proximal_bias,
114 | proximal_init=proximal_init,
115 | )
116 | )
117 | self.norm_layers_0.append(LayerNorm(hidden_channels))
118 | self.encdec_attn_layers.append(
119 | MultiHeadAttention(
120 | hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
121 | )
122 | )
123 | self.norm_layers_1.append(LayerNorm(hidden_channels))
124 | self.ffn_layers.append(
125 | FFN(
126 | hidden_channels,
127 | hidden_channels,
128 | filter_channels,
129 | kernel_size,
130 | p_dropout=p_dropout,
131 | causal=True,
132 | )
133 | )
134 | self.norm_layers_2.append(LayerNorm(hidden_channels))
135 |
136 | def forward(self, x, x_mask, h, h_mask):
137 | """
138 | x: decoder input
139 | h: encoder output
140 | """
141 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
142 | device=x.device, dtype=x.dtype
143 | )
144 | encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
145 | x = x * x_mask
146 | for i in range(self.n_layers):
147 | y = self.self_attn_layers[i](x, x, self_attn_mask)
148 | y = self.drop(y)
149 | x = self.norm_layers_0[i](x + y)
150 |
151 | y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
152 | y = self.drop(y)
153 | x = self.norm_layers_1[i](x + y)
154 |
155 | y = self.ffn_layers[i](x, x_mask)
156 | y = self.drop(y)
157 | x = self.norm_layers_2[i](x + y)
158 | x = x * x_mask
159 | return x
160 |
161 |
162 | class MultiHeadAttention(nn.Module):
163 | def __init__(
164 | self,
165 | channels,
166 | out_channels,
167 | n_heads,
168 | p_dropout=0.0,
169 | window_size=None,
170 | heads_share=True,
171 | block_length=None,
172 | proximal_bias=False,
173 | proximal_init=False,
174 | ):
175 | super().__init__()
176 | assert channels % n_heads == 0
177 |
178 | self.channels = channels
179 | self.out_channels = out_channels
180 | self.n_heads = n_heads
181 | self.p_dropout = p_dropout
182 | self.window_size = window_size
183 | self.heads_share = heads_share
184 | self.block_length = block_length
185 | self.proximal_bias = proximal_bias
186 | self.proximal_init = proximal_init
187 | self.attn = None
188 |
189 | self.k_channels = channels // n_heads
190 | self.conv_q = nn.Conv1d(channels, channels, 1)
191 | self.conv_k = nn.Conv1d(channels, channels, 1)
192 | self.conv_v = nn.Conv1d(channels, channels, 1)
193 | self.conv_o = nn.Conv1d(channels, out_channels, 1)
194 | self.drop = nn.Dropout(p_dropout)
195 |
196 | if window_size is not None:
197 | n_heads_rel = 1 if heads_share else n_heads
198 | rel_stddev = self.k_channels**-0.5
199 | self.emb_rel_k = nn.Parameter(
200 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
201 | * rel_stddev
202 | )
203 | self.emb_rel_v = nn.Parameter(
204 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
205 | * rel_stddev
206 | )
207 |
208 | nn.init.xavier_uniform_(self.conv_q.weight)
209 | nn.init.xavier_uniform_(self.conv_k.weight)
210 | nn.init.xavier_uniform_(self.conv_v.weight)
211 | if proximal_init:
212 | with torch.no_grad():
213 | self.conv_k.weight.copy_(self.conv_q.weight)
214 | self.conv_k.bias.copy_(self.conv_q.bias)
215 |
216 | def forward(self, x, c, attn_mask=None):
217 | q = self.conv_q(x)
218 | k = self.conv_k(c)
219 | v = self.conv_v(c)
220 |
221 | x, self.attn = self.attention(q, k, v, mask=attn_mask)
222 |
223 | x = self.conv_o(x)
224 | return x
225 |
226 | def attention(self, query, key, value, mask=None):
227 | # reshape [b, d, t] -> [b, n_h, t, d_k]
228 | b, d, t_s, t_t = (*key.size(), query.size(2))
229 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
230 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
231 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
232 |
233 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
234 | if self.window_size is not None:
235 | assert (
236 | t_s == t_t
237 | ), "Relative attention is only available for self-attention."
238 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
239 | rel_logits = self._matmul_with_relative_keys(
240 | query / math.sqrt(self.k_channels), key_relative_embeddings
241 | )
242 | scores_local = self._relative_position_to_absolute_position(rel_logits)
243 | scores = scores + scores_local
244 | if self.proximal_bias:
245 | assert t_s == t_t, "Proximal bias is only available for self-attention."
246 | scores = scores + self._attention_bias_proximal(t_s).to(
247 | device=scores.device, dtype=scores.dtype
248 | )
249 | if mask is not None:
250 | scores = scores.masked_fill(mask == 0, -1e4)
251 | if self.block_length is not None:
252 | assert (
253 | t_s == t_t
254 | ), "Local attention is only available for self-attention."
255 | block_mask = (
256 | torch.ones_like(scores)
257 | .triu(-self.block_length)
258 | .tril(self.block_length)
259 | )
260 | scores = scores.masked_fill(block_mask == 0, -1e4)
261 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
262 | p_attn = self.drop(p_attn)
263 | output = torch.matmul(p_attn, value)
264 | if self.window_size is not None:
265 | relative_weights = self._absolute_position_to_relative_position(p_attn)
266 | value_relative_embeddings = self._get_relative_embeddings(
267 | self.emb_rel_v, t_s
268 | )
269 | output = output + self._matmul_with_relative_values(
270 | relative_weights, value_relative_embeddings
271 | )
272 | output = (
273 | output.transpose(2, 3).contiguous().view(b, d, t_t)
274 | ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
275 | return output, p_attn
276 |
277 | def _matmul_with_relative_values(self, x, y):
278 | """
279 | x: [b, h, l, m]
280 | y: [h or 1, m, d]
281 | ret: [b, h, l, d]
282 | """
283 | ret = torch.matmul(x, y.unsqueeze(0))
284 | return ret
285 |
286 | def _matmul_with_relative_keys(self, x, y):
287 | """
288 | x: [b, h, l, d]
289 | y: [h or 1, m, d]
290 | ret: [b, h, l, m]
291 | """
292 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
293 | return ret
294 |
295 | def _get_relative_embeddings(self, relative_embeddings, length):
296 | max_relative_position = 2 * self.window_size + 1
297 | # Pad first before slice to avoid using cond ops.
298 | pad_length = max(length - (self.window_size + 1), 0)
299 | slice_start_position = max((self.window_size + 1) - length, 0)
300 | slice_end_position = slice_start_position + 2 * length - 1
301 | if pad_length > 0:
302 | padded_relative_embeddings = F.pad(
303 | relative_embeddings,
304 | commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
305 | )
306 | else:
307 | padded_relative_embeddings = relative_embeddings
308 | used_relative_embeddings = padded_relative_embeddings[
309 | :, slice_start_position:slice_end_position
310 | ]
311 | return used_relative_embeddings
312 |
313 | def _relative_position_to_absolute_position(self, x):
314 | """
315 | x: [b, h, l, 2*l-1]
316 | ret: [b, h, l, l]
317 | """
318 | batch, heads, length, _ = x.size()
319 | # Concat columns of pad to shift from relative to absolute indexing.
320 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
321 |
322 | # Concat extra elements so to add up to shape (len+1, 2*len-1).
323 | x_flat = x.view([batch, heads, length * 2 * length])
324 | x_flat = F.pad(
325 | x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
326 | )
327 |
328 | # Reshape and slice out the padded elements.
329 | x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
330 | :, :, :length, length - 1 :
331 | ]
332 | return x_final
333 |
334 | def _absolute_position_to_relative_position(self, x):
335 | """
336 | x: [b, h, l, l]
337 | ret: [b, h, l, 2*l-1]
338 | """
339 | batch, heads, length, _ = x.size()
340 | # padd along column
341 | x = F.pad(
342 | x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
343 | )
344 | x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
345 | # add 0's in the beginning that will skew the elements after reshape
346 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
347 | x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
348 | return x_final
349 |
350 | def _attention_bias_proximal(self, length):
351 | """Bias for self-attention to encourage attention to close positions.
352 | Args:
353 | length: an integer scalar.
354 | Returns:
355 | a Tensor with shape [1, 1, length, length]
356 | """
357 | r = torch.arange(length, dtype=torch.float32)
358 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
359 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
360 |
361 |
362 | class FFN(nn.Module):
363 | def __init__(
364 | self,
365 | in_channels,
366 | out_channels,
367 | filter_channels,
368 | kernel_size,
369 | p_dropout=0.0,
370 | activation=None,
371 | causal=False,
372 | ):
373 | super().__init__()
374 | self.in_channels = in_channels
375 | self.out_channels = out_channels
376 | self.filter_channels = filter_channels
377 | self.kernel_size = kernel_size
378 | self.p_dropout = p_dropout
379 | self.activation = activation
380 | self.causal = causal
381 |
382 | if causal:
383 | self.padding = self._causal_padding
384 | else:
385 | self.padding = self._same_padding
386 |
387 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
388 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
389 | self.drop = nn.Dropout(p_dropout)
390 |
391 | def forward(self, x, x_mask):
392 | x = self.conv_1(self.padding(x * x_mask))
393 | if self.activation == "gelu":
394 | x = x * torch.sigmoid(1.702 * x)
395 | else:
396 | x = torch.relu(x)
397 | x = self.drop(x)
398 | x = self.conv_2(self.padding(x * x_mask))
399 | return x * x_mask
400 |
401 | def _causal_padding(self, x):
402 | if self.kernel_size == 1:
403 | return x
404 | pad_l = self.kernel_size - 1
405 | pad_r = 0
406 | padding = [[0, 0], [0, 0], [pad_l, pad_r]]
407 | x = F.pad(x, commons.convert_pad_shape(padding))
408 | return x
409 |
410 | def _same_padding(self, x):
411 | if self.kernel_size == 1:
412 | return x
413 | pad_l = (self.kernel_size - 1) // 2
414 | pad_r = self.kernel_size // 2
415 | padding = [[0, 0], [0, 0], [pad_l, pad_r]]
416 | x = F.pad(x, commons.convert_pad_shape(padding))
417 | return x
418 |
--------------------------------------------------------------------------------
/bert/ProsodyModel.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from transformers import BertModel, BertConfig, BertTokenizer
7 |
8 |
9 | class CharEmbedding(nn.Module):
10 | def __init__(self, model_dir):
11 | super().__init__()
12 | self.tokenizer = BertTokenizer.from_pretrained(model_dir)
13 | self.bert_config = BertConfig.from_pretrained(model_dir)
14 | self.hidden_size = self.bert_config.hidden_size
15 | self.bert = BertModel(self.bert_config)
16 | self.proj = nn.Linear(self.hidden_size, 256)
17 | self.linear = nn.Linear(256, 3)
18 |
19 | def text2Token(self, text):
20 | token = self.tokenizer.tokenize(text)
21 | txtid = self.tokenizer.convert_tokens_to_ids(token)
22 | return txtid
23 |
24 | def forward(self, inputs_ids, inputs_masks, tokens_type_ids):
25 | out_seq = self.bert(input_ids=inputs_ids,
26 | attention_mask=inputs_masks,
27 | token_type_ids=tokens_type_ids)[0]
28 | out_seq = self.proj(out_seq)
29 | return out_seq
30 |
31 |
32 | class TTSProsody(object):
33 | def __init__(self, path, device):
34 | self.device = device
35 | self.char_model = CharEmbedding(path)
36 | self.char_model.load_state_dict(
37 | torch.load(
38 | os.path.join(path, 'prosody_model.pt'),
39 | map_location="cpu"
40 | ),
41 | strict=False
42 | )
43 | self.char_model.eval()
44 | self.char_model.to(self.device)
45 |
46 | def get_char_embeds(self, text):
47 | input_ids = self.char_model.text2Token(text)
48 | input_masks = [1] * len(input_ids)
49 | type_ids = [0] * len(input_ids)
50 | input_ids = torch.LongTensor([input_ids]).to(self.device)
51 | input_masks = torch.LongTensor([input_masks]).to(self.device)
52 | type_ids = torch.LongTensor([type_ids]).to(self.device)
53 |
54 | with torch.no_grad():
55 | char_embeds = self.char_model(
56 | input_ids, input_masks, type_ids).squeeze(0).cpu()
57 | return char_embeds
58 |
59 | def expand_for_phone(self, char_embeds, length): # length of phones for char
60 | assert char_embeds.size(0) == len(length)
61 | expand_vecs = list()
62 | for vec, leng in zip(char_embeds, length):
63 | vec = vec.expand(leng, -1)
64 | expand_vecs.append(vec)
65 | expand_embeds = torch.cat(expand_vecs, 0)
66 | assert expand_embeds.size(0) == sum(length)
67 | return expand_embeds.numpy()
68 |
69 |
70 | if __name__ == "__main__":
71 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
72 | prosody = TTSProsody('./bert/', device)
73 | while True:
74 | text = input("请输入文本:")
75 | prosody.get_char_embeds(text)
76 |
--------------------------------------------------------------------------------
/bert/__init__.py:
--------------------------------------------------------------------------------
1 | from .ProsodyModel import TTSProsody
--------------------------------------------------------------------------------
/bert/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "attention_probs_dropout_prob": 0.1,
3 | "directionality": "bidi",
4 | "hidden_act": "gelu",
5 | "hidden_dropout_prob": 0.1,
6 | "hidden_size": 768,
7 | "initializer_range": 0.02,
8 | "intermediate_size": 3072,
9 | "max_position_embeddings": 512,
10 | "num_attention_heads": 12,
11 | "num_hidden_layers": 12,
12 | "pooler_fc_size": 768,
13 | "pooler_num_attention_heads": 12,
14 | "pooler_num_fc_layers": 3,
15 | "pooler_size_per_head": 128,
16 | "pooler_type": "first_token_transform",
17 | "type_vocab_size": 2,
18 | "vocab_size": 21128
19 | }
20 |
--------------------------------------------------------------------------------
/bert/prosody_tool.py:
--------------------------------------------------------------------------------
1 | def is_chinese(uchar):
2 | if uchar >= u'\u4e00' and uchar <= u'\u9fa5':
3 | return True
4 | else:
5 | return False
6 |
7 |
8 | pinyin_dict = {
9 | "a": ("^", "a"),
10 | "ai": ("^", "ai"),
11 | "an": ("^", "an"),
12 | "ang": ("^", "ang"),
13 | "ao": ("^", "ao"),
14 | "ba": ("b", "a"),
15 | "bai": ("b", "ai"),
16 | "ban": ("b", "an"),
17 | "bang": ("b", "ang"),
18 | "bao": ("b", "ao"),
19 | "be": ("b", "e"),
20 | "bei": ("b", "ei"),
21 | "ben": ("b", "en"),
22 | "beng": ("b", "eng"),
23 | "bi": ("b", "i"),
24 | "bian": ("b", "ian"),
25 | "biao": ("b", "iao"),
26 | "bie": ("b", "ie"),
27 | "bin": ("b", "in"),
28 | "bing": ("b", "ing"),
29 | "bo": ("b", "o"),
30 | "bu": ("b", "u"),
31 | "ca": ("c", "a"),
32 | "cai": ("c", "ai"),
33 | "can": ("c", "an"),
34 | "cang": ("c", "ang"),
35 | "cao": ("c", "ao"),
36 | "ce": ("c", "e"),
37 | "cen": ("c", "en"),
38 | "ceng": ("c", "eng"),
39 | "cha": ("ch", "a"),
40 | "chai": ("ch", "ai"),
41 | "chan": ("ch", "an"),
42 | "chang": ("ch", "ang"),
43 | "chao": ("ch", "ao"),
44 | "che": ("ch", "e"),
45 | "chen": ("ch", "en"),
46 | "cheng": ("ch", "eng"),
47 | "chi": ("ch", "iii"),
48 | "chong": ("ch", "ong"),
49 | "chou": ("ch", "ou"),
50 | "chu": ("ch", "u"),
51 | "chua": ("ch", "ua"),
52 | "chuai": ("ch", "uai"),
53 | "chuan": ("ch", "uan"),
54 | "chuang": ("ch", "uang"),
55 | "chui": ("ch", "uei"),
56 | "chun": ("ch", "uen"),
57 | "chuo": ("ch", "uo"),
58 | "ci": ("c", "ii"),
59 | "cong": ("c", "ong"),
60 | "cou": ("c", "ou"),
61 | "cu": ("c", "u"),
62 | "cuan": ("c", "uan"),
63 | "cui": ("c", "uei"),
64 | "cun": ("c", "uen"),
65 | "cuo": ("c", "uo"),
66 | "da": ("d", "a"),
67 | "dai": ("d", "ai"),
68 | "dan": ("d", "an"),
69 | "dang": ("d", "ang"),
70 | "dao": ("d", "ao"),
71 | "de": ("d", "e"),
72 | "dei": ("d", "ei"),
73 | "den": ("d", "en"),
74 | "deng": ("d", "eng"),
75 | "di": ("d", "i"),
76 | "dia": ("d", "ia"),
77 | "dian": ("d", "ian"),
78 | "diao": ("d", "iao"),
79 | "die": ("d", "ie"),
80 | "ding": ("d", "ing"),
81 | "diu": ("d", "iou"),
82 | "dong": ("d", "ong"),
83 | "dou": ("d", "ou"),
84 | "du": ("d", "u"),
85 | "duan": ("d", "uan"),
86 | "dui": ("d", "uei"),
87 | "dun": ("d", "uen"),
88 | "duo": ("d", "uo"),
89 | "e": ("^", "e"),
90 | "ei": ("^", "ei"),
91 | "en": ("^", "en"),
92 | "ng": ("^", "en"),
93 | "eng": ("^", "eng"),
94 | "er": ("^", "er"),
95 | "fa": ("f", "a"),
96 | "fan": ("f", "an"),
97 | "fang": ("f", "ang"),
98 | "fei": ("f", "ei"),
99 | "fen": ("f", "en"),
100 | "feng": ("f", "eng"),
101 | "fo": ("f", "o"),
102 | "fou": ("f", "ou"),
103 | "fu": ("f", "u"),
104 | "ga": ("g", "a"),
105 | "gai": ("g", "ai"),
106 | "gan": ("g", "an"),
107 | "gang": ("g", "ang"),
108 | "gao": ("g", "ao"),
109 | "ge": ("g", "e"),
110 | "gei": ("g", "ei"),
111 | "gen": ("g", "en"),
112 | "geng": ("g", "eng"),
113 | "gong": ("g", "ong"),
114 | "gou": ("g", "ou"),
115 | "gu": ("g", "u"),
116 | "gua": ("g", "ua"),
117 | "guai": ("g", "uai"),
118 | "guan": ("g", "uan"),
119 | "guang": ("g", "uang"),
120 | "gui": ("g", "uei"),
121 | "gun": ("g", "uen"),
122 | "guo": ("g", "uo"),
123 | "ha": ("h", "a"),
124 | "hai": ("h", "ai"),
125 | "han": ("h", "an"),
126 | "hang": ("h", "ang"),
127 | "hao": ("h", "ao"),
128 | "he": ("h", "e"),
129 | "hei": ("h", "ei"),
130 | "hen": ("h", "en"),
131 | "heng": ("h", "eng"),
132 | "hong": ("h", "ong"),
133 | "hou": ("h", "ou"),
134 | "hu": ("h", "u"),
135 | "hua": ("h", "ua"),
136 | "huai": ("h", "uai"),
137 | "huan": ("h", "uan"),
138 | "huang": ("h", "uang"),
139 | "hui": ("h", "uei"),
140 | "hun": ("h", "uen"),
141 | "huo": ("h", "uo"),
142 | "ji": ("j", "i"),
143 | "jia": ("j", "ia"),
144 | "jian": ("j", "ian"),
145 | "jiang": ("j", "iang"),
146 | "jiao": ("j", "iao"),
147 | "jie": ("j", "ie"),
148 | "jin": ("j", "in"),
149 | "jing": ("j", "ing"),
150 | "jiong": ("j", "iong"),
151 | "jiu": ("j", "iou"),
152 | "ju": ("j", "v"),
153 | "juan": ("j", "van"),
154 | "jue": ("j", "ve"),
155 | "jun": ("j", "vn"),
156 | "ka": ("k", "a"),
157 | "kai": ("k", "ai"),
158 | "kan": ("k", "an"),
159 | "kang": ("k", "ang"),
160 | "kao": ("k", "ao"),
161 | "ke": ("k", "e"),
162 | "kei": ("k", "ei"),
163 | "ken": ("k", "en"),
164 | "keng": ("k", "eng"),
165 | "kong": ("k", "ong"),
166 | "kou": ("k", "ou"),
167 | "ku": ("k", "u"),
168 | "kua": ("k", "ua"),
169 | "kuai": ("k", "uai"),
170 | "kuan": ("k", "uan"),
171 | "kuang": ("k", "uang"),
172 | "kui": ("k", "uei"),
173 | "kun": ("k", "uen"),
174 | "kuo": ("k", "uo"),
175 | "la": ("l", "a"),
176 | "lai": ("l", "ai"),
177 | "lan": ("l", "an"),
178 | "lang": ("l", "ang"),
179 | "lao": ("l", "ao"),
180 | "le": ("l", "e"),
181 | "lei": ("l", "ei"),
182 | "leng": ("l", "eng"),
183 | "li": ("l", "i"),
184 | "lia": ("l", "ia"),
185 | "lian": ("l", "ian"),
186 | "liang": ("l", "iang"),
187 | "liao": ("l", "iao"),
188 | "lie": ("l", "ie"),
189 | "lin": ("l", "in"),
190 | "ling": ("l", "ing"),
191 | "liu": ("l", "iou"),
192 | "lo": ("l", "o"),
193 | "long": ("l", "ong"),
194 | "lou": ("l", "ou"),
195 | "lu": ("l", "u"),
196 | "lv": ("l", "v"),
197 | "luan": ("l", "uan"),
198 | "lve": ("l", "ve"),
199 | "lue": ("l", "ve"),
200 | "lun": ("l", "uen"),
201 | "luo": ("l", "uo"),
202 | "ma": ("m", "a"),
203 | "mai": ("m", "ai"),
204 | "man": ("m", "an"),
205 | "mang": ("m", "ang"),
206 | "mao": ("m", "ao"),
207 | "me": ("m", "e"),
208 | "mei": ("m", "ei"),
209 | "men": ("m", "en"),
210 | "meng": ("m", "eng"),
211 | "mi": ("m", "i"),
212 | "mian": ("m", "ian"),
213 | "miao": ("m", "iao"),
214 | "mie": ("m", "ie"),
215 | "min": ("m", "in"),
216 | "ming": ("m", "ing"),
217 | "miu": ("m", "iou"),
218 | "mo": ("m", "o"),
219 | "mou": ("m", "ou"),
220 | "mu": ("m", "u"),
221 | "na": ("n", "a"),
222 | "nai": ("n", "ai"),
223 | "nan": ("n", "an"),
224 | "nang": ("n", "ang"),
225 | "nao": ("n", "ao"),
226 | "ne": ("n", "e"),
227 | "nei": ("n", "ei"),
228 | "nen": ("n", "en"),
229 | "neng": ("n", "eng"),
230 | "ni": ("n", "i"),
231 | "nia": ("n", "ia"),
232 | "nian": ("n", "ian"),
233 | "niang": ("n", "iang"),
234 | "niao": ("n", "iao"),
235 | "nie": ("n", "ie"),
236 | "nin": ("n", "in"),
237 | "ning": ("n", "ing"),
238 | "niu": ("n", "iou"),
239 | "nong": ("n", "ong"),
240 | "nou": ("n", "ou"),
241 | "nu": ("n", "u"),
242 | "nv": ("n", "v"),
243 | "nuan": ("n", "uan"),
244 | "nve": ("n", "ve"),
245 | "nue": ("n", "ve"),
246 | "nuo": ("n", "uo"),
247 | "o": ("^", "o"),
248 | "ou": ("^", "ou"),
249 | "pa": ("p", "a"),
250 | "pai": ("p", "ai"),
251 | "pan": ("p", "an"),
252 | "pang": ("p", "ang"),
253 | "pao": ("p", "ao"),
254 | "pe": ("p", "e"),
255 | "pei": ("p", "ei"),
256 | "pen": ("p", "en"),
257 | "peng": ("p", "eng"),
258 | "pi": ("p", "i"),
259 | "pian": ("p", "ian"),
260 | "piao": ("p", "iao"),
261 | "pie": ("p", "ie"),
262 | "pin": ("p", "in"),
263 | "ping": ("p", "ing"),
264 | "po": ("p", "o"),
265 | "pou": ("p", "ou"),
266 | "pu": ("p", "u"),
267 | "qi": ("q", "i"),
268 | "qia": ("q", "ia"),
269 | "qian": ("q", "ian"),
270 | "qiang": ("q", "iang"),
271 | "qiao": ("q", "iao"),
272 | "qie": ("q", "ie"),
273 | "qin": ("q", "in"),
274 | "qing": ("q", "ing"),
275 | "qiong": ("q", "iong"),
276 | "qiu": ("q", "iou"),
277 | "qu": ("q", "v"),
278 | "quan": ("q", "van"),
279 | "que": ("q", "ve"),
280 | "qun": ("q", "vn"),
281 | "ran": ("r", "an"),
282 | "rang": ("r", "ang"),
283 | "rao": ("r", "ao"),
284 | "re": ("r", "e"),
285 | "ren": ("r", "en"),
286 | "reng": ("r", "eng"),
287 | "ri": ("r", "iii"),
288 | "rong": ("r", "ong"),
289 | "rou": ("r", "ou"),
290 | "ru": ("r", "u"),
291 | "rua": ("r", "ua"),
292 | "ruan": ("r", "uan"),
293 | "rui": ("r", "uei"),
294 | "run": ("r", "uen"),
295 | "ruo": ("r", "uo"),
296 | "sa": ("s", "a"),
297 | "sai": ("s", "ai"),
298 | "san": ("s", "an"),
299 | "sang": ("s", "ang"),
300 | "sao": ("s", "ao"),
301 | "se": ("s", "e"),
302 | "sen": ("s", "en"),
303 | "seng": ("s", "eng"),
304 | "sha": ("sh", "a"),
305 | "shai": ("sh", "ai"),
306 | "shan": ("sh", "an"),
307 | "shang": ("sh", "ang"),
308 | "shao": ("sh", "ao"),
309 | "she": ("sh", "e"),
310 | "shei": ("sh", "ei"),
311 | "shen": ("sh", "en"),
312 | "sheng": ("sh", "eng"),
313 | "shi": ("sh", "iii"),
314 | "shou": ("sh", "ou"),
315 | "shu": ("sh", "u"),
316 | "shua": ("sh", "ua"),
317 | "shuai": ("sh", "uai"),
318 | "shuan": ("sh", "uan"),
319 | "shuang": ("sh", "uang"),
320 | "shui": ("sh", "uei"),
321 | "shun": ("sh", "uen"),
322 | "shuo": ("sh", "uo"),
323 | "si": ("s", "ii"),
324 | "song": ("s", "ong"),
325 | "sou": ("s", "ou"),
326 | "su": ("s", "u"),
327 | "suan": ("s", "uan"),
328 | "sui": ("s", "uei"),
329 | "sun": ("s", "uen"),
330 | "suo": ("s", "uo"),
331 | "ta": ("t", "a"),
332 | "tai": ("t", "ai"),
333 | "tan": ("t", "an"),
334 | "tang": ("t", "ang"),
335 | "tao": ("t", "ao"),
336 | "te": ("t", "e"),
337 | "tei": ("t", "ei"),
338 | "teng": ("t", "eng"),
339 | "ti": ("t", "i"),
340 | "tian": ("t", "ian"),
341 | "tiao": ("t", "iao"),
342 | "tie": ("t", "ie"),
343 | "ting": ("t", "ing"),
344 | "tong": ("t", "ong"),
345 | "tou": ("t", "ou"),
346 | "tu": ("t", "u"),
347 | "tuan": ("t", "uan"),
348 | "tui": ("t", "uei"),
349 | "tun": ("t", "uen"),
350 | "tuo": ("t", "uo"),
351 | "wa": ("^", "ua"),
352 | "wai": ("^", "uai"),
353 | "wan": ("^", "uan"),
354 | "wang": ("^", "uang"),
355 | "wei": ("^", "uei"),
356 | "wen": ("^", "uen"),
357 | "weng": ("^", "ueng"),
358 | "wo": ("^", "uo"),
359 | "wu": ("^", "u"),
360 | "xi": ("x", "i"),
361 | "xia": ("x", "ia"),
362 | "xian": ("x", "ian"),
363 | "xiang": ("x", "iang"),
364 | "xiao": ("x", "iao"),
365 | "xie": ("x", "ie"),
366 | "xin": ("x", "in"),
367 | "xing": ("x", "ing"),
368 | "xiong": ("x", "iong"),
369 | "xiu": ("x", "iou"),
370 | "xu": ("x", "v"),
371 | "xuan": ("x", "van"),
372 | "xue": ("x", "ve"),
373 | "xun": ("x", "vn"),
374 | "ya": ("^", "ia"),
375 | "yan": ("^", "ian"),
376 | "yang": ("^", "iang"),
377 | "yao": ("^", "iao"),
378 | "ye": ("^", "ie"),
379 | "yi": ("^", "i"),
380 | "yin": ("^", "in"),
381 | "ying": ("^", "ing"),
382 | "yo": ("^", "iou"),
383 | "yong": ("^", "iong"),
384 | "you": ("^", "iou"),
385 | "yu": ("^", "v"),
386 | "yuan": ("^", "van"),
387 | "yue": ("^", "ve"),
388 | "yun": ("^", "vn"),
389 | "za": ("z", "a"),
390 | "zai": ("z", "ai"),
391 | "zan": ("z", "an"),
392 | "zang": ("z", "ang"),
393 | "zao": ("z", "ao"),
394 | "ze": ("z", "e"),
395 | "zei": ("z", "ei"),
396 | "zen": ("z", "en"),
397 | "zeng": ("z", "eng"),
398 | "zha": ("zh", "a"),
399 | "zhai": ("zh", "ai"),
400 | "zhan": ("zh", "an"),
401 | "zhang": ("zh", "ang"),
402 | "zhao": ("zh", "ao"),
403 | "zhe": ("zh", "e"),
404 | "zhei": ("zh", "ei"),
405 | "zhen": ("zh", "en"),
406 | "zheng": ("zh", "eng"),
407 | "zhi": ("zh", "iii"),
408 | "zhong": ("zh", "ong"),
409 | "zhou": ("zh", "ou"),
410 | "zhu": ("zh", "u"),
411 | "zhua": ("zh", "ua"),
412 | "zhuai": ("zh", "uai"),
413 | "zhuan": ("zh", "uan"),
414 | "zhuang": ("zh", "uang"),
415 | "zhui": ("zh", "uei"),
416 | "zhun": ("zh", "uen"),
417 | "zhuo": ("zh", "uo"),
418 | "zi": ("z", "ii"),
419 | "zong": ("z", "ong"),
420 | "zou": ("z", "ou"),
421 | "zu": ("z", "u"),
422 | "zuan": ("z", "uan"),
423 | "zui": ("z", "uei"),
424 | "zun": ("z", "uen"),
425 | "zuo": ("z", "uo"),
426 | }
427 |
--------------------------------------------------------------------------------
/commons.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 |
8 | def init_weights(m, mean=0.0, std=0.01):
9 | classname = m.__class__.__name__
10 | if classname.find("Conv") != -1:
11 | m.weight.data.normal_(mean, std)
12 |
13 |
14 | def get_padding(kernel_size, dilation=1):
15 | return int((kernel_size * dilation - dilation) / 2)
16 |
17 |
18 | def convert_pad_shape(pad_shape):
19 | l = pad_shape[::-1]
20 | pad_shape = [item for sublist in l for item in sublist]
21 | return pad_shape
22 |
23 |
24 | def intersperse(lst, item):
25 | result = [item] * (len(lst) * 2 + 1)
26 | result[1::2] = lst
27 | return result
28 |
29 |
30 | def kl_divergence(m_p, logs_p, m_q, logs_q):
31 | """KL(P||Q)"""
32 | kl = (logs_q - logs_p) - 0.5
33 | kl += (
34 | 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
35 | )
36 | return kl
37 |
38 |
39 | def rand_gumbel(shape):
40 | """Sample from the Gumbel distribution, protect from overflows."""
41 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
42 | return -torch.log(-torch.log(uniform_samples))
43 |
44 |
45 | def rand_gumbel_like(x):
46 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
47 | return g
48 |
49 |
50 | def slice_segments(x, ids_str, segment_size=4):
51 | ret = torch.zeros_like(x[:, :, :segment_size])
52 | for i in range(x.size(0)):
53 | idx_str = ids_str[i]
54 | idx_end = idx_str + segment_size
55 | ret[i] = x[i, :, idx_str:idx_end]
56 | return ret
57 |
58 |
59 | def rand_slice_segments(x, x_lengths=None, segment_size=4):
60 | b, d, t = x.size()
61 | if x_lengths is None:
62 | x_lengths = t
63 | ids_str_max = x_lengths - segment_size + 1
64 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
65 | ret = slice_segments(x, ids_str, segment_size)
66 | return ret, ids_str
67 |
68 |
69 | def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
70 | position = torch.arange(length, dtype=torch.float)
71 | num_timescales = channels // 2
72 | log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
73 | num_timescales - 1
74 | )
75 | inv_timescales = min_timescale * torch.exp(
76 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
77 | )
78 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
79 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
80 | signal = F.pad(signal, [0, 0, 0, channels % 2])
81 | signal = signal.view(1, channels, length)
82 | return signal
83 |
84 |
85 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
86 | b, channels, length = x.size()
87 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
88 | return x + signal.to(dtype=x.dtype, device=x.device)
89 |
90 |
91 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
92 | b, channels, length = x.size()
93 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
94 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
95 |
96 |
97 | def subsequent_mask(length):
98 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
99 | return mask
100 |
101 |
102 | @torch.jit.script
103 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
104 | n_channels_int = n_channels[0]
105 | in_act = input_a + input_b
106 | t_act = torch.tanh(in_act[:, :n_channels_int, :])
107 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
108 | acts = t_act * s_act
109 | return acts
110 |
111 |
112 | def convert_pad_shape(pad_shape):
113 | l = pad_shape[::-1]
114 | pad_shape = [item for sublist in l for item in sublist]
115 | return pad_shape
116 |
117 |
118 | def shift_1d(x):
119 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
120 | return x
121 |
122 |
123 | def sequence_mask(length, max_length=None):
124 | if max_length is None:
125 | max_length = length.max()
126 | x = torch.arange(max_length, dtype=length.dtype, device=length.device)
127 | return x.unsqueeze(0) < length.unsqueeze(1)
128 |
129 |
130 | def generate_path(duration, mask):
131 | """
132 | duration: [b, 1, t_x]
133 | mask: [b, 1, t_y, t_x]
134 | """
135 | device = duration.device
136 |
137 | b, _, t_y, t_x = mask.shape
138 | cum_duration = torch.cumsum(duration, -1)
139 |
140 | cum_duration_flat = cum_duration.view(b * t_x)
141 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
142 | path = path.view(b, t_x, t_y)
143 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
144 | path = path.unsqueeze(1).transpose(2, 3) * mask
145 | return path
146 |
147 |
148 | def clip_grad_value_(parameters, clip_value, norm_type=2):
149 | if isinstance(parameters, torch.Tensor):
150 | parameters = [parameters]
151 | parameters = list(filter(lambda p: p.grad is not None, parameters))
152 | norm_type = float(norm_type)
153 | if clip_value is not None:
154 | clip_value = float(clip_value)
155 |
156 | total_norm = 0
157 | for p in parameters:
158 | param_norm = p.grad.data.norm(norm_type)
159 | total_norm += param_norm.item() ** norm_type
160 | if clip_value is not None:
161 | p.grad.data.clamp_(min=-clip_value, max=clip_value)
162 | total_norm = total_norm ** (1.0 / norm_type)
163 | return total_norm
164 |
--------------------------------------------------------------------------------
/configs/bert_vits.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "train_class": "models.SynthesizerTrn",
4 | "eval_class": "models.SynthesizerEval",
5 | "log_interval": 100,
6 | "eval_interval": 10000,
7 | "seed": 1234,
8 | "epochs": 20000,
9 | "learning_rate": 1e-4,
10 | "betas": [0.8, 0.99],
11 | "eps": 1e-9,
12 | "batch_size": 8,
13 | "fp16_run": false,
14 | "lr_decay": 0.999875,
15 | "segment_size": 12800,
16 | "init_lr_ratio": 1,
17 | "warmup_epochs": 0,
18 | "c_mel": 45,
19 | "c_kl": 1.0
20 | },
21 | "data": {
22 | "training_files": "filelists/train.txt",
23 | "validation_files": "filelists/valid.txt",
24 | "max_wav_value": 32768.0,
25 | "sampling_rate": 16000,
26 | "filter_length": 1024,
27 | "hop_length": 256,
28 | "win_length": 1024,
29 | "n_mel_channels": 80,
30 | "mel_fmin": 0.0,
31 | "mel_fmax": null,
32 | "add_blank": false,
33 | "n_speakers": 0
34 | },
35 | "model": {
36 | "inter_channels": 192,
37 | "hidden_channels": 192,
38 | "filter_channels": 768,
39 | "n_heads": 2,
40 | "n_layers": 6,
41 | "kernel_size": 3,
42 | "p_dropout": 0.1,
43 | "resblock": "1",
44 | "resblock_kernel_sizes": [3,7,11],
45 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
46 | "upsample_rates": [8,8,2,2],
47 | "upsample_initial_channel": 512,
48 | "upsample_kernel_sizes": [16,16,4,4],
49 | "n_layers_q": 3,
50 | "use_spectral_norm": false
51 | }
52 | }
--------------------------------------------------------------------------------
/configs/bert_vits_student.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "train_class": "models.SyntStudentTrn",
4 | "eval_class": "models.SynthesizerEval",
5 | "teacher": "./vits_bert_model.pth",
6 | "log_interval": 100,
7 | "eval_interval": 10000,
8 | "seed": 1234,
9 | "epochs": 20000,
10 | "learning_rate": 1e-4,
11 | "betas": [0.8, 0.99],
12 | "eps": 1e-9,
13 | "batch_size": 8,
14 | "fp16_run": false,
15 | "lr_decay": 0.999875,
16 | "segment_size": 12800,
17 | "init_lr_ratio": 1,
18 | "warmup_epochs": 0,
19 | "c_mel": 45,
20 | "c_kl": 1.0
21 | },
22 | "data": {
23 | "training_files":"filelists/train.txt",
24 | "validation_files":"filelists/valid.txt",
25 | "max_wav_value": 32768.0,
26 | "sampling_rate": 16000,
27 | "filter_length": 1024,
28 | "hop_length": 256,
29 | "win_length": 1024,
30 | "n_mel_channels": 80,
31 | "mel_fmin": 0.0,
32 | "mel_fmax": null,
33 | "add_blank": false,
34 | "n_speakers": 0
35 | },
36 | "model": {
37 | "inter_channels": 192,
38 | "hidden_channels": 192,
39 | "filter_channels": 512,
40 | "n_heads": 2,
41 | "n_layers": 5,
42 | "kernel_size": 3,
43 | "p_dropout": 0.1,
44 | "resblock": "1",
45 | "resblock_kernel_sizes": [3,7,11],
46 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
47 | "upsample_rates": [8,8,2,2],
48 | "upsample_initial_channel": 256,
49 | "upsample_kernel_sizes": [16,16,4,4],
50 | "n_layers_q": 3,
51 | "use_spectral_norm": false
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | import random
4 | import numpy as np
5 | import torch
6 | import torch.utils.data
7 |
8 | import commons
9 | from utils import load_wav_to_torch, load_filepaths_and_text
10 | from text import cleaned_text_to_sequence
11 |
12 |
13 | class TextAudioLoader(torch.utils.data.Dataset):
14 | """
15 | 1) loads audio, text pairs
16 | 2) normalizes text and converts them to sequences of integers
17 | 3) computes spectrograms from audio files.
18 | """
19 |
20 | def __init__(self, audiopaths_and_text, hparams):
21 | self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
22 | self.max_wav_value = hparams.max_wav_value
23 | self.sampling_rate = hparams.sampling_rate
24 | self.filter_length = hparams.filter_length
25 | self.hop_length = hparams.hop_length
26 | self.win_length = hparams.win_length
27 | self.sampling_rate = hparams.sampling_rate
28 |
29 | self.cleaned_text = getattr(hparams, "cleaned_text", False)
30 |
31 | self.add_blank = hparams.add_blank
32 | self.min_text_len = getattr(hparams, "min_text_len", 1)
33 | self.max_text_len = getattr(hparams, "max_text_len", 100)
34 |
35 | # shuffle is not nead for single speaker
36 | # random.seed(1234)
37 | # random.shuffle(self.audiopaths_and_text)
38 | self._filter()
39 |
40 | def _filter(self):
41 | """
42 | Filter text & store spec lengths
43 | """
44 | # Store spectrogram lengths for Bucketing
45 | # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
46 | # spec_length = wav_length // hop_length
47 |
48 | audiopaths_and_text_new = []
49 | lengths = []
50 | for audiopath, spec, bert, text in self.audiopaths_and_text:
51 | length = len(text.split())
52 | if self.min_text_len <= length and length <= self.max_text_len:
53 | audiopaths_and_text_new.append([audiopath, spec, bert, text])
54 | lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length))
55 | self.audiopaths_and_text = audiopaths_and_text_new
56 | self.lengths = lengths
57 |
58 | def get_audio_text_pair(self, audiopath_and_text):
59 | # separate filename and text
60 | audiopath, spec = audiopath_and_text[0], audiopath_and_text[1]
61 | bert, text = audiopath_and_text[2], audiopath_and_text[3]
62 | wave = self.get_audio(audiopath)
63 | spec = torch.load(spec)
64 | text = self.get_text(text)
65 | bert = self.get_bert(bert)
66 | return (spec, wave, text, bert)
67 |
68 | def get_audio(self, filename):
69 | audio, sampling_rate = load_wav_to_torch(filename)
70 | if sampling_rate != self.sampling_rate:
71 | raise ValueError(
72 | "{} {} SR doesn't match target {} SR".format(
73 | sampling_rate, self.sampling_rate
74 | )
75 | )
76 | audio_norm = audio / self.max_wav_value
77 | audio_norm = audio_norm.unsqueeze(0)
78 | return audio_norm
79 |
80 | def get_bert(self, bert):
81 | bert_embed = np.load(bert)
82 | bert_embed = bert_embed.astype(np.float32)
83 | bert_embed = torch.FloatTensor(bert_embed)
84 | return bert_embed
85 |
86 | def get_text(self, text):
87 | text_norm = cleaned_text_to_sequence(text)
88 | if self.add_blank:
89 | text_norm = commons.intersperse(text_norm, 0)
90 | text_norm = torch.LongTensor(text_norm)
91 | return text_norm
92 |
93 | def __getitem__(self, index):
94 | return self.get_audio_text_pair(self.audiopaths_and_text[index])
95 |
96 | def __len__(self):
97 | return len(self.audiopaths_and_text)
98 |
99 |
100 | class TextAudioCollate():
101 | """ Zero-pads model inputs and targets
102 | """
103 |
104 | def __init__(self, return_ids=False):
105 | self.return_ids = return_ids
106 |
107 | def __call__(self, batch):
108 | """Collate's training batch from normalized text and aduio
109 | PARAMS
110 | ------
111 | batch: [text_normalized, spec_normalized, wav_normalized]
112 | """
113 | # Right zero-pad all one-hot text sequences to max input length
114 | _, ids_sorted_decreasing = torch.sort(
115 | torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True
116 | )
117 |
118 | max_spec_len = max([x[0].size(1) for x in batch])
119 | max_wav_len = max([x[1].size(1) for x in batch])
120 | max_text_len = max([len(x[2]) for x in batch])
121 |
122 | spec_lengths = torch.LongTensor(len(batch))
123 | wav_lengths = torch.LongTensor(len(batch))
124 | text_lengths = torch.LongTensor(len(batch))
125 |
126 | spec_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len)
127 | wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
128 | text_padded = torch.LongTensor(len(batch), max_text_len)
129 | bert_padded = torch.FloatTensor(len(batch), max_text_len, 256)
130 |
131 | spec_padded.zero_()
132 | wav_padded.zero_()
133 | text_padded.zero_()
134 | bert_padded.zero_()
135 | for i in range(len(ids_sorted_decreasing)):
136 | row = batch[ids_sorted_decreasing[i]]
137 |
138 | spec = row[0]
139 | spec_padded[i, :, :spec.size(1)] = spec
140 | spec_lengths[i] = spec.size(1)
141 |
142 | wav = row[1]
143 | wav_padded[i, :, :wav.size(1)] = wav
144 | wav_lengths[i] = wav.size(1)
145 |
146 | text = row[2]
147 | text_padded[i, :text.size(0)] = text
148 | text_lengths[i] = text.size(0)
149 |
150 | bert = row[3]
151 | bert_padded[i, :bert.size(0), :] = bert
152 |
153 | if self.return_ids:
154 | return text_padded, text_lengths, bert_padded, spec_padded, spec_lengths, wav_padded, wav_lengths, ids_sorted_decreasing
155 | return text_padded, text_lengths, bert_padded, spec_padded, spec_lengths, wav_padded, wav_lengths
156 |
157 |
158 | class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
159 | """
160 | Maintain similar input lengths in a batch.
161 | Length groups are specified by boundaries.
162 | Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
163 |
164 | It removes samples which are not included in the boundaries.
165 | Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
166 | """
167 |
168 | def __init__(
169 | self,
170 | dataset,
171 | batch_size,
172 | boundaries,
173 | num_replicas=None,
174 | rank=None,
175 | shuffle=True,
176 | ):
177 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
178 | self.lengths = dataset.lengths
179 | self.batch_size = batch_size
180 | self.boundaries = boundaries
181 |
182 | self.buckets, self.num_samples_per_bucket = self._create_buckets()
183 | self.total_size = sum(self.num_samples_per_bucket)
184 | self.num_samples = self.total_size // self.num_replicas
185 |
186 | def _create_buckets(self):
187 | buckets = [[] for _ in range(len(self.boundaries) - 1)]
188 | for i in range(len(self.lengths)):
189 | length = self.lengths[i]
190 | idx_bucket = self._bisect(length)
191 | if idx_bucket != -1:
192 | buckets[idx_bucket].append(i)
193 |
194 | for i in range(len(buckets) - 1, 0, -1):
195 | if len(buckets[i]) == 0:
196 | buckets.pop(i)
197 | self.boundaries.pop(i + 1)
198 |
199 | num_samples_per_bucket = []
200 | for i in range(len(buckets)):
201 | len_bucket = len(buckets[i])
202 | total_batch_size = self.num_replicas * self.batch_size
203 | rem = (
204 | total_batch_size - (len_bucket % total_batch_size)
205 | ) % total_batch_size
206 | num_samples_per_bucket.append(len_bucket + rem)
207 | return buckets, num_samples_per_bucket
208 |
209 | def __iter__(self):
210 | # deterministically shuffle based on epoch
211 | g = torch.Generator()
212 | g.manual_seed(self.epoch)
213 |
214 | indices = []
215 | if self.shuffle:
216 | for bucket in self.buckets:
217 | indices.append(torch.randperm(len(bucket), generator=g).tolist())
218 | else:
219 | for bucket in self.buckets:
220 | indices.append(list(range(len(bucket))))
221 |
222 | batches = []
223 | for i in range(len(self.buckets)):
224 | bucket = self.buckets[i]
225 | len_bucket = len(bucket)
226 | if (len_bucket == 0):
227 | continue
228 | ids_bucket = indices[i]
229 | num_samples_bucket = self.num_samples_per_bucket[i]
230 |
231 | # add extra samples to make it evenly divisible
232 | rem = num_samples_bucket - len_bucket
233 | ids_bucket = (
234 | ids_bucket
235 | + ids_bucket * (rem // len_bucket)
236 | + ids_bucket[: (rem % len_bucket)]
237 | )
238 |
239 | # subsample
240 | ids_bucket = ids_bucket[self.rank :: self.num_replicas]
241 |
242 | # batching
243 | for j in range(len(ids_bucket) // self.batch_size):
244 | batch = [
245 | bucket[idx]
246 | for idx in ids_bucket[
247 | j * self.batch_size : (j + 1) * self.batch_size
248 | ]
249 | ]
250 | batches.append(batch)
251 |
252 | if self.shuffle:
253 | batch_ids = torch.randperm(len(batches), generator=g).tolist()
254 | batches = [batches[i] for i in batch_ids]
255 | self.batches = batches
256 |
257 | assert len(self.batches) * self.batch_size == self.num_samples
258 | return iter(self.batches)
259 |
260 | def _bisect(self, x, lo=0, hi=None):
261 | if hi is None:
262 | hi = len(self.boundaries) - 1
263 |
264 | if hi > lo:
265 | mid = (hi + lo) // 2
266 | if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
267 | return mid
268 | elif x <= self.boundaries[mid]:
269 | return self._bisect(x, lo, mid)
270 | else:
271 | return self._bisect(x, mid + 1, hi)
272 | else:
273 | return -1
274 |
275 | def __len__(self):
276 | return self.num_samples // self.batch_size
277 |
--------------------------------------------------------------------------------
/filelists/valid.txt:
--------------------------------------------------------------------------------
1 | ./data/waves/000001.wav|./data/berts/000001.npy|sil k a2 ^ er2 p u3 p ei2 ^ uai4 s uen1 ^ uan2 h ua2 t i1 sp sil
2 | ./data/waves/000002.wav|./data/berts/000002.npy|sil j ia2 ^ v3 c uen1 ^ ian2 b ie2 z ai4 ^ iong1 b ao4 ^ uo3 sp sil
3 | ./data/waves/000003.wav|./data/berts/000003.npy|sil b ao2 m a3 p ei4 g ua4 b o3 l uo2 ^ an1 sp d iao1 ch an2 ^ van4 zh en3 d ong3 ^ ueng1 t a4 sp sil
4 | ./data/waves/000004.wav|./data/berts/000004.npy|sil d eng4 x iao3 p ing2 ^ v3 s a4 q ie4 ^ er3 h uei4 ^ u4 sp sil
5 | ./data/waves/000005.wav|./data/berts/000005.npy|sil l ao2 h u3 ^ iou4 z ai3 ^ v2 ch ong3 ^ u4 q van3 ^ uan2 sh ua3 sp sil
6 | ./data/waves/000006.wav|./data/berts/000006.npy|sil sh en1 ch ang2 ^ ve1 ^ u2 ch iii3 ^ er4 c uen4 ^ u3 f en1 h uo4 ^ i3 sh ang4 sp sil
7 | ./data/waves/000007.wav|./data/berts/000007.npy|sil zh ao4 d i2 ^ ve1 c ao2 ^ vn2 t eng2 q v4 g uei3 ^ u1 sp sil
8 | ./data/waves/000008.wav|./data/berts/000008.npy|sil zh an2 p in3 s uei1 ^ iou3 sp zh an3 ^ van2 q ve4 t uei2 sp sil
9 | ./data/waves/000009.wav|./data/berts/000009.npy|sil ^ i2 s an3 j v1 ^ er2 t ong2 h e2 ^ iou4 t uo1 ^ er2 t ong2 ^ uei2 zh u3 sp sil
10 | ./data/waves/000010.wav|./data/berts/000010.npy|sil k e1 t e4 n i1 sh en1 ch uan1 b ao4 ^ uen2 d a4 ^ i1 sp sil
11 | ./data/waves/000011.wav|./data/berts/000011.npy|sil ^ in3 c ai2 ^ iao4 sh ai1 sp ^ iong4 c ai2 ^ ie3 ^ iao4 sh ai1 sp sil
12 | ./data/waves/000012.wav|./data/berts/000012.npy|sil n an2 ^ ve4 k uen1 l uen2 sh an1 ^ v3 x i1 z ang4 j ie1 r ang3 sp sil
13 | ./data/waves/000013.wav|./data/berts/000013.npy|sil ^ ueng1 k ai3 l an2 f ang2 r u3 x ian4 ^ ai2 x van1 ch uan2 zh ao4 sp sil
14 | ./data/waves/000014.wav|./data/berts/000014.npy|sil ^ uo3 h uei2 ^ iou4 h eng1 h eng5 z uo3 h eng1 h eng5 sp sil
15 | ./data/waves/000015.wav|./data/berts/000015.npy|sil ^ u2 ^ vn2 b ao2 n ai3 n ai5 t iao1 x van3 p i2 p a2 sp sil
16 | ./data/waves/000016.wav|./data/berts/000016.npy|sil c ii3 ^ uai4 g uang2 b en3 ^ ie3 j iang1 ^ iou2 sh ao4 zh uang4 p ai4 zh ang2 g uan3 sp sil
17 | ./data/waves/000017.wav|./data/berts/000017.npy|sil ^ iou4 m an2 ^ u3 n iao2 n iao3 sp z uo3 q iong2 g e1 x i1 x i1 sp sil
18 | ./data/waves/000018.wav|./data/berts/000018.npy|sil ^ ian3 k uang4 k uan1 k uo4 ^ er2 d i1 ^ ai3 sp b i2 d uan3 ^ er2 k uan1 sp sil
19 | ./data/waves/000019.wav|./data/berts/000019.npy|sil x ia2 p u3 x ian4 ^ ia2 ch eng2 zh en4 ^ u1 q i2 sp ^ ua3 ^ iao2 c uen1 sh uei3 ^ uei4 m eng2 zh ang3 sp sil
20 | ./data/waves/000020.wav|./data/berts/000020.npy|sil h uo4 s ii1 ^ ian4 l u4 b ei4 x iou4 r u3 g ou1 x ing4 g an3 r e2 h uo3 sp sil
21 | ./data/waves/000021.wav|./data/berts/000021.npy|sil ^ ie3 k e2 ^ i3 g ei3 b en2 ^ uei3 ^ van2 f an3 ^ ing4 ^ iou5 sp sil
22 | ./data/waves/000022.wav|./data/berts/000022.npy|sil ^ iao2 l an2 p ai2 g ai4 ^ uei2 j ian4 ^ ing1 ^ er2 p ei4 f ang1 n ai2 f en3 sp sil
23 | ./data/waves/000023.wav|./data/berts/000023.npy|sil ^ e4 ^ v4 l u2 ^ uan3 s u1 j v2 d i4 d a4 b ao4 ^ v3 sp sil
24 | ./data/waves/000024.wav|./data/berts/000024.npy|sil ^ i3 x ia4 ^ uei2 x ve1 m an2 z ii5 g uan1 d ian3 zh ai1 ^ iao4 sp sil
25 | ./data/waves/000025.wav|./data/berts/000025.npy|sil ch en2 ^ v2 l uo4 ^ ian4 sp b i4 ^ ve4 x iou1 h ua1 sp sil
26 | ./data/waves/000026.wav|./data/berts/000026.npy|sil ^ er2 x v3 g uan4 ^ ing1 ^ i4 j iang1 ^ v2 m ing2 r iii4 ch u1 b in4 sp sil
27 | ./data/waves/000027.wav|./data/berts/000027.npy|sil ^ a1 j iao1 ^ v2 b ai3 ^ uei4 sp g uei2 f en3 s ii1 sp k uang2 h uan1 sp sil
28 | ./data/waves/000028.wav|./data/berts/000028.npy|sil f ang2 ^ u1 m ai3 m ai4 q i4 ^ ve1 h e2 sh ou1 t iao2 sp sil
29 | ./data/waves/000030.wav|./data/berts/000030.npy|sil h ei1 x iong2 ch uang3 j in4 ^ uang2 m ing2 h uei1 j ia1 h ou4 ^ van4 m i4 sh iii2 sp sil
30 | ./data/waves/000031.wav|./data/berts/000031.npy|sil ^ u2 ^ ia4 j vn1 ^ v3 zh ang4 f u5 c ai4 k uei2 sp sil
31 | ./data/waves/000032.wav|./data/berts/000032.npy|sil ^ ia4 x iao1 s uan1 ^ ian2 ^ uei4 ^ uei1 x ian2 sp ^ i4 r ong2 ^ v2 sh uei3 sp sil
32 | ./data/waves/000033.wav|./data/berts/000033.npy|sil ^ van4 q iang2 b ei4 sh iii2 m ian2 ^ ua3 zh e1 ^ ian3 sh ang4 l e5 sp sil
33 | ./data/waves/000034.wav|./data/berts/000034.npy|sil ch eng2 ^ in1 t iao1 x van2 ^ uo2 ^ ian3 zh ao4 ^ v4 m in3 sp sil
34 | ./data/waves/000035.wav|./data/berts/000035.npy|sil ^ in2 ch iii1 m ao3 l iang2 sp ^ u2 ^ i4 ^ v2 ^ in3 zh en4 zh iii2 k e3 sp sil
35 | ./data/waves/000036.wav|./data/berts/000036.npy|sil l v2 z ii5 ^ ve4 p ao3 ^ ve4 k uai4 sp ^ ve4 p ao3 ^ ve4 f eng1 k uang2 sp sil
36 | ./data/waves/000038.wav|./data/berts/000038.npy|sil g uang1 g uen4 j ie2 ^ ing2 x iao1 ^ ing4 ^ vn4 ^ er2 sh eng1 sp sil
37 | ./data/waves/000039.wav|./data/berts/000039.npy|sil g ao1 ^ ia1 t ie2 t a3 x ia4 d e5 d i1 ^ ai3 p eng2 ^ u1 sp sil
38 | ./data/waves/000040.wav|./data/berts/000040.npy|sil d ian4 ^ ing2 h ai3 b ao4 sp p in2 m in2 k u1 b ai3 ^ uan4 f u4 ^ ueng1 sp sp sil
39 | ./data/waves/000041.wav|./data/berts/000041.npy|sil d iao1 ^ uei2 l ie4 ^ i3 b ei4 l ing4 ^ an4 ch u2 l i3 sp sil
40 | ./data/waves/000042.wav|./data/berts/000042.npy|sil ^ vn2 ^ u4 ^ iou3 sh iii2 ^ uan3 r u2 ^ v4 d ai4 p ing2 ^ uo4 f eng1 l uan2 sh an1 j ian1 sp ^ iou3 sh iii2 ch uei1 ^ ian1 n iao3 r ao4 sp b o2 ^ u4 q ing1 x van2 sp sil
41 | ./data/waves/000043.wav|./data/berts/000043.npy|sil z uei3 b a5 zh ou1 ^ uei2 l ve4 ^ uei1 zh ong2 q i3 sp sil
42 | ./data/waves/000044.wav|./data/berts/000044.npy|sil h uan1 sh eng1 x iao4 ^ v3 s a2 m an3 c uen1 zh uang1 sp sil
43 | ./data/waves/000045.wav|./data/berts/000045.npy|sil sh a1 ch ang3 j ing4 ^ vn3 m ing4 sp zh uang4 zh iii4 ^ ie3 ^ u2 ^ uei2 sp sil
44 | ./data/waves/000046.wav|./data/berts/000046.npy|sil ^ iou2 zh a2 d ou4 f u5 p en1 p en1 x iang1 sp s an3 z ii5 m a2 h ua1 b eng1 b eng1 c uei4 sp z ii3 m ei4 t uan2 z ii5 sh u3 ^ er4 j iang1 sp sil
45 | ./data/waves/000048.wav|./data/berts/000048.npy|sil l in4 d ong1 ^ ing1 d e5 zh u4 ^ van4 b ing4 ^ an4 sp sil
46 | ./data/waves/000049.wav|./data/berts/000049.npy|sil q van3 b i4 x v1 sh uan1 ^ iang3 h uo4 j van4 ^ iang3 sp sil
47 | ./data/waves/000050.wav|./data/berts/000050.npy|sil h e4 l i4 h e2 ^ in4 d u4 d a4 h eng1 ^ a1 l ang2 n ai4 ^ er3 ^ i3 j ing1 f en1 j v1 sh u4 ^ ve4 sp sil
48 | ./data/waves/000051.wav|./data/berts/000051.npy|sil n an4 x iong1 n an4 d i4 ^ iao4 h ao2 h ao3 q ie1 c uo1 q ie1 c uo1 sp sil
49 | ./data/waves/000052.wav|./data/berts/000052.npy|sil x iao3 j vn1 d a4 t uei3 n ei4 c e4 ^ v1 q ing1 sp sil
50 | ./data/waves/000053.wav|./data/berts/000053.npy|sil ^ i3 j i2 sp c ii2 x iong2 k ong2 z ii2 n iao3 sp f u4 ^ van2 t u2 sp sil
51 | ./data/waves/000054.wav|./data/berts/000054.npy|sil ^ v3 n v2 ^ ian3 ^ van2 p ai1 ^ uen3 x i4 sp ch en2 x iao3 ch uen1 b u2 p a4 ^ ing4 c ai3 ^ er2 ch iii1 c u4 sp sil
52 | ./data/waves/000055.wav|./data/berts/000055.npy|sil f ang3 zh en1 zh iii1 sp c u1 h ua1 n i2 d eng3 ^ in4 h ua1 m ian4 l iao4 d e5 q iao3 m iao4 ^ vn4 ^ iong4 sp sil
53 | ./data/waves/000056.wav|./data/berts/000056.npy|sil ^ v3 ^ ve1 s e4 f u1 t ie1 m ian4 l ou3 b ao4 sp sil
54 | ./data/waves/000057.wav|./data/berts/000057.npy|sil n u2 ^ er3 b ie2 ^ er3 d e2 ^ ie1 ^ ua2 q ing3 zh ou1 ^ iong3 k ang1 zh uan3 d a2 d uei4 ^ u2 b ang1 g uo2 ^ uei3 ^ van2 zh ang3 d e5 q in1 q ie4 ^ uen4 h ou4 sp sil
55 | ./data/waves/000058.wav|./data/berts/000058.npy|sil j in1 ^ iou3 zh uang1 m an2 ^ er2 x ia4 zh ou1 j van1 sh en4 sp sil
56 | ./data/waves/000059.wav|./data/berts/000059.npy|sil ^ uo2 ^ uo3 ^ uo2 ^ uo3 ^ uo2 ^ uo3 h e2 p ang4 z ii5 ^ i4 q i3 sp sil
57 | ./data/waves/000060.wav|./data/berts/000060.npy|sil s ao3 s ao5 g ei3 z uo4 d e5 ^ iou2 m ian4 ^ uo1 ^ uo5 sp h e2 ^ iang2 r ou4 d uen4 sh an1 ^ iao4 sp sil
58 | ./data/waves/000061.wav|./data/berts/000061.npy|sil m an4 s a4 n i2 ^ ve1 g ang2 k ou3 b ei4 g uan1 b i4 sp sil
59 | ./data/waves/000062.wav|./data/berts/000062.npy|sil l iao4 b i4 ^ er2 m a2 ^ ia3 sh u1 h e2 l ao3 g ong1 sh ai4 ^ en1 ^ ai4 sp sil
60 | ./data/waves/000063.wav|./data/berts/000063.npy|sil ^ ua4 z ii5 t an1 x iao3 l ao2 b an3 ch en2 ^ iong3 q van2 sp sil
61 | ./data/waves/000064.wav|./data/berts/000064.npy|sil h e2 j i4 ch ou2 z ii1 ^ i1 d ian3 ^ i1 ^ er4 j iou3 ^ i1 ^ u2 ^ u3 ^ i4 ^ van2 sp sil
62 | ./data/waves/000065.wav|./data/berts/000065.npy|sil n iao3 ^ er2 zh a1 zh a1 sp z ou4 q i3 ch en2 q v3 sp sil
63 | ./data/waves/000066.wav|./data/berts/000066.npy|sil t i3 q iang2 zh uang4 ^ er2 h ou4 x ve2 ^ uen4 d ao4 d e2 zh iii1 j in4 x iou1 ^ iong3 ^ er2 sh ou1 x iao4 ^ van3 sp sil
64 | ./data/waves/000067.wav|./data/berts/000067.npy|sil ^ uang3 g ou4 h uo4 ^ uang3 s ou1 g u3 m in2 x in4 x i1 sp ^ i3 g ao1 h uei2 b ao4 ^ uei2 ^ er3 sp q ing3 j vn1 r u4 ^ ueng4 sp sp sil
65 | ./data/waves/000068.wav|./data/berts/000068.npy|sil t u2 ^ uei2 k ai2 m u3 b o2 ^ i1 sp x ie2 x ian4 sp ch ong1 c ii4 sp sil
66 | ./data/waves/000069.wav|./data/berts/000069.npy|sil j i2 j iang1 d iao4 ^ v2 t ai2 sp h uang2 ^ uei2 ^ v3 sp ch iii4 ^ v3 s an1 d ao3 sh ang2 g ei3 sh eng4 x van1 h uai2 ^ uei2 ch an3 ^ ie4 sp g ong1 c ai3 ^ iao4 zh iii1 ^ iong4 sp sil
67 | ./data/waves/000070.wav|./data/berts/000070.npy|sil ^ iong4 b u2 ^ iong4 ^ uo3 t i4 n i2 ^ u3 zh e5 z uei3 sp sil
68 | ./data/waves/000071.wav|./data/berts/000071.npy|sil ^ er4 j ing4 j ia1 zh u4 b ei2 ^ u3 h uan2 ^ uai4 sp sh ang4 b an1 ^ iao4 q v4 ^ ia4 ^ vn4 c uen1 h ua2 t ang2 sh ang1 ch ang3 sp sil
69 | ./data/waves/000072.wav|./data/berts/000072.npy|sil ^ iou3 g uan1 n eng2 f ou2 g ai3 k ong4 sp ch uan4 m ou2 sp sp sil
70 | ./data/waves/000073.wav|./data/berts/000073.npy|sil b u4 l v3 n i2 ^ ie2 ^ v3 q ian2 f u1 zh e2 x ve2 j ia1 l a1 f ei2 ^ er3 ^ v4 ^ iou3 ^ i4 z ii3 sp sil
71 | ./data/waves/000074.wav|./data/berts/000074.npy|sil ^ iou4 ^ iou3 sh uei2 k en2 sh e3 sh iii1 g ei3 t a1 n e5 sp sil
72 | ./data/waves/000075.wav|./data/berts/000075.npy|sil q ie4 z ei2 l ou4 t uei3 z ii4 ch eng1 h uan4 sp m a2 f eng1 b ing4 sp sp sil
73 | ./data/waves/000076.wav|./data/berts/000076.npy|sil l ao3 z u2 m u3 ^ i3 m ou2 sh a1 z uei4 b ei4 sh ou1 ^ ia1 sp sil
74 | ./data/waves/000077.wav|./data/berts/000077.npy|sil ^ ai3 c uo2 q iong2 sp b ie2 x ia1 k ua1 sp sil
75 | ./data/waves/000078.wav|./data/berts/000078.npy|sil ^ v2 z ai4 sh uei2 l i2 d a3 x van2 d e5 ^ uang3 ^ uai4 l uan4 b eng4 sp sil
76 | ./data/waves/000079.wav|./data/berts/000079.npy|sil zh uo2 ^ uen2 j vn1 sh uo1 sp ^ uan3 r u2 sh an1 sh ang4 x ve3 sp j iao3 r uo4 ^ vn2 j ian1 ^ ve4 sp sil
77 | ./data/waves/000080.wav|./data/berts/000080.npy|sil ^ i2 l ei4 m iao2 k e3 ^ uei4 m ing4 ^ vn4 d uo1 ch uan3 sp sil
78 | ./data/waves/000081.wav|./data/berts/000081.npy|sil m ou2 x iao3 s ong2 k uai4 k uai4 k an4 k an5 sp sil
79 | ./data/waves/000082.wav|./data/berts/000082.npy|sil t ie3 sh a1 zh ang3 ^ ai4 h ao4 zh e3 zh ang1 h uei1 b iao2 ^ ian3 p i1 zh uan1 sp sil
80 | ./data/waves/000083.wav|./data/berts/000083.npy|sil ^ er2 l v3 ^ ie4 x in1 ^ ie3 b ei4 ^ i4 g en1 l in3 z ii5 ^ ia1 zh e5 sp sil
81 | ./data/waves/000084.wav|./data/berts/000084.npy|sil zh uen1 zh uen1 d ing1 n ing2 sp ^ u4 ^ uang4 t ian1 sh an1 b u3 ^ v4 q ing2 sp sil
82 | ./data/waves/000086.wav|./data/berts/000086.npy|sil ^ u4 ^ iong1 h uei4 ^ ian2 sp r en2 ^ iou3 j i4 x ing4 sp ^ i4 ^ iou3 ^ uang4 x ing4 sp sil
83 | ./data/waves/000087.wav|./data/berts/000087.npy|sil c an1 h ou4 sp ^ i4 x ing2 r en2 ^ iou4 p ei2 ^ uang1 x iao3 f ei1 d ao4 l ing4 ^ i4 zh en2 s uo3 k an4 ^ er3 b i2 h ou2 k e1 sp sil
84 | ./data/waves/000088.wav|./data/berts/000088.npy|sil g an4 m a2 ^ a5 n i3 ^ iou4 l ai2 g ou1 d a5 sh ei2 sp sil
85 | ./data/waves/000089.wav|./data/berts/000089.npy|sil ^ ve1 ^ uan2 k a3 ^ v2 t ie3 t uo1 m ei2 ^ iou3 h ai2 z ii5 sp sil
86 | ./data/waves/000090.wav|./data/berts/000090.npy|sil sh ou4 s uen3 s ong4 d ai4 g e1 ^ iao2 q ing1 ^ iou4 k uei2 b an4 k ou3 p an2 zh eng4 m ian4 s uen2 h uei3 h ou4 sp sil
87 | ./data/waves/000091.wav|./data/berts/000091.npy|sil ^ ie3 h uan1 ^ ing2 ^ uei4 sh eng1 m ian2 ch ang3 sh ang1 zh ao2 ^ uo3 d ai4 ^ ian2 sp sil
88 | ./data/waves/000092.wav|./data/berts/000092.npy|sil l ou2 ^ i2 ^ uang3 h uan2 k ong1 l ong2 m u3 sp q i2 l in2 m ai2 m o4 j i3 ch uen1 q iou1 sp sil
89 | ./data/waves/000093.wav|./data/berts/000093.npy|sil r u2 h u4 f a4 h e2 h u4 ^ ian3 ^ iao4 d ai4 sh ang4 g uei1 j iao1 ^ iong3 m ao4 h e2 zh ai1 ch u2 ^ in3 x ing2 ^ ian3 j ing4 sp sil
90 | ./data/waves/000094.wav|./data/berts/000094.npy|sil m a3 x vn4 ^ uei2 ^ v3 ^ iang2 ^ i2 d eng3 g ong4 d u4 sh eng1 r iii4 sp sil
91 | ./data/waves/000096.wav|./data/berts/000096.npy|sil t u2 ^ uei2 zh ong1 sh an1 zh an1 ^ van2 g uang3 ^ iao1 q ian1 ^ v2 ^ iou2 k e4 p in3 ch ang2 ch ong2 ^ iang2 ^ i4 sh ou4 q ian1 s ou3 ^ ian4 sp sil
92 | ./data/waves/000097.wav|./data/berts/000097.npy|sil x iao3 ^ uen2 j ie2 q van2 sh en1 k u1 sh ou4 r u2 ch ai2 sp sil
93 | ./data/waves/000098.wav|./data/berts/000098.npy|sil q in2 m ou3 ^ iong4 l ian3 p en2 ^ iao2 sh uei3 j iang1 h uo3 p u1 m ie4 sp sil
94 | ./data/waves/000099.wav|./data/berts/000099.npy|sil ^ an1 q van2 t ao4 z uo4 ch eng2 ^ u1 d u2 ^ ua2 ^ ua5 ^ iang4 z ii5 sp sil
95 | ./data/waves/000100.wav|./data/berts/000100.npy|sil d i4 s an1 sh iii4 k ao3 ^ ia1 q ve1 j in1 d uan2 l iang3 sp sil
96 | ./data/waves/000101.wav|./data/berts/000101.npy|sil ch uan2 m a3 ^ i1 l i2 h uai2 sh ang5 ^ er4 t ai1 sp sil
97 | ./data/waves/000102.wav|./data/berts/000102.npy|sil f eng1 zh eng5 ^ iao4 x ian4 sp b en1 m a3 ^ iao4 j iang1 sp h ang2 ch uan2 ^ iao4 d uo4 sp sil
98 | ./data/waves/000103.wav|./data/berts/000103.npy|sil j ia3 ^ van2 ^ van5 t uo1 g ao1 ^ ve4 sh ang4 l ou2 sp sil
99 | ./data/waves/000104.wav|./data/berts/000104.npy|sil h ei1 p ao2 n iao2 sp x iao3 h ei1 sp sh iii5 l in2 b a1 g e1 sp sil
100 | ./data/waves/000105.wav|./data/berts/000105.npy|sil t a1 ^ ve4 ^ en4 sp ^ uo3 ^ ve4 h uang1 sp f a1 d ong4 ^ ve4 m an4 sp sil
101 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 |
4 | import commons
5 |
6 |
7 | def feature_loss(fmap_r, fmap_g):
8 | loss = 0
9 | for dr, dg in zip(fmap_r, fmap_g):
10 | for rl, gl in zip(dr, dg):
11 | rl = rl.float().detach()
12 | gl = gl.float()
13 | loss += torch.mean(torch.abs(rl - gl))
14 |
15 | return loss * 2
16 |
17 |
18 | def discriminator_loss(disc_real_outputs, disc_generated_outputs):
19 | loss = 0
20 | r_losses = []
21 | g_losses = []
22 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
23 | dr = dr.float()
24 | dg = dg.float()
25 | r_loss = torch.mean((1 - dr) ** 2)
26 | g_loss = torch.mean(dg**2)
27 | loss += r_loss + g_loss
28 | r_losses.append(r_loss.item())
29 | g_losses.append(g_loss.item())
30 |
31 | return loss, r_losses, g_losses
32 |
33 |
34 | def generator_loss(disc_outputs):
35 | loss = 0
36 | gen_losses = []
37 | for dg in disc_outputs:
38 | dg = dg.float()
39 | l = torch.mean((1 - dg) ** 2)
40 | gen_losses.append(l)
41 | loss += l
42 |
43 | return loss, gen_losses
44 |
45 |
46 | def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
47 | """
48 | z_p, logs_q: [b, h, t_t]
49 | m_p, logs_p: [b, h, t_t]
50 | """
51 | z_p = z_p.float()
52 | logs_q = logs_q.float()
53 | m_p = m_p.float()
54 | logs_p = logs_p.float()
55 | z_mask = z_mask.float()
56 |
57 | kl = logs_p - logs_q - 0.5
58 | kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
59 | kl = torch.sum(kl * z_mask)
60 | l = kl / torch.sum(z_mask)
61 | return l
62 |
--------------------------------------------------------------------------------
/mel_processing.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import random
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 | import torch.utils.data
8 | import numpy as np
9 | import librosa
10 | import librosa.util as librosa_util
11 | from librosa.util import normalize, pad_center, tiny
12 | from scipy.signal import get_window
13 | from scipy.io.wavfile import read
14 | from librosa.filters import mel as librosa_mel_fn
15 |
16 | MAX_WAV_VALUE = 32768.0
17 |
18 |
19 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
20 | """
21 | PARAMS
22 | ------
23 | C: compression factor
24 | """
25 | return torch.log(torch.clamp(x, min=clip_val) * C)
26 |
27 |
28 | def dynamic_range_decompression_torch(x, C=1):
29 | """
30 | PARAMS
31 | ------
32 | C: compression factor used to compress
33 | """
34 | return torch.exp(x) / C
35 |
36 |
37 | def spectral_normalize_torch(magnitudes):
38 | output = dynamic_range_compression_torch(magnitudes)
39 | return output
40 |
41 |
42 | def spectral_de_normalize_torch(magnitudes):
43 | output = dynamic_range_decompression_torch(magnitudes)
44 | return output
45 |
46 |
47 | mel_basis = {}
48 | hann_window = {}
49 |
50 |
51 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
52 | if torch.min(y) < -1.0:
53 | print("min value is ", torch.min(y))
54 | if torch.max(y) > 1.0:
55 | print("max value is ", torch.max(y))
56 |
57 | global hann_window
58 | dtype_device = str(y.dtype) + "_" + str(y.device)
59 | wnsize_dtype_device = str(win_size) + "_" + dtype_device
60 | if wnsize_dtype_device not in hann_window:
61 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
62 | dtype=y.dtype, device=y.device
63 | )
64 |
65 | y = torch.nn.functional.pad(
66 | y.unsqueeze(1),
67 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
68 | mode="reflect",
69 | )
70 | y = y.squeeze(1)
71 |
72 | spec = torch.stft(
73 | y,
74 | n_fft,
75 | hop_length=hop_size,
76 | win_length=win_size,
77 | window=hann_window[wnsize_dtype_device],
78 | center=center,
79 | pad_mode="reflect",
80 | normalized=False,
81 | onesided=True,
82 | return_complex=False,
83 | )
84 |
85 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
86 | return spec
87 |
88 |
89 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
90 | global mel_basis
91 | dtype_device = str(spec.dtype) + "_" + str(spec.device)
92 | fmax_dtype_device = str(fmax) + "_" + dtype_device
93 | if fmax_dtype_device not in mel_basis:
94 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
95 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
96 | dtype=spec.dtype, device=spec.device
97 | )
98 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
99 | spec = spectral_normalize_torch(spec)
100 | return spec
101 |
102 |
103 | def mel_spectrogram_torch(
104 | y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
105 | ):
106 | if torch.min(y) < -1.0:
107 | print("min value is ", torch.min(y))
108 | if torch.max(y) > 1.0:
109 | print("max value is ", torch.max(y))
110 |
111 | global mel_basis, hann_window
112 | dtype_device = str(y.dtype) + "_" + str(y.device)
113 | fmax_dtype_device = str(fmax) + "_" + dtype_device
114 | wnsize_dtype_device = str(win_size) + "_" + dtype_device
115 | if fmax_dtype_device not in mel_basis:
116 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
117 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
118 | dtype=y.dtype, device=y.device
119 | )
120 | if wnsize_dtype_device not in hann_window:
121 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
122 | dtype=y.dtype, device=y.device
123 | )
124 |
125 | y = torch.nn.functional.pad(
126 | y.unsqueeze(1),
127 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
128 | mode="reflect",
129 | )
130 | y = y.squeeze(1)
131 |
132 | spec = torch.stft(
133 | y,
134 | n_fft,
135 | hop_length=hop_size,
136 | win_length=win_size,
137 | window=hann_window[wnsize_dtype_device],
138 | center=center,
139 | pad_mode="reflect",
140 | normalized=False,
141 | onesided=True,
142 | return_complex=False,
143 | )
144 |
145 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
146 |
147 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
148 | spec = spectral_normalize_torch(spec)
149 |
150 | return spec
151 |
--------------------------------------------------------------------------------
/model_onnx.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | from pathlib import Path
3 | from typing import Any, Dict
4 |
5 | import math
6 | import onnx
7 | import torch
8 | import argparse
9 |
10 | from onnxruntime.quantization import QuantType, quantize_dynamic
11 |
12 | import utils
13 | import commons
14 | import attentions
15 | from torch import nn
16 | from models import DurationPredictor, ResidualCouplingBlock, Generator
17 | from text.symbols import symbols
18 |
19 |
20 | class TextEncoder(nn.Module):
21 | def __init__(
22 | self,
23 | n_vocab,
24 | out_channels,
25 | hidden_channels,
26 | filter_channels,
27 | n_heads,
28 | n_layers,
29 | kernel_size,
30 | p_dropout,
31 | ):
32 | super().__init__()
33 | self.n_vocab = n_vocab
34 | self.out_channels = out_channels
35 | self.hidden_channels = hidden_channels
36 | self.filter_channels = filter_channels
37 | self.n_heads = n_heads
38 | self.n_layers = n_layers
39 | self.kernel_size = kernel_size
40 | self.p_dropout = p_dropout
41 |
42 | self.emb = nn.Embedding(n_vocab, hidden_channels)
43 | # self.emb_bert = nn.Linear(256, hidden_channels)
44 | nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
45 |
46 | self.encoder = attentions.Encoder(
47 | hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
48 | )
49 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
50 |
51 | def forward(self, x, x_lengths):
52 | x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
53 | # if bert is not None:
54 | # b = self.emb_bert(bert)
55 | # x = x + b
56 | x = torch.transpose(x, 1, -1) # [b, h, t]
57 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
58 | x.dtype
59 | )
60 |
61 | x = self.encoder(x * x_mask, x_mask)
62 | stats = self.proj(x) * x_mask
63 |
64 | m, logs = torch.split(stats, self.out_channels, dim=1)
65 | return x, m, logs, x_mask
66 |
67 |
68 | class SynthesizerEval(nn.Module):
69 | """
70 | Synthesizer for Training
71 | """
72 |
73 | def __init__(
74 | self,
75 | n_vocab,
76 | spec_channels,
77 | segment_size,
78 | inter_channels,
79 | hidden_channels,
80 | filter_channels,
81 | n_heads,
82 | n_layers,
83 | kernel_size,
84 | p_dropout,
85 | resblock,
86 | resblock_kernel_sizes,
87 | resblock_dilation_sizes,
88 | upsample_rates,
89 | upsample_initial_channel,
90 | upsample_kernel_sizes,
91 | n_speakers=0,
92 | gin_channels=0,
93 | use_sdp=False,
94 | **kwargs
95 | ):
96 |
97 | super().__init__()
98 | self.n_vocab = n_vocab
99 | self.spec_channels = spec_channels
100 | self.inter_channels = inter_channels
101 | self.hidden_channels = hidden_channels
102 | self.filter_channels = filter_channels
103 | self.n_heads = n_heads
104 | self.n_layers = n_layers
105 | self.kernel_size = kernel_size
106 | self.p_dropout = p_dropout
107 | self.resblock = resblock
108 | self.resblock_kernel_sizes = resblock_kernel_sizes
109 | self.resblock_dilation_sizes = resblock_dilation_sizes
110 | self.upsample_rates = upsample_rates
111 | self.upsample_initial_channel = upsample_initial_channel
112 | self.upsample_kernel_sizes = upsample_kernel_sizes
113 | self.segment_size = segment_size
114 | self.n_speakers = n_speakers
115 | self.gin_channels = gin_channels
116 |
117 | self.enc_p = TextEncoder(
118 | n_vocab,
119 | inter_channels,
120 | hidden_channels,
121 | filter_channels,
122 | n_heads,
123 | n_layers,
124 | kernel_size,
125 | p_dropout,
126 | )
127 | self.dec = Generator(
128 | inter_channels,
129 | resblock,
130 | resblock_kernel_sizes,
131 | resblock_dilation_sizes,
132 | upsample_rates,
133 | upsample_initial_channel,
134 | upsample_kernel_sizes,
135 | gin_channels=gin_channels,
136 | )
137 | self.flow = ResidualCouplingBlock(
138 | inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
139 | )
140 | self.dp = DurationPredictor(
141 | hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
142 | )
143 | if n_speakers > 1:
144 | self.emb_g = nn.Embedding(n_speakers, gin_channels)
145 |
146 | def remove_weight_norm(self):
147 | self.flow.remove_weight_norm()
148 |
149 | def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1):
150 | x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
151 | if self.n_speakers > 0:
152 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
153 | else:
154 | g = None
155 |
156 | logw = self.dp(x, x_mask, g=g)
157 | w = torch.exp(logw) * x_mask * length_scale
158 | w_ceil = torch.ceil(w + 0.35)
159 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
160 | y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
161 | x_mask.dtype
162 | )
163 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
164 | attn = commons.generate_path(w_ceil, attn_mask)
165 |
166 | m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
167 | 1, 2
168 | ) # [b, t', t], [b, t, d] -> [b, d, t']
169 | logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
170 | 1, 2
171 | ) # [b, t', t], [b, t, d] -> [b, d, t']
172 |
173 | z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
174 | z = self.flow(z_p, y_mask, g=g, reverse=True)
175 | o = self.dec((z * y_mask), g=g)
176 | return o.squeeze()
177 |
178 |
179 | class OnnxModel(torch.nn.Module):
180 | def __init__(self, model: SynthesizerEval):
181 | super().__init__()
182 | self.model = model
183 |
184 | def forward(
185 | self,
186 | x,
187 | x_lengths,
188 | noise_scale=1,
189 | length_scale=1,
190 | ):
191 | return self.model.infer(
192 | x=x,
193 | x_lengths=x_lengths,
194 | noise_scale=noise_scale,
195 | length_scale=length_scale,
196 | )
197 |
198 |
199 | def add_meta_data(filename: str, meta_data: Dict[str, Any]):
200 | """Add meta data to an ONNX model. It is changed in-place.
201 |
202 | Args:
203 | filename:
204 | Filename of the ONNX model to be changed.
205 | meta_data:
206 | Key-value pairs.
207 | """
208 | model = onnx.load(filename)
209 | for key, value in meta_data.items():
210 | meta = model.metadata_props.add()
211 | meta.key = key
212 | meta.value = str(value)
213 |
214 | onnx.save(model, filename)
215 |
216 |
217 | @torch.no_grad()
218 | def main():
219 | parser = argparse.ArgumentParser(description='Inference code for bert vits models')
220 | parser.add_argument('--config', type=str, required=True)
221 | parser.add_argument('--model', type=str, required=True)
222 | args = parser.parse_args()
223 | config_file = args.config
224 | checkpoint = args.model
225 |
226 | hps = utils.get_hparams_from_file(config_file)
227 | print(hps)
228 |
229 | net_g = SynthesizerEval(
230 | len(symbols),
231 | hps.data.filter_length // 2 + 1,
232 | hps.train.segment_size // hps.data.hop_length,
233 | n_speakers=hps.data.n_speakers,
234 | **hps.model,
235 | )
236 |
237 | _ = net_g.eval()
238 | _ = utils.load_model(checkpoint, net_g)
239 | net_g.remove_weight_norm()
240 |
241 | x = torch.randint(low=0, high=100, size=(50,), dtype=torch.int64)
242 | x = x.unsqueeze(0)
243 |
244 | x_length = torch.tensor([x.shape[1]], dtype=torch.int64)
245 | noise_scale = torch.tensor([1], dtype=torch.float32)
246 | length_scale = torch.tensor([1], dtype=torch.float32)
247 |
248 | model = OnnxModel(net_g)
249 |
250 | opset_version = 13
251 |
252 | filename = "vits-chinese.onnx"
253 |
254 | torch.onnx.export(
255 | model,
256 | (x, x_length, noise_scale, length_scale),
257 | filename,
258 | opset_version=opset_version,
259 | input_names=[
260 | "x",
261 | "x_length",
262 | "noise_scale",
263 | "length_scale",
264 | ],
265 | output_names=["y"],
266 | dynamic_axes={
267 | "x": {1: "L"},
268 | "y": {0: "L"},
269 | },
270 | )
271 | meta_data = {
272 | "model_type": "vits",
273 | "comment": "csukuangfj",
274 | "language": "Chinese",
275 | "add_blank": int(hps.data.add_blank),
276 | "n_speakers": int(hps.data.n_speakers),
277 | "sample_rate": hps.data.sampling_rate,
278 | "punctuation": "",
279 | }
280 | print("meta_data", meta_data)
281 | add_meta_data(filename=filename, meta_data=meta_data)
282 |
283 | print("Generate int8 quantization models")
284 | filename_int8 = "vits-chinese.int8.onnx"
285 | quantize_dynamic(
286 | model_input=filename,
287 | model_output=filename_int8,
288 | weight_type=QuantType.QUInt8,
289 | )
290 | print(f"Saved to {filename} and {filename_int8}")
291 |
292 |
293 | if __name__ == "__main__":
294 | main()
295 |
--------------------------------------------------------------------------------
/model_onnx_stream.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | from pathlib import Path
3 | from typing import Any, Dict
4 |
5 | import math
6 | import onnx
7 | import torch
8 | import argparse
9 |
10 | import utils
11 | import commons
12 | import attentions
13 | from torch import nn
14 | from models import DurationPredictor, ResidualCouplingBlock, Generator
15 | from text.symbols import symbols
16 |
17 |
18 | class TextEncoder(nn.Module):
19 | def __init__(
20 | self,
21 | n_vocab,
22 | out_channels,
23 | hidden_channels,
24 | filter_channels,
25 | n_heads,
26 | n_layers,
27 | kernel_size,
28 | p_dropout,
29 | ):
30 | super().__init__()
31 | self.n_vocab = n_vocab
32 | self.out_channels = out_channels
33 | self.hidden_channels = hidden_channels
34 | self.filter_channels = filter_channels
35 | self.n_heads = n_heads
36 | self.n_layers = n_layers
37 | self.kernel_size = kernel_size
38 | self.p_dropout = p_dropout
39 |
40 | self.emb = nn.Embedding(n_vocab, hidden_channels)
41 | # self.emb_bert = nn.Linear(256, hidden_channels)
42 | nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
43 |
44 | self.encoder = attentions.Encoder(
45 | hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
46 | )
47 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
48 |
49 | def forward(self, x, x_lengths):
50 | x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
51 | # if bert is not None:
52 | # b = self.emb_bert(bert)
53 | # x = x + b
54 | x = torch.transpose(x, 1, -1) # [b, h, t]
55 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
56 | x.dtype
57 | )
58 |
59 | x = self.encoder(x * x_mask, x_mask)
60 | stats = self.proj(x) * x_mask
61 |
62 | m, logs = torch.split(stats, self.out_channels, dim=1)
63 | return x, m, logs, x_mask
64 |
65 |
66 | class VITS_Encoder(nn.Module):
67 |
68 | def __init__(
69 | self,
70 | n_vocab,
71 | spec_channels,
72 | segment_size,
73 | inter_channels,
74 | hidden_channels,
75 | filter_channels,
76 | n_heads,
77 | n_layers,
78 | kernel_size,
79 | p_dropout,
80 | resblock,
81 | resblock_kernel_sizes,
82 | resblock_dilation_sizes,
83 | upsample_rates,
84 | upsample_initial_channel,
85 | upsample_kernel_sizes,
86 | n_speakers=0,
87 | gin_channels=0,
88 | use_sdp=False,
89 | **kwargs
90 | ):
91 |
92 | super().__init__()
93 | self.n_speakers = n_speakers
94 | self.enc_p = TextEncoder(
95 | n_vocab,
96 | inter_channels,
97 | hidden_channels,
98 | filter_channels,
99 | n_heads,
100 | n_layers,
101 | kernel_size,
102 | p_dropout,
103 | )
104 |
105 | self.dp = DurationPredictor(
106 | hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
107 | )
108 | if n_speakers > 1:
109 | self.emb_g = nn.Embedding(n_speakers, gin_channels)
110 |
111 | def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1):
112 | x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
113 | if self.n_speakers > 0:
114 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
115 | else:
116 | g = None
117 |
118 | logw = self.dp(x, x_mask, g=g)
119 | w = torch.exp(logw) * x_mask * length_scale
120 | w_ceil = torch.ceil(w + 0.35)
121 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
122 | y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
123 | x_mask.dtype
124 | )
125 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
126 | attn = commons.generate_path(w_ceil, attn_mask)
127 | m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
128 | 1, 2
129 | ) # [b, t', t], [b, t, d] -> [b, d, t']
130 | logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
131 | 1, 2
132 | ) # [b, t', t], [b, t, d] -> [b, d, t']
133 | z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
134 | return z_p, y_mask
135 |
136 |
137 | class VITS_Decoder(nn.Module):
138 |
139 | def __init__(
140 | self,
141 | n_vocab,
142 | spec_channels,
143 | segment_size,
144 | inter_channels,
145 | hidden_channels,
146 | filter_channels,
147 | n_heads,
148 | n_layers,
149 | kernel_size,
150 | p_dropout,
151 | resblock,
152 | resblock_kernel_sizes,
153 | resblock_dilation_sizes,
154 | upsample_rates,
155 | upsample_initial_channel,
156 | upsample_kernel_sizes,
157 | n_speakers=0,
158 | gin_channels=0,
159 | use_sdp=False,
160 | **kwargs
161 | ):
162 |
163 | super().__init__()
164 | self.n_speakers = n_speakers
165 | self.dec = Generator(
166 | inter_channels,
167 | resblock,
168 | resblock_kernel_sizes,
169 | resblock_dilation_sizes,
170 | upsample_rates,
171 | upsample_initial_channel,
172 | upsample_kernel_sizes,
173 | gin_channels=gin_channels,
174 | )
175 | self.flow = ResidualCouplingBlock(
176 | inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
177 | )
178 | if n_speakers > 1:
179 | self.emb_g = nn.Embedding(n_speakers, gin_channels)
180 |
181 | def remove_weight_norm(self):
182 | self.flow.remove_weight_norm()
183 |
184 | def infer(self, z_p, y_mask, sid=None):
185 | if self.n_speakers > 0:
186 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
187 | else:
188 | g = None
189 | z = self.flow(z_p, y_mask, g=g, reverse=True)
190 | o = self.dec((z * y_mask), g=g)
191 | return o.squeeze()
192 |
193 |
194 | class OnnxModel_Encoder(torch.nn.Module):
195 | def __init__(self, model: VITS_Encoder):
196 | super().__init__()
197 | self.model = model
198 |
199 | def forward(self, x, x_lengths, noise_scale=1, length_scale=1):
200 | return self.model.infer(
201 | x=x,
202 | x_lengths=x_lengths,
203 | noise_scale=noise_scale,
204 | length_scale=length_scale,
205 | )
206 |
207 |
208 | class OnnxModel_Decoder(torch.nn.Module):
209 | def __init__(self, model: VITS_Decoder):
210 | super().__init__()
211 | self.model = model
212 |
213 | def forward(self, z_p, y_mask):
214 | return self.model.infer(
215 | z_p=z_p,
216 | y_mask=y_mask,
217 | )
218 |
219 |
220 | def add_meta_data(filename: str, meta_data: Dict[str, Any]):
221 | """Add meta data to an ONNX model. It is changed in-place.
222 |
223 | Args:
224 | filename:
225 | Filename of the ONNX model to be changed.
226 | meta_data:
227 | Key-value pairs.
228 | """
229 | model = onnx.load(filename)
230 | for key, value in meta_data.items():
231 | meta = model.metadata_props.add()
232 | meta.key = key
233 | meta.value = str(value)
234 |
235 | onnx.save(model, filename)
236 |
237 |
238 | @torch.no_grad()
239 | def main():
240 | parser = argparse.ArgumentParser(description='Inference code for bert vits models')
241 | parser.add_argument('--config', type=str, required=True)
242 | parser.add_argument('--model', type=str, required=True)
243 | args = parser.parse_args()
244 | config_file = args.config
245 | checkpoint = args.model
246 |
247 | hps = utils.get_hparams_from_file(config_file)
248 | print(hps)
249 |
250 | opset_version = 13
251 |
252 | # Encoder
253 | #########################################################################
254 | net_g = VITS_Encoder(
255 | len(symbols),
256 | hps.data.filter_length // 2 + 1,
257 | hps.train.segment_size // hps.data.hop_length,
258 | n_speakers=hps.data.n_speakers,
259 | **hps.model,
260 | )
261 |
262 | _ = net_g.eval()
263 | _ = utils.load_model(checkpoint, net_g)
264 |
265 | x = torch.randint(low=0, high=100, size=(50,), dtype=torch.int64)
266 | x = x.unsqueeze(0)
267 |
268 | x_length = torch.tensor([x.shape[1]], dtype=torch.int64)
269 | noise_scale = torch.tensor([1], dtype=torch.float32)
270 | length_scale = torch.tensor([1], dtype=torch.float32)
271 |
272 | encoder = OnnxModel_Encoder(net_g)
273 |
274 | filename = "vits-chinese-encoder.onnx"
275 |
276 | torch.onnx.export(
277 | encoder,
278 | (x, x_length, noise_scale, length_scale),
279 | filename,
280 | opset_version=opset_version,
281 | input_names=[
282 | "x",
283 | "x_length",
284 | "noise_scale",
285 | "length_scale",
286 | ],
287 | output_names=["z_p", "y_mask"],
288 | dynamic_axes={
289 | "x": {1: "L"},
290 | "z_p": {2: "L"},
291 | "y_mask": {2: "L"},
292 | },
293 | )
294 | meta_data = {
295 | "model_type": "vits-endocer",
296 | "comment": "onnx@csukuangfj",
297 | "language": "Chinese",
298 | "add_blank": int(hps.data.add_blank),
299 | "n_speakers": int(hps.data.n_speakers),
300 | "sample_rate": hps.data.sampling_rate,
301 | "punctuation": "",
302 | }
303 | print("meta_data", meta_data)
304 | add_meta_data(filename=filename, meta_data=meta_data)
305 |
306 | # Decoder
307 | #########################################################################
308 | net_g = VITS_Decoder(
309 | len(symbols),
310 | hps.data.filter_length // 2 + 1,
311 | hps.train.segment_size // hps.data.hop_length,
312 | n_speakers=hps.data.n_speakers,
313 | **hps.model,
314 | )
315 |
316 | _ = net_g.eval()
317 | _ = utils.load_model(checkpoint, net_g)
318 | net_g.remove_weight_norm()
319 |
320 | z_p = torch.rand(size=(1, hps.model.inter_channels, 200), dtype=torch.float32)
321 | y_mask = torch.randint(low=0, high=1, size=(1, 1, 200), dtype=torch.float32)
322 |
323 | decoder = OnnxModel_Decoder(net_g)
324 |
325 | filename = "vits-chinese-decoder.onnx"
326 |
327 | torch.onnx.export(
328 | decoder,
329 | (z_p, y_mask),
330 | filename,
331 | opset_version=opset_version,
332 | input_names=[
333 | "z_p",
334 | "y_mask",
335 | ],
336 | output_names=["y"],
337 | dynamic_axes={
338 | "y": {0: "L"},
339 | "z_p": {2: "L"},
340 | "y_mask": {2: "L"},
341 | },
342 | )
343 | meta_data = {
344 | "model_type": "vits-decoder",
345 | "comment": "onnx@csukuangfj",
346 | "language": "Chinese",
347 | "inter_channels": hps.model.inter_channels,
348 | "hop_length": hps.data.hop_length,
349 | }
350 | print("meta_data", meta_data)
351 | add_meta_data(filename=filename, meta_data=meta_data)
352 |
353 |
354 | if __name__ == "__main__":
355 | main()
356 |
--------------------------------------------------------------------------------
/modules.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | import numpy as np
4 | import scipy
5 | import torch
6 | from torch import nn
7 | from torch.nn import functional as F
8 |
9 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
10 | from torch.nn.utils import weight_norm, remove_weight_norm
11 |
12 | import commons
13 | from commons import init_weights, get_padding
14 | from transforms import piecewise_rational_quadratic_transform
15 |
16 |
17 | LRELU_SLOPE = 0.1
18 |
19 |
20 | class LayerNorm(nn.Module):
21 | def __init__(self, channels, eps=1e-5):
22 | super().__init__()
23 | self.channels = channels
24 | self.eps = eps
25 |
26 | self.gamma = nn.Parameter(torch.ones(channels))
27 | self.beta = nn.Parameter(torch.zeros(channels))
28 |
29 | def forward(self, x):
30 | x = x.transpose(1, -1)
31 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
32 | return x.transpose(1, -1)
33 |
34 |
35 | class ConvReluNorm(nn.Module):
36 | def __init__(
37 | self,
38 | in_channels,
39 | hidden_channels,
40 | out_channels,
41 | kernel_size,
42 | n_layers,
43 | p_dropout,
44 | ):
45 | super().__init__()
46 | self.in_channels = in_channels
47 | self.hidden_channels = hidden_channels
48 | self.out_channels = out_channels
49 | self.kernel_size = kernel_size
50 | self.n_layers = n_layers
51 | self.p_dropout = p_dropout
52 | assert n_layers > 1, "Number of layers should be larger than 0."
53 |
54 | self.conv_layers = nn.ModuleList()
55 | self.norm_layers = nn.ModuleList()
56 | self.conv_layers.append(
57 | nn.Conv1d(
58 | in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
59 | )
60 | )
61 | self.norm_layers.append(LayerNorm(hidden_channels))
62 | self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
63 | for _ in range(n_layers - 1):
64 | self.conv_layers.append(
65 | nn.Conv1d(
66 | hidden_channels,
67 | hidden_channels,
68 | kernel_size,
69 | padding=kernel_size // 2,
70 | )
71 | )
72 | self.norm_layers.append(LayerNorm(hidden_channels))
73 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
74 | self.proj.weight.data.zero_()
75 | self.proj.bias.data.zero_()
76 |
77 | def forward(self, x, x_mask):
78 | x_org = x
79 | for i in range(self.n_layers):
80 | x = self.conv_layers[i](x * x_mask)
81 | x = self.norm_layers[i](x)
82 | x = self.relu_drop(x)
83 | x = x_org + self.proj(x)
84 | return x * x_mask
85 |
86 |
87 | class DDSConv(nn.Module):
88 | """
89 | Dialted and Depth-Separable Convolution
90 | """
91 |
92 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
93 | super().__init__()
94 | self.channels = channels
95 | self.kernel_size = kernel_size
96 | self.n_layers = n_layers
97 | self.p_dropout = p_dropout
98 |
99 | self.drop = nn.Dropout(p_dropout)
100 | self.convs_sep = nn.ModuleList()
101 | self.convs_1x1 = nn.ModuleList()
102 | self.norms_1 = nn.ModuleList()
103 | self.norms_2 = nn.ModuleList()
104 | for i in range(n_layers):
105 | dilation = kernel_size**i
106 | padding = (kernel_size * dilation - dilation) // 2
107 | self.convs_sep.append(
108 | nn.Conv1d(
109 | channels,
110 | channels,
111 | kernel_size,
112 | groups=channels,
113 | dilation=dilation,
114 | padding=padding,
115 | )
116 | )
117 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
118 | self.norms_1.append(LayerNorm(channels))
119 | self.norms_2.append(LayerNorm(channels))
120 |
121 | def forward(self, x, x_mask, g=None):
122 | if g is not None:
123 | x = x + g
124 | for i in range(self.n_layers):
125 | y = self.convs_sep[i](x * x_mask)
126 | y = self.norms_1[i](y)
127 | y = F.gelu(y)
128 | y = self.convs_1x1[i](y)
129 | y = self.norms_2[i](y)
130 | y = F.gelu(y)
131 | y = self.drop(y)
132 | x = x + y
133 | return x * x_mask
134 |
135 |
136 | class WN(torch.nn.Module):
137 | def __init__(
138 | self,
139 | hidden_channels,
140 | kernel_size,
141 | dilation_rate,
142 | n_layers,
143 | gin_channels=0,
144 | p_dropout=0,
145 | ):
146 | super(WN, self).__init__()
147 | assert kernel_size % 2 == 1
148 | self.hidden_channels = hidden_channels
149 | self.kernel_size = (kernel_size,)
150 | self.dilation_rate = dilation_rate
151 | self.n_layers = n_layers
152 | self.gin_channels = gin_channels
153 | self.p_dropout = p_dropout
154 |
155 | self.in_layers = torch.nn.ModuleList()
156 | self.res_skip_layers = torch.nn.ModuleList()
157 | self.drop = nn.Dropout(p_dropout)
158 |
159 | if gin_channels != 0:
160 | cond_layer = torch.nn.Conv1d(
161 | gin_channels, 2 * hidden_channels * n_layers, 1
162 | )
163 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
164 |
165 | for i in range(n_layers):
166 | dilation = dilation_rate**i
167 | padding = int((kernel_size * dilation - dilation) / 2)
168 | in_layer = torch.nn.Conv1d(
169 | hidden_channels,
170 | 2 * hidden_channels,
171 | kernel_size,
172 | dilation=dilation,
173 | padding=padding,
174 | )
175 | in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
176 | self.in_layers.append(in_layer)
177 |
178 | # last one is not necessary
179 | if i < n_layers - 1:
180 | res_skip_channels = 2 * hidden_channels
181 | else:
182 | res_skip_channels = hidden_channels
183 |
184 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
185 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
186 | self.res_skip_layers.append(res_skip_layer)
187 |
188 | def forward(self, x, x_mask, g=None, **kwargs):
189 | output = torch.zeros_like(x)
190 | n_channels_tensor = torch.IntTensor([self.hidden_channels])
191 |
192 | if g is not None:
193 | g = self.cond_layer(g)
194 |
195 | for i in range(self.n_layers):
196 | x_in = self.in_layers[i](x)
197 | if g is not None:
198 | cond_offset = i * 2 * self.hidden_channels
199 | g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
200 | else:
201 | g_l = torch.zeros_like(x_in)
202 |
203 | acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
204 | acts = self.drop(acts)
205 |
206 | res_skip_acts = self.res_skip_layers[i](acts)
207 | if i < self.n_layers - 1:
208 | res_acts = res_skip_acts[:, : self.hidden_channels, :]
209 | x = (x + res_acts) * x_mask
210 | output = output + res_skip_acts[:, self.hidden_channels :, :]
211 | else:
212 | output = output + res_skip_acts
213 | return output * x_mask
214 |
215 | def remove_weight_norm(self):
216 | if self.gin_channels != 0:
217 | torch.nn.utils.remove_weight_norm(self.cond_layer)
218 | for l in self.in_layers:
219 | torch.nn.utils.remove_weight_norm(l)
220 | for l in self.res_skip_layers:
221 | torch.nn.utils.remove_weight_norm(l)
222 |
223 |
224 | class ResBlock1(torch.nn.Module):
225 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
226 | super(ResBlock1, self).__init__()
227 | self.convs1 = nn.ModuleList(
228 | [
229 | weight_norm(
230 | Conv1d(
231 | channels,
232 | channels,
233 | kernel_size,
234 | 1,
235 | dilation=dilation[0],
236 | padding=get_padding(kernel_size, dilation[0]),
237 | )
238 | ),
239 | weight_norm(
240 | Conv1d(
241 | channels,
242 | channels,
243 | kernel_size,
244 | 1,
245 | dilation=dilation[1],
246 | padding=get_padding(kernel_size, dilation[1]),
247 | )
248 | ),
249 | weight_norm(
250 | Conv1d(
251 | channels,
252 | channels,
253 | kernel_size,
254 | 1,
255 | dilation=dilation[2],
256 | padding=get_padding(kernel_size, dilation[2]),
257 | )
258 | ),
259 | ]
260 | )
261 | self.convs1.apply(init_weights)
262 |
263 | self.convs2 = nn.ModuleList(
264 | [
265 | weight_norm(
266 | Conv1d(
267 | channels,
268 | channels,
269 | kernel_size,
270 | 1,
271 | dilation=1,
272 | padding=get_padding(kernel_size, 1),
273 | )
274 | ),
275 | weight_norm(
276 | Conv1d(
277 | channels,
278 | channels,
279 | kernel_size,
280 | 1,
281 | dilation=1,
282 | padding=get_padding(kernel_size, 1),
283 | )
284 | ),
285 | weight_norm(
286 | Conv1d(
287 | channels,
288 | channels,
289 | kernel_size,
290 | 1,
291 | dilation=1,
292 | padding=get_padding(kernel_size, 1),
293 | )
294 | ),
295 | ]
296 | )
297 | self.convs2.apply(init_weights)
298 |
299 | def forward(self, x, x_mask=None):
300 | for c1, c2 in zip(self.convs1, self.convs2):
301 | xt = F.leaky_relu(x, LRELU_SLOPE)
302 | if x_mask is not None:
303 | xt = xt * x_mask
304 | xt = c1(xt)
305 | xt = F.leaky_relu(xt, LRELU_SLOPE)
306 | if x_mask is not None:
307 | xt = xt * x_mask
308 | xt = c2(xt)
309 | x = xt + x
310 | if x_mask is not None:
311 | x = x * x_mask
312 | return x
313 |
314 | def remove_weight_norm(self):
315 | for l in self.convs1:
316 | remove_weight_norm(l)
317 | for l in self.convs2:
318 | remove_weight_norm(l)
319 |
320 |
321 | class ResBlock2(torch.nn.Module):
322 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
323 | super(ResBlock2, self).__init__()
324 | self.convs = nn.ModuleList(
325 | [
326 | weight_norm(
327 | Conv1d(
328 | channels,
329 | channels,
330 | kernel_size,
331 | 1,
332 | dilation=dilation[0],
333 | padding=get_padding(kernel_size, dilation[0]),
334 | )
335 | ),
336 | weight_norm(
337 | Conv1d(
338 | channels,
339 | channels,
340 | kernel_size,
341 | 1,
342 | dilation=dilation[1],
343 | padding=get_padding(kernel_size, dilation[1]),
344 | )
345 | ),
346 | ]
347 | )
348 | self.convs.apply(init_weights)
349 |
350 | def forward(self, x, x_mask=None):
351 | for c in self.convs:
352 | xt = F.leaky_relu(x, LRELU_SLOPE)
353 | if x_mask is not None:
354 | xt = xt * x_mask
355 | xt = c(xt)
356 | x = xt + x
357 | if x_mask is not None:
358 | x = x * x_mask
359 | return x
360 |
361 | def remove_weight_norm(self):
362 | for l in self.convs:
363 | remove_weight_norm(l)
364 |
365 |
366 | class Log(nn.Module):
367 | def forward(self, x, x_mask, reverse=False, **kwargs):
368 | if not reverse:
369 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
370 | logdet = torch.sum(-y, [1, 2])
371 | return y, logdet
372 | else:
373 | x = torch.exp(x) * x_mask
374 | return x
375 |
376 |
377 | class Flip(nn.Module):
378 | def forward(self, x, *args, reverse=False, **kwargs):
379 | x = torch.flip(x, [1])
380 | if not reverse:
381 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
382 | return x, logdet
383 | else:
384 | return x
385 |
386 |
387 | class ElementwiseAffine(nn.Module):
388 | def __init__(self, channels):
389 | super().__init__()
390 | self.channels = channels
391 | self.m = nn.Parameter(torch.zeros(channels, 1))
392 | self.logs = nn.Parameter(torch.zeros(channels, 1))
393 |
394 | def forward(self, x, x_mask, reverse=False, **kwargs):
395 | if not reverse:
396 | y = self.m + torch.exp(self.logs) * x
397 | y = y * x_mask
398 | logdet = torch.sum(self.logs * x_mask, [1, 2])
399 | return y, logdet
400 | else:
401 | x = (x - self.m) * torch.exp(-self.logs) * x_mask
402 | return x
403 |
404 |
405 | class ResidualCouplingLayer(nn.Module):
406 | def __init__(
407 | self,
408 | channels,
409 | hidden_channels,
410 | kernel_size,
411 | dilation_rate,
412 | n_layers,
413 | p_dropout=0,
414 | gin_channels=0,
415 | mean_only=False,
416 | ):
417 | assert channels % 2 == 0, "channels should be divisible by 2"
418 | super().__init__()
419 | self.channels = channels
420 | self.hidden_channels = hidden_channels
421 | self.kernel_size = kernel_size
422 | self.dilation_rate = dilation_rate
423 | self.n_layers = n_layers
424 | self.half_channels = channels // 2
425 | self.mean_only = mean_only
426 |
427 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
428 | self.enc = WN(
429 | hidden_channels,
430 | kernel_size,
431 | dilation_rate,
432 | n_layers,
433 | p_dropout=p_dropout,
434 | gin_channels=gin_channels,
435 | )
436 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
437 | self.post.weight.data.zero_()
438 | self.post.bias.data.zero_()
439 |
440 | def forward(self, x, x_mask, g=None, reverse=False):
441 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
442 | h = self.pre(x0) * x_mask
443 | h = self.enc(h, x_mask, g=g)
444 | stats = self.post(h) * x_mask
445 | if not self.mean_only:
446 | m, logs = torch.split(stats, [self.half_channels] * 2, 1)
447 | else:
448 | m = stats
449 | logs = torch.zeros_like(m)
450 |
451 | if not reverse:
452 | x1 = m + x1 * torch.exp(logs) * x_mask
453 | x = torch.cat([x0, x1], 1)
454 | logdet = torch.sum(logs, [1, 2])
455 | return x, logdet
456 | else:
457 | x1 = (x1 - m) * torch.exp(-logs) * x_mask
458 | x = torch.cat([x0, x1], 1)
459 | return x
460 |
461 | def remove_weight_norm(self):
462 | self.enc.remove_weight_norm()
463 |
464 |
465 | class ConvFlow(nn.Module):
466 | def __init__(
467 | self,
468 | in_channels,
469 | filter_channels,
470 | kernel_size,
471 | n_layers,
472 | num_bins=10,
473 | tail_bound=5.0,
474 | ):
475 | super().__init__()
476 | self.in_channels = in_channels
477 | self.filter_channels = filter_channels
478 | self.kernel_size = kernel_size
479 | self.n_layers = n_layers
480 | self.num_bins = num_bins
481 | self.tail_bound = tail_bound
482 | self.half_channels = in_channels // 2
483 |
484 | self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
485 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
486 | self.proj = nn.Conv1d(
487 | filter_channels, self.half_channels * (num_bins * 3 - 1), 1
488 | )
489 | self.proj.weight.data.zero_()
490 | self.proj.bias.data.zero_()
491 |
492 | def forward(self, x, x_mask, g=None, reverse=False):
493 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
494 | h = self.pre(x0)
495 | h = self.convs(h, x_mask, g=g)
496 | h = self.proj(h) * x_mask
497 |
498 | b, c, t = x0.shape
499 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
500 |
501 | unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
502 | unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
503 | self.filter_channels
504 | )
505 | unnormalized_derivatives = h[..., 2 * self.num_bins :]
506 |
507 | x1, logabsdet = piecewise_rational_quadratic_transform(
508 | x1,
509 | unnormalized_widths,
510 | unnormalized_heights,
511 | unnormalized_derivatives,
512 | inverse=reverse,
513 | tails="linear",
514 | tail_bound=self.tail_bound,
515 | )
516 |
517 | x = torch.cat([x0, x1], 1) * x_mask
518 | logdet = torch.sum(logabsdet * x_mask, [1, 2])
519 | if not reverse:
520 | return x, logdet
521 | else:
522 | return x
523 |
--------------------------------------------------------------------------------
/monotonic_align/__init__.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from .monotonic_align.core import maximum_path_c
4 |
5 |
6 | def maximum_path(neg_cent, mask):
7 | """Cython optimized version.
8 | neg_cent: [b, t_t, t_s]
9 | mask: [b, t_t, t_s]
10 | """
11 | device = neg_cent.device
12 | dtype = neg_cent.dtype
13 | neg_cent = neg_cent.data.cpu().numpy().astype(np.float32)
14 | path = np.zeros(neg_cent.shape, dtype=np.int32)
15 |
16 | t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32)
17 | t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32)
18 | maximum_path_c(path, neg_cent, t_t_max, t_s_max)
19 | return torch.from_numpy(path).to(device=device, dtype=dtype)
20 |
--------------------------------------------------------------------------------
/monotonic_align/core.pyx:
--------------------------------------------------------------------------------
1 | cimport cython
2 | from cython.parallel import prange
3 |
4 |
5 | @cython.boundscheck(False)
6 | @cython.wraparound(False)
7 | cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
8 | cdef int x
9 | cdef int y
10 | cdef float v_prev
11 | cdef float v_cur
12 | cdef float tmp
13 | cdef int index = t_x - 1
14 |
15 | for y in range(t_y):
16 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
17 | if x == y:
18 | v_cur = max_neg_val
19 | else:
20 | v_cur = value[y-1, x]
21 | if x == 0:
22 | if y == 0:
23 | v_prev = 0.
24 | else:
25 | v_prev = max_neg_val
26 | else:
27 | v_prev = value[y-1, x-1]
28 | value[y, x] += max(v_prev, v_cur)
29 |
30 | for y in range(t_y - 1, -1, -1):
31 | path[y, index] = 1
32 | if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]):
33 | index = index - 1
34 |
35 |
36 | @cython.boundscheck(False)
37 | @cython.wraparound(False)
38 | cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
39 | cdef int b = paths.shape[0]
40 | cdef int i
41 | for i in prange(b, nogil=True):
42 | maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])
43 |
44 |
45 |
--------------------------------------------------------------------------------
/monotonic_align/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 | from Cython.Build import cythonize
3 | import numpy
4 |
5 | setup(
6 | name="monotonic_align",
7 | ext_modules=cythonize("core.pyx"),
8 | include_dirs=[numpy.get_include()],
9 | )
10 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.9.0
2 | Cython
3 | transformers
4 | tensorboard
5 | WeTextProcessing
6 |
--------------------------------------------------------------------------------
/text/__init__.py:
--------------------------------------------------------------------------------
1 | from text.symbols import symbols
2 |
3 |
4 | # Mappings from symbol to numeric ID and vice versa:
5 | _symbol_to_id = {s: i for i, s in enumerate(symbols)}
6 | _id_to_symbol = {i: s for i, s in enumerate(symbols)}
7 |
8 |
9 | def cleaned_text_to_sequence(cleaned_text):
10 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
11 | Args:
12 | text: string to convert to a sequence
13 | Returns:
14 | List of integers corresponding to the symbols in the text
15 | """
16 | sequence = [_symbol_to_id[symbol] for symbol in cleaned_text.split()]
17 | return sequence
18 |
19 |
20 | def sequence_to_text(sequence):
21 | """Converts a sequence of IDs back to a string"""
22 | result = ""
23 | for symbol_id in sequence:
24 | s = _id_to_symbol[symbol_id]
25 | result += s
26 | return result
27 |
28 |
29 | pinyin_dict = {
30 | "a": ("^", "a"),
31 | "ai": ("^", "ai"),
32 | "an": ("^", "an"),
33 | "ang": ("^", "ang"),
34 | "ao": ("^", "ao"),
35 | "ba": ("b", "a"),
36 | "bai": ("b", "ai"),
37 | "ban": ("b", "an"),
38 | "bang": ("b", "ang"),
39 | "bao": ("b", "ao"),
40 | "be": ("b", "e"),
41 | "bei": ("b", "ei"),
42 | "ben": ("b", "en"),
43 | "beng": ("b", "eng"),
44 | "bi": ("b", "i"),
45 | "bian": ("b", "ian"),
46 | "biao": ("b", "iao"),
47 | "bie": ("b", "ie"),
48 | "bin": ("b", "in"),
49 | "bing": ("b", "ing"),
50 | "bo": ("b", "o"),
51 | "bu": ("b", "u"),
52 | "ca": ("c", "a"),
53 | "cai": ("c", "ai"),
54 | "can": ("c", "an"),
55 | "cang": ("c", "ang"),
56 | "cao": ("c", "ao"),
57 | "ce": ("c", "e"),
58 | "cen": ("c", "en"),
59 | "ceng": ("c", "eng"),
60 | "cha": ("ch", "a"),
61 | "chai": ("ch", "ai"),
62 | "chan": ("ch", "an"),
63 | "chang": ("ch", "ang"),
64 | "chao": ("ch", "ao"),
65 | "che": ("ch", "e"),
66 | "chen": ("ch", "en"),
67 | "cheng": ("ch", "eng"),
68 | "chi": ("ch", "iii"),
69 | "chong": ("ch", "ong"),
70 | "chou": ("ch", "ou"),
71 | "chu": ("ch", "u"),
72 | "chua": ("ch", "ua"),
73 | "chuai": ("ch", "uai"),
74 | "chuan": ("ch", "uan"),
75 | "chuang": ("ch", "uang"),
76 | "chui": ("ch", "uei"),
77 | "chun": ("ch", "uen"),
78 | "chuo": ("ch", "uo"),
79 | "ci": ("c", "ii"),
80 | "cong": ("c", "ong"),
81 | "cou": ("c", "ou"),
82 | "cu": ("c", "u"),
83 | "cuan": ("c", "uan"),
84 | "cui": ("c", "uei"),
85 | "cun": ("c", "uen"),
86 | "cuo": ("c", "uo"),
87 | "da": ("d", "a"),
88 | "dai": ("d", "ai"),
89 | "dan": ("d", "an"),
90 | "dang": ("d", "ang"),
91 | "dao": ("d", "ao"),
92 | "de": ("d", "e"),
93 | "dei": ("d", "ei"),
94 | "den": ("d", "en"),
95 | "deng": ("d", "eng"),
96 | "di": ("d", "i"),
97 | "dia": ("d", "ia"),
98 | "dian": ("d", "ian"),
99 | "diao": ("d", "iao"),
100 | "die": ("d", "ie"),
101 | "ding": ("d", "ing"),
102 | "diu": ("d", "iou"),
103 | "dong": ("d", "ong"),
104 | "dou": ("d", "ou"),
105 | "du": ("d", "u"),
106 | "duan": ("d", "uan"),
107 | "dui": ("d", "uei"),
108 | "dun": ("d", "uen"),
109 | "duo": ("d", "uo"),
110 | "e": ("^", "e"),
111 | "ei": ("^", "ei"),
112 | "en": ("^", "en"),
113 | "ng": ("^", "en"),
114 | "eng": ("^", "eng"),
115 | "er": ("^", "er"),
116 | "fa": ("f", "a"),
117 | "fan": ("f", "an"),
118 | "fang": ("f", "ang"),
119 | "fei": ("f", "ei"),
120 | "fen": ("f", "en"),
121 | "feng": ("f", "eng"),
122 | "fo": ("f", "o"),
123 | "fou": ("f", "ou"),
124 | "fu": ("f", "u"),
125 | "ga": ("g", "a"),
126 | "gai": ("g", "ai"),
127 | "gan": ("g", "an"),
128 | "gang": ("g", "ang"),
129 | "gao": ("g", "ao"),
130 | "ge": ("g", "e"),
131 | "gei": ("g", "ei"),
132 | "gen": ("g", "en"),
133 | "geng": ("g", "eng"),
134 | "gong": ("g", "ong"),
135 | "gou": ("g", "ou"),
136 | "gu": ("g", "u"),
137 | "gua": ("g", "ua"),
138 | "guai": ("g", "uai"),
139 | "guan": ("g", "uan"),
140 | "guang": ("g", "uang"),
141 | "gui": ("g", "uei"),
142 | "gun": ("g", "uen"),
143 | "guo": ("g", "uo"),
144 | "ha": ("h", "a"),
145 | "hai": ("h", "ai"),
146 | "han": ("h", "an"),
147 | "hang": ("h", "ang"),
148 | "hao": ("h", "ao"),
149 | "he": ("h", "e"),
150 | "hei": ("h", "ei"),
151 | "hen": ("h", "en"),
152 | "heng": ("h", "eng"),
153 | "hong": ("h", "ong"),
154 | "hou": ("h", "ou"),
155 | "hu": ("h", "u"),
156 | "hua": ("h", "ua"),
157 | "huai": ("h", "uai"),
158 | "huan": ("h", "uan"),
159 | "huang": ("h", "uang"),
160 | "hui": ("h", "uei"),
161 | "hun": ("h", "uen"),
162 | "huo": ("h", "uo"),
163 | "ji": ("j", "i"),
164 | "jia": ("j", "ia"),
165 | "jian": ("j", "ian"),
166 | "jiang": ("j", "iang"),
167 | "jiao": ("j", "iao"),
168 | "jie": ("j", "ie"),
169 | "jin": ("j", "in"),
170 | "jing": ("j", "ing"),
171 | "jiong": ("j", "iong"),
172 | "jiu": ("j", "iou"),
173 | "ju": ("j", "v"),
174 | "juan": ("j", "van"),
175 | "jue": ("j", "ve"),
176 | "jun": ("j", "vn"),
177 | "ka": ("k", "a"),
178 | "kai": ("k", "ai"),
179 | "kan": ("k", "an"),
180 | "kang": ("k", "ang"),
181 | "kao": ("k", "ao"),
182 | "ke": ("k", "e"),
183 | "kei": ("k", "ei"),
184 | "ken": ("k", "en"),
185 | "keng": ("k", "eng"),
186 | "kong": ("k", "ong"),
187 | "kou": ("k", "ou"),
188 | "ku": ("k", "u"),
189 | "kua": ("k", "ua"),
190 | "kuai": ("k", "uai"),
191 | "kuan": ("k", "uan"),
192 | "kuang": ("k", "uang"),
193 | "kui": ("k", "uei"),
194 | "kun": ("k", "uen"),
195 | "kuo": ("k", "uo"),
196 | "la": ("l", "a"),
197 | "lai": ("l", "ai"),
198 | "lan": ("l", "an"),
199 | "lang": ("l", "ang"),
200 | "lao": ("l", "ao"),
201 | "le": ("l", "e"),
202 | "lei": ("l", "ei"),
203 | "leng": ("l", "eng"),
204 | "li": ("l", "i"),
205 | "lia": ("l", "ia"),
206 | "lian": ("l", "ian"),
207 | "liang": ("l", "iang"),
208 | "liao": ("l", "iao"),
209 | "lie": ("l", "ie"),
210 | "lin": ("l", "in"),
211 | "ling": ("l", "ing"),
212 | "liu": ("l", "iou"),
213 | "lo": ("l", "o"),
214 | "long": ("l", "ong"),
215 | "lou": ("l", "ou"),
216 | "lu": ("l", "u"),
217 | "lv": ("l", "v"),
218 | "luan": ("l", "uan"),
219 | "lve": ("l", "ve"),
220 | "lue": ("l", "ve"),
221 | "lun": ("l", "uen"),
222 | "luo": ("l", "uo"),
223 | "ma": ("m", "a"),
224 | "mai": ("m", "ai"),
225 | "man": ("m", "an"),
226 | "mang": ("m", "ang"),
227 | "mao": ("m", "ao"),
228 | "me": ("m", "e"),
229 | "mei": ("m", "ei"),
230 | "men": ("m", "en"),
231 | "meng": ("m", "eng"),
232 | "mi": ("m", "i"),
233 | "mian": ("m", "ian"),
234 | "miao": ("m", "iao"),
235 | "mie": ("m", "ie"),
236 | "min": ("m", "in"),
237 | "ming": ("m", "ing"),
238 | "miu": ("m", "iou"),
239 | "mo": ("m", "o"),
240 | "mou": ("m", "ou"),
241 | "mu": ("m", "u"),
242 | "n": ("^", "en"),
243 | "na": ("n", "a"),
244 | "nai": ("n", "ai"),
245 | "nan": ("n", "an"),
246 | "nang": ("n", "ang"),
247 | "nao": ("n", "ao"),
248 | "ne": ("n", "e"),
249 | "nei": ("n", "ei"),
250 | "nen": ("n", "en"),
251 | "neng": ("n", "eng"),
252 | "ni": ("n", "i"),
253 | "nia": ("n", "ia"),
254 | "nian": ("n", "ian"),
255 | "niang": ("n", "iang"),
256 | "niao": ("n", "iao"),
257 | "nie": ("n", "ie"),
258 | "nin": ("n", "in"),
259 | "ning": ("n", "ing"),
260 | "niu": ("n", "iou"),
261 | "nong": ("n", "ong"),
262 | "nou": ("n", "ou"),
263 | "nu": ("n", "u"),
264 | "nv": ("n", "v"),
265 | "nuan": ("n", "uan"),
266 | "nve": ("n", "ve"),
267 | "nue": ("n", "ve"),
268 | "nuo": ("n", "uo"),
269 | "o": ("^", "o"),
270 | "ou": ("^", "ou"),
271 | "pa": ("p", "a"),
272 | "pai": ("p", "ai"),
273 | "pan": ("p", "an"),
274 | "pang": ("p", "ang"),
275 | "pao": ("p", "ao"),
276 | "pe": ("p", "e"),
277 | "pei": ("p", "ei"),
278 | "pen": ("p", "en"),
279 | "peng": ("p", "eng"),
280 | "pi": ("p", "i"),
281 | "pian": ("p", "ian"),
282 | "piao": ("p", "iao"),
283 | "pie": ("p", "ie"),
284 | "pin": ("p", "in"),
285 | "ping": ("p", "ing"),
286 | "po": ("p", "o"),
287 | "pou": ("p", "ou"),
288 | "pu": ("p", "u"),
289 | "qi": ("q", "i"),
290 | "qia": ("q", "ia"),
291 | "qian": ("q", "ian"),
292 | "qiang": ("q", "iang"),
293 | "qiao": ("q", "iao"),
294 | "qie": ("q", "ie"),
295 | "qin": ("q", "in"),
296 | "qing": ("q", "ing"),
297 | "qiong": ("q", "iong"),
298 | "qiu": ("q", "iou"),
299 | "qu": ("q", "v"),
300 | "quan": ("q", "van"),
301 | "que": ("q", "ve"),
302 | "qun": ("q", "vn"),
303 | "ran": ("r", "an"),
304 | "rang": ("r", "ang"),
305 | "rao": ("r", "ao"),
306 | "re": ("r", "e"),
307 | "ren": ("r", "en"),
308 | "reng": ("r", "eng"),
309 | "ri": ("r", "iii"),
310 | "rong": ("r", "ong"),
311 | "rou": ("r", "ou"),
312 | "ru": ("r", "u"),
313 | "rua": ("r", "ua"),
314 | "ruan": ("r", "uan"),
315 | "rui": ("r", "uei"),
316 | "run": ("r", "uen"),
317 | "ruo": ("r", "uo"),
318 | "sa": ("s", "a"),
319 | "sai": ("s", "ai"),
320 | "san": ("s", "an"),
321 | "sang": ("s", "ang"),
322 | "sao": ("s", "ao"),
323 | "se": ("s", "e"),
324 | "sen": ("s", "en"),
325 | "seng": ("s", "eng"),
326 | "sha": ("sh", "a"),
327 | "shai": ("sh", "ai"),
328 | "shan": ("sh", "an"),
329 | "shang": ("sh", "ang"),
330 | "shao": ("sh", "ao"),
331 | "she": ("sh", "e"),
332 | "shei": ("sh", "ei"),
333 | "shen": ("sh", "en"),
334 | "sheng": ("sh", "eng"),
335 | "shi": ("sh", "iii"),
336 | "shou": ("sh", "ou"),
337 | "shu": ("sh", "u"),
338 | "shua": ("sh", "ua"),
339 | "shuai": ("sh", "uai"),
340 | "shuan": ("sh", "uan"),
341 | "shuang": ("sh", "uang"),
342 | "shui": ("sh", "uei"),
343 | "shun": ("sh", "uen"),
344 | "shuo": ("sh", "uo"),
345 | "si": ("s", "ii"),
346 | "song": ("s", "ong"),
347 | "sou": ("s", "ou"),
348 | "su": ("s", "u"),
349 | "suan": ("s", "uan"),
350 | "sui": ("s", "uei"),
351 | "sun": ("s", "uen"),
352 | "suo": ("s", "uo"),
353 | "ta": ("t", "a"),
354 | "tai": ("t", "ai"),
355 | "tan": ("t", "an"),
356 | "tang": ("t", "ang"),
357 | "tao": ("t", "ao"),
358 | "te": ("t", "e"),
359 | "tei": ("t", "ei"),
360 | "teng": ("t", "eng"),
361 | "ti": ("t", "i"),
362 | "tian": ("t", "ian"),
363 | "tiao": ("t", "iao"),
364 | "tie": ("t", "ie"),
365 | "ting": ("t", "ing"),
366 | "tong": ("t", "ong"),
367 | "tou": ("t", "ou"),
368 | "tu": ("t", "u"),
369 | "tuan": ("t", "uan"),
370 | "tui": ("t", "uei"),
371 | "tun": ("t", "uen"),
372 | "tuo": ("t", "uo"),
373 | "wa": ("^", "ua"),
374 | "wai": ("^", "uai"),
375 | "wan": ("^", "uan"),
376 | "wang": ("^", "uang"),
377 | "wei": ("^", "uei"),
378 | "wen": ("^", "uen"),
379 | "weng": ("^", "ueng"),
380 | "wo": ("^", "uo"),
381 | "wu": ("^", "u"),
382 | "xi": ("x", "i"),
383 | "xia": ("x", "ia"),
384 | "xian": ("x", "ian"),
385 | "xiang": ("x", "iang"),
386 | "xiao": ("x", "iao"),
387 | "xie": ("x", "ie"),
388 | "xin": ("x", "in"),
389 | "xing": ("x", "ing"),
390 | "xiong": ("x", "iong"),
391 | "xiu": ("x", "iou"),
392 | "xu": ("x", "v"),
393 | "xuan": ("x", "van"),
394 | "xue": ("x", "ve"),
395 | "xun": ("x", "vn"),
396 | "ya": ("^", "ia"),
397 | "yan": ("^", "ian"),
398 | "yang": ("^", "iang"),
399 | "yao": ("^", "iao"),
400 | "ye": ("^", "ie"),
401 | "yi": ("^", "i"),
402 | "yin": ("^", "in"),
403 | "ying": ("^", "ing"),
404 | "yo": ("^", "iou"),
405 | "yong": ("^", "iong"),
406 | "you": ("^", "iou"),
407 | "yu": ("^", "v"),
408 | "yuan": ("^", "van"),
409 | "yue": ("^", "ve"),
410 | "yun": ("^", "vn"),
411 | "za": ("z", "a"),
412 | "zai": ("z", "ai"),
413 | "zan": ("z", "an"),
414 | "zang": ("z", "ang"),
415 | "zao": ("z", "ao"),
416 | "ze": ("z", "e"),
417 | "zei": ("z", "ei"),
418 | "zen": ("z", "en"),
419 | "zeng": ("z", "eng"),
420 | "zha": ("zh", "a"),
421 | "zhai": ("zh", "ai"),
422 | "zhan": ("zh", "an"),
423 | "zhang": ("zh", "ang"),
424 | "zhao": ("zh", "ao"),
425 | "zhe": ("zh", "e"),
426 | "zhei": ("zh", "ei"),
427 | "zhen": ("zh", "en"),
428 | "zheng": ("zh", "eng"),
429 | "zhi": ("zh", "iii"),
430 | "zhong": ("zh", "ong"),
431 | "zhou": ("zh", "ou"),
432 | "zhu": ("zh", "u"),
433 | "zhua": ("zh", "ua"),
434 | "zhuai": ("zh", "uai"),
435 | "zhuan": ("zh", "uan"),
436 | "zhuang": ("zh", "uang"),
437 | "zhui": ("zh", "uei"),
438 | "zhun": ("zh", "uen"),
439 | "zhuo": ("zh", "uo"),
440 | "zi": ("z", "ii"),
441 | "zong": ("z", "ong"),
442 | "zou": ("z", "ou"),
443 | "zu": ("z", "u"),
444 | "zuan": ("z", "uan"),
445 | "zui": ("z", "uei"),
446 | "zun": ("z", "uen"),
447 | "zuo": ("z", "uo"),
448 | }
449 |
--------------------------------------------------------------------------------
/text/pinyin-local.txt:
--------------------------------------------------------------------------------
1 | 浅浅 qian3 qian3
--------------------------------------------------------------------------------
/text/symbols.py:
--------------------------------------------------------------------------------
1 | _pause = ["sil", "eos", "sp", "#0", "#1", "#2", "#3"]
2 |
3 | _initials = [
4 | "^",
5 | "b",
6 | "c",
7 | "ch",
8 | "d",
9 | "f",
10 | "g",
11 | "h",
12 | "j",
13 | "k",
14 | "l",
15 | "m",
16 | "n",
17 | "p",
18 | "q",
19 | "r",
20 | "s",
21 | "sh",
22 | "t",
23 | "x",
24 | "z",
25 | "zh",
26 | ]
27 |
28 | _tones = ["1", "2", "3", "4", "5"]
29 |
30 | _finals = [
31 | "a",
32 | "ai",
33 | "an",
34 | "ang",
35 | "ao",
36 | "e",
37 | "ei",
38 | "en",
39 | "eng",
40 | "er",
41 | "i",
42 | "ia",
43 | "ian",
44 | "iang",
45 | "iao",
46 | "ie",
47 | "ii",
48 | "iii",
49 | "in",
50 | "ing",
51 | "iong",
52 | "iou",
53 | "o",
54 | "ong",
55 | "ou",
56 | "u",
57 | "ua",
58 | "uai",
59 | "uan",
60 | "uang",
61 | "uei",
62 | "uen",
63 | "ueng",
64 | "uo",
65 | "v",
66 | "van",
67 | "ve",
68 | "vn",
69 | ]
70 |
71 | symbols = _pause + _initials + [i + j for i in _finals for j in _tones]
72 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | logging.getLogger('numba').setLevel(logging.WARNING)
4 | logging.getLogger('matplotlib').setLevel(logging.WARNING)
5 |
6 | import os
7 | import json
8 | import argparse
9 | import itertools
10 | import math
11 | import torch
12 | import tqdm
13 | from torch import nn, optim
14 | from torch.nn import functional as F
15 | from torch.utils.data import DataLoader
16 | from torch.utils.tensorboard import SummaryWriter
17 | import torch.multiprocessing as mp
18 | import torch.distributed as dist
19 | from torch.nn.parallel import DistributedDataParallel as DDP
20 | from torch.cuda.amp import autocast, GradScaler
21 |
22 | import commons
23 | import utils
24 | from data_utils import TextAudioLoader, TextAudioCollate, DistributedBucketSampler
25 | from models import MultiPeriodDiscriminator
26 | from losses import generator_loss, discriminator_loss, feature_loss, kl_loss
27 | from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
28 | from text.symbols import symbols
29 | import platform
30 |
31 | torch.backends.cudnn.benchmark = True
32 | global_step = 0
33 |
34 |
35 | def main():
36 | """Assume Single Node Multi GPUs Training Only"""
37 | assert torch.cuda.is_available(), "CPU training is not allowed."
38 |
39 | n_gpus = torch.cuda.device_count()
40 | os.environ["MASTER_ADDR"] = "localhost"
41 | os.environ["MASTER_PORT"] = "40000"
42 |
43 | hps = utils.get_hparams()
44 | mp.spawn(
45 | run,
46 | nprocs=n_gpus,
47 | args=(
48 | n_gpus,
49 | hps,
50 | ),
51 | )
52 |
53 |
54 | def run(rank, n_gpus, hps):
55 | global global_step
56 | if rank == 0:
57 | logger = utils.get_logger(hps.model_dir)
58 | logger.info(hps)
59 | utils.check_git_hash(hps.model_dir)
60 | writer = SummaryWriter(log_dir=hps.model_dir)
61 | writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
62 |
63 | backend_str = (platform.system().lower() == "windows") and "gloo" or "nccl"
64 | dist.init_process_group(
65 | backend=backend_str, init_method="env://", world_size=n_gpus, rank=rank
66 | )
67 | torch.manual_seed(hps.train.seed)
68 | torch.cuda.set_device(rank)
69 |
70 | train_dataset = TextAudioLoader(hps.data.training_files, hps.data)
71 | train_sampler = DistributedBucketSampler(
72 | train_dataset,
73 | hps.train.batch_size,
74 | [32, 300, 400, 500, 600, 700, 800, 900, 1000],
75 | num_replicas=n_gpus,
76 | rank=rank,
77 | shuffle=True,
78 | )
79 | # It is possible that dataloader's workers are out of shared memory. Please try to raise your shared memory limit.
80 | # num_workers=8 -> num_workers=4
81 | collate_fn = TextAudioCollate()
82 | train_loader = DataLoader(
83 | train_dataset,
84 | num_workers=8,
85 | shuffle=False,
86 | pin_memory=True,
87 | collate_fn=collate_fn,
88 | batch_sampler=train_sampler,
89 | )
90 | if rank == 0:
91 | eval_dataset = TextAudioLoader(hps.data.validation_files, hps.data)
92 | eval_loader = DataLoader(
93 | eval_dataset,
94 | num_workers=8,
95 | shuffle=False,
96 | batch_size=hps.train.batch_size,
97 | pin_memory=True,
98 | drop_last=False,
99 | collate_fn=collate_fn,
100 | )
101 |
102 | net_g = utils.load_class(hps.train.train_class)(
103 | len(symbols),
104 | hps.data.filter_length // 2 + 1,
105 | hps.train.segment_size // hps.data.hop_length,
106 | **hps.model,
107 | ).cuda(rank)
108 | net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
109 | optim_g = torch.optim.AdamW(
110 | net_g.parameters(),
111 | hps.train.learning_rate,
112 | betas=hps.train.betas,
113 | eps=hps.train.eps,
114 | )
115 | optim_d = torch.optim.AdamW(
116 | net_d.parameters(),
117 | hps.train.learning_rate,
118 | betas=hps.train.betas,
119 | eps=hps.train.eps,
120 | )
121 |
122 | try:
123 | teacher = getattr(hps.train, "teacher")
124 | if rank == 0:
125 | logger.info(f"Has teacher model: {teacher}")
126 |
127 | net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
128 | utils.load_teacher(teacher, net_g)
129 | except:
130 |
131 | net_g = DDP(net_g, device_ids=[rank])
132 | if rank == 0:
133 | logger.info("no teacher model.")
134 |
135 | net_d = DDP(net_d, device_ids=[rank])
136 |
137 | try:
138 | _, _, _, epoch_str = utils.load_checkpoint(
139 | utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g
140 | )
141 | _, _, _, epoch_str = utils.load_checkpoint(
142 | utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d
143 | )
144 | global_step = (epoch_str - 1) * len(train_loader)
145 | except:
146 | epoch_str = 1
147 | global_step = 0
148 |
149 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
150 | optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
151 | )
152 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
153 | optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
154 | )
155 |
156 | scaler = GradScaler(enabled=hps.train.fp16_run)
157 |
158 | for epoch in range(epoch_str, hps.train.epochs + 1):
159 | if rank == 0:
160 | train_and_evaluate(
161 | rank,
162 | epoch,
163 | hps,
164 | [net_g, net_d],
165 | [optim_g, optim_d],
166 | [scheduler_g, scheduler_d],
167 | scaler,
168 | [train_loader, eval_loader],
169 | logger,
170 | [writer, writer_eval],
171 | )
172 | else:
173 | train_and_evaluate(
174 | rank,
175 | epoch,
176 | hps,
177 | [net_g, net_d],
178 | [optim_g, optim_d],
179 | [scheduler_g, scheduler_d],
180 | scaler,
181 | [train_loader, None],
182 | None,
183 | None,
184 | )
185 | scheduler_g.step()
186 | scheduler_d.step()
187 |
188 |
189 | def train_and_evaluate(
190 | rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
191 | ):
192 | net_g, net_d = nets
193 | optim_g, optim_d = optims
194 | scheduler_g, scheduler_d = schedulers
195 | train_loader, eval_loader = loaders
196 | if writers is not None:
197 | writer, writer_eval = writers
198 |
199 | train_loader.batch_sampler.set_epoch(epoch)
200 | global global_step
201 |
202 | net_g.train()
203 | net_d.train()
204 | if rank == 0:
205 | loader = tqdm.tqdm(train_loader, desc='Loading train data')
206 | else:
207 | loader = train_loader
208 | for batch_idx, (x, x_lengths, bert, spec, spec_lengths, y, y_lengths) in enumerate(loader):
209 | x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(
210 | rank, non_blocking=True
211 | )
212 | spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
213 | rank, non_blocking=True
214 | )
215 | y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
216 | rank, non_blocking=True
217 | )
218 | bert = bert.cuda(rank, non_blocking=True)
219 |
220 | with autocast(enabled=hps.train.fp16_run):
221 | y_hat, l_length, attn, ids_slice, x_mask, z_mask, \
222 | (z, z_p, z_r, m_p, logs_p, m_q, logs_q) = net_g(x, x_lengths, bert, spec, spec_lengths)
223 |
224 | mel = spec_to_mel_torch(
225 | spec,
226 | hps.data.filter_length,
227 | hps.data.n_mel_channels,
228 | hps.data.sampling_rate,
229 | hps.data.mel_fmin,
230 | hps.data.mel_fmax,
231 | )
232 | y_mel = commons.slice_segments(
233 | mel, ids_slice, hps.train.segment_size // hps.data.hop_length
234 | )
235 | y_hat_mel = mel_spectrogram_torch(
236 | y_hat.squeeze(1),
237 | hps.data.filter_length,
238 | hps.data.n_mel_channels,
239 | hps.data.sampling_rate,
240 | hps.data.hop_length,
241 | hps.data.win_length,
242 | hps.data.mel_fmin,
243 | hps.data.mel_fmax,
244 | )
245 |
246 | y = commons.slice_segments(
247 | y, ids_slice * hps.data.hop_length, hps.train.segment_size
248 | ) # slice
249 |
250 | # Discriminator
251 | y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
252 | with autocast(enabled=False):
253 | loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
254 | y_d_hat_r, y_d_hat_g
255 | )
256 | loss_disc_all = loss_disc
257 | optim_d.zero_grad()
258 | scaler.scale(loss_disc_all).backward()
259 | scaler.unscale_(optim_d)
260 | grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
261 | scaler.step(optim_d)
262 |
263 | with autocast(enabled=hps.train.fp16_run):
264 | # Generator
265 | y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
266 | with autocast(enabled=False):
267 | loss_dur = torch.sum(l_length.float())
268 | loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
269 | loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
270 | if z_r == None:
271 | loss_kl_r = 0
272 | else:
273 | loss_kl_r = kl_loss(z_r, logs_p, m_q, logs_q, z_mask) * hps.train.c_kl
274 | loss_fm = feature_loss(fmap_r, fmap_g)
275 | loss_gen, losses_gen = generator_loss(y_d_hat_g)
276 | loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl + loss_kl_r
277 | optim_g.zero_grad()
278 | scaler.scale(loss_gen_all).backward()
279 | scaler.unscale_(optim_g)
280 | grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
281 | scaler.step(optim_g)
282 | scaler.update()
283 |
284 | if rank == 0:
285 | if global_step % hps.train.log_interval == 0:
286 | lr = optim_g.param_groups[0]["lr"]
287 | losses = [
288 | loss_disc,
289 | loss_gen,
290 | loss_fm,
291 | loss_mel,
292 | loss_dur,
293 | loss_kl,
294 | loss_kl_r,
295 | ]
296 | logger.info(
297 | "Train Epoch: {} [{:.0f}%]".format(
298 | epoch, 100.0 * batch_idx / len(train_loader)
299 | )
300 | )
301 | logger.info([global_step, lr])
302 | logger.info(
303 | f"loss_disc={loss_disc:.3f}, loss_gen={loss_gen:.3f}, loss_fm={loss_fm:.3f}"
304 | )
305 | logger.info(
306 | f"loss_mel={loss_mel:.3f}, loss_dur={loss_dur:.3f}, loss_kl={loss_kl:.3f}"
307 | )
308 | logger.info(
309 | f"loss_kl_r={loss_kl_r:.3f}"
310 | )
311 |
312 | scalar_dict = {
313 | "loss/g/total": loss_gen_all,
314 | "loss/d/total": loss_disc_all,
315 | "learning_rate": lr,
316 | "grad_norm_d": grad_norm_d,
317 | "grad_norm_g": grad_norm_g,
318 | }
319 | scalar_dict.update(
320 | {
321 | "loss/g/fm": loss_fm,
322 | "loss/g/mel": loss_mel,
323 | "loss/g/dur": loss_dur,
324 | "loss/g/kl": loss_kl,
325 | "loss/g/kl_r": loss_kl_r,
326 | }
327 | )
328 |
329 | scalar_dict.update(
330 | {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}
331 | )
332 | scalar_dict.update(
333 | {"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}
334 | )
335 | scalar_dict.update(
336 | {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
337 | )
338 | image_dict = {
339 | "slice/mel_org": utils.plot_spectrogram_to_numpy(
340 | y_mel[0].data.cpu().numpy()
341 | ),
342 | "slice/mel_gen": utils.plot_spectrogram_to_numpy(
343 | y_hat_mel[0].data.cpu().numpy()
344 | ),
345 | "all/mel": utils.plot_spectrogram_to_numpy(
346 | mel[0].data.cpu().numpy()
347 | ),
348 | "all/attn": utils.plot_alignment_to_numpy(
349 | attn[0, 0].data.cpu().numpy()
350 | ),
351 | }
352 | utils.summarize(
353 | writer=writer,
354 | global_step=global_step,
355 | images=image_dict,
356 | scalars=scalar_dict,
357 | )
358 |
359 | if global_step % hps.train.eval_interval == 0:
360 | evaluate(hps, net_g, eval_loader, writer_eval)
361 | utils.save_checkpoint(
362 | net_g,
363 | optim_g,
364 | hps.train.learning_rate,
365 | epoch,
366 | os.path.join(hps.model_dir, "G_{}.pth".format(global_step)),
367 | )
368 | utils.save_checkpoint(
369 | net_d,
370 | optim_d,
371 | hps.train.learning_rate,
372 | epoch,
373 | os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
374 | )
375 | global_step += 1
376 |
377 | if rank == 0:
378 | logger.info("====> Epoch: {}".format(epoch))
379 |
380 |
381 | def evaluate(hps, generator, eval_loader, writer_eval):
382 | generator.eval()
383 | with torch.no_grad():
384 | for batch_idx, (x, x_lengths, bert, spec, spec_lengths, y, y_lengths) in enumerate(eval_loader):
385 | x, x_lengths = x.cuda(0), x_lengths.cuda(0)
386 | spec, spec_lengths = spec.cuda(0), spec_lengths.cuda(0)
387 | y, y_lengths = y.cuda(0), y_lengths.cuda(0)
388 | bert = bert.cuda(0)
389 |
390 | # remove else
391 | x = x[:1]
392 | x_lengths = x_lengths[:1]
393 | spec = spec[:1]
394 | spec_lengths = spec_lengths[:1]
395 | y = y[:1]
396 | y_lengths = y_lengths[:1]
397 | break
398 | y_hat, attn, mask, *_ = generator.module.infer(x, x_lengths, bert, max_len=1000)
399 | y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
400 |
401 | mel = spec_to_mel_torch(
402 | spec,
403 | hps.data.filter_length,
404 | hps.data.n_mel_channels,
405 | hps.data.sampling_rate,
406 | hps.data.mel_fmin,
407 | hps.data.mel_fmax,
408 | )
409 | y_hat_mel = mel_spectrogram_torch(
410 | y_hat.squeeze(1).float(),
411 | hps.data.filter_length,
412 | hps.data.n_mel_channels,
413 | hps.data.sampling_rate,
414 | hps.data.hop_length,
415 | hps.data.win_length,
416 | hps.data.mel_fmin,
417 | hps.data.mel_fmax,
418 | )
419 | image_dict = {
420 | f"gen/mel_{global_step}": utils.plot_spectrogram_to_numpy(
421 | y_hat_mel[0].cpu().numpy()
422 | )
423 | }
424 | audio_dict = {f"gen/audio_{global_step}": y_hat[0, :, : y_hat_lengths[0]]}
425 | if global_step == 0:
426 | image_dict.update(
427 | {"gt/mel": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy())}
428 | )
429 | audio_dict.update({"gt/audio": y[0, :, : y_lengths[0]]})
430 |
431 | utils.summarize(
432 | writer=writer_eval,
433 | global_step=global_step,
434 | images=image_dict,
435 | audios=audio_dict,
436 | audio_sampling_rate=hps.data.sampling_rate,
437 | )
438 | generator.train()
439 |
440 |
441 | if __name__ == "__main__":
442 | main()
443 |
--------------------------------------------------------------------------------
/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 |
4 | import numpy as np
5 |
6 |
7 | DEFAULT_MIN_BIN_WIDTH = 1e-3
8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3
9 | DEFAULT_MIN_DERIVATIVE = 1e-3
10 |
11 |
12 | def piecewise_rational_quadratic_transform(
13 | inputs,
14 | unnormalized_widths,
15 | unnormalized_heights,
16 | unnormalized_derivatives,
17 | inverse=False,
18 | tails=None,
19 | tail_bound=1.0,
20 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
21 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
22 | min_derivative=DEFAULT_MIN_DERIVATIVE,
23 | ):
24 |
25 | if tails is None:
26 | spline_fn = rational_quadratic_spline
27 | spline_kwargs = {}
28 | else:
29 | spline_fn = unconstrained_rational_quadratic_spline
30 | spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
31 |
32 | outputs, logabsdet = spline_fn(
33 | inputs=inputs,
34 | unnormalized_widths=unnormalized_widths,
35 | unnormalized_heights=unnormalized_heights,
36 | unnormalized_derivatives=unnormalized_derivatives,
37 | inverse=inverse,
38 | min_bin_width=min_bin_width,
39 | min_bin_height=min_bin_height,
40 | min_derivative=min_derivative,
41 | **spline_kwargs
42 | )
43 | return outputs, logabsdet
44 |
45 |
46 | def searchsorted(bin_locations, inputs, eps=1e-6):
47 | bin_locations[..., -1] += eps
48 | return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
49 |
50 |
51 | def unconstrained_rational_quadratic_spline(
52 | inputs,
53 | unnormalized_widths,
54 | unnormalized_heights,
55 | unnormalized_derivatives,
56 | inverse=False,
57 | tails="linear",
58 | tail_bound=1.0,
59 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
60 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
61 | min_derivative=DEFAULT_MIN_DERIVATIVE,
62 | ):
63 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
64 | outside_interval_mask = ~inside_interval_mask
65 |
66 | outputs = torch.zeros_like(inputs)
67 | logabsdet = torch.zeros_like(inputs)
68 |
69 | if tails == "linear":
70 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
71 | constant = np.log(np.exp(1 - min_derivative) - 1)
72 | unnormalized_derivatives[..., 0] = constant
73 | unnormalized_derivatives[..., -1] = constant
74 |
75 | outputs[outside_interval_mask] = inputs[outside_interval_mask]
76 | logabsdet[outside_interval_mask] = 0
77 | else:
78 | raise RuntimeError("{} tails are not implemented.".format(tails))
79 |
80 | (
81 | outputs[inside_interval_mask],
82 | logabsdet[inside_interval_mask],
83 | ) = rational_quadratic_spline(
84 | inputs=inputs[inside_interval_mask],
85 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
86 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
87 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
88 | inverse=inverse,
89 | left=-tail_bound,
90 | right=tail_bound,
91 | bottom=-tail_bound,
92 | top=tail_bound,
93 | min_bin_width=min_bin_width,
94 | min_bin_height=min_bin_height,
95 | min_derivative=min_derivative,
96 | )
97 |
98 | return outputs, logabsdet
99 |
100 |
101 | def rational_quadratic_spline(
102 | inputs,
103 | unnormalized_widths,
104 | unnormalized_heights,
105 | unnormalized_derivatives,
106 | inverse=False,
107 | left=0.0,
108 | right=1.0,
109 | bottom=0.0,
110 | top=1.0,
111 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
112 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
113 | min_derivative=DEFAULT_MIN_DERIVATIVE,
114 | ):
115 | if torch.min(inputs) < left or torch.max(inputs) > right:
116 | raise ValueError("Input to a transform is not within its domain")
117 |
118 | num_bins = unnormalized_widths.shape[-1]
119 |
120 | if min_bin_width * num_bins > 1.0:
121 | raise ValueError("Minimal bin width too large for the number of bins")
122 | if min_bin_height * num_bins > 1.0:
123 | raise ValueError("Minimal bin height too large for the number of bins")
124 |
125 | widths = F.softmax(unnormalized_widths, dim=-1)
126 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
127 | cumwidths = torch.cumsum(widths, dim=-1)
128 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
129 | cumwidths = (right - left) * cumwidths + left
130 | cumwidths[..., 0] = left
131 | cumwidths[..., -1] = right
132 | widths = cumwidths[..., 1:] - cumwidths[..., :-1]
133 |
134 | derivatives = min_derivative + F.softplus(unnormalized_derivatives)
135 |
136 | heights = F.softmax(unnormalized_heights, dim=-1)
137 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
138 | cumheights = torch.cumsum(heights, dim=-1)
139 | cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
140 | cumheights = (top - bottom) * cumheights + bottom
141 | cumheights[..., 0] = bottom
142 | cumheights[..., -1] = top
143 | heights = cumheights[..., 1:] - cumheights[..., :-1]
144 |
145 | if inverse:
146 | bin_idx = searchsorted(cumheights, inputs)[..., None]
147 | else:
148 | bin_idx = searchsorted(cumwidths, inputs)[..., None]
149 |
150 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
151 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
152 |
153 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
154 | delta = heights / widths
155 | input_delta = delta.gather(-1, bin_idx)[..., 0]
156 |
157 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
158 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
159 |
160 | input_heights = heights.gather(-1, bin_idx)[..., 0]
161 |
162 | if inverse:
163 | a = (inputs - input_cumheights) * (
164 | input_derivatives + input_derivatives_plus_one - 2 * input_delta
165 | ) + input_heights * (input_delta - input_derivatives)
166 | b = input_heights * input_derivatives - (inputs - input_cumheights) * (
167 | input_derivatives + input_derivatives_plus_one - 2 * input_delta
168 | )
169 | c = -input_delta * (inputs - input_cumheights)
170 |
171 | discriminant = b.pow(2) - 4 * a * c
172 | assert (discriminant >= 0).all()
173 |
174 | root = (2 * c) / (-b - torch.sqrt(discriminant))
175 | outputs = root * input_bin_widths + input_cumwidths
176 |
177 | theta_one_minus_theta = root * (1 - root)
178 | denominator = input_delta + (
179 | (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
180 | * theta_one_minus_theta
181 | )
182 | derivative_numerator = input_delta.pow(2) * (
183 | input_derivatives_plus_one * root.pow(2)
184 | + 2 * input_delta * theta_one_minus_theta
185 | + input_derivatives * (1 - root).pow(2)
186 | )
187 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
188 |
189 | return outputs, -logabsdet
190 | else:
191 | theta = (inputs - input_cumwidths) / input_bin_widths
192 | theta_one_minus_theta = theta * (1 - theta)
193 |
194 | numerator = input_heights * (
195 | input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
196 | )
197 | denominator = input_delta + (
198 | (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
199 | * theta_one_minus_theta
200 | )
201 | outputs = input_cumheights + numerator / denominator
202 |
203 | derivative_numerator = input_delta.pow(2) * (
204 | input_derivatives_plus_one * theta.pow(2)
205 | + 2 * input_delta * theta_one_minus_theta
206 | + input_derivatives * (1 - theta).pow(2)
207 | )
208 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
209 |
210 | return outputs, logabsdet
211 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import sys
4 | import argparse
5 | import logging
6 | import json
7 | import subprocess
8 | import numpy as np
9 | from scipy.io.wavfile import read
10 | import torch
11 |
12 | MATPLOTLIB_FLAG = False
13 |
14 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
15 | logger = logging
16 |
17 |
18 | def load_class(full_class_name):
19 | cls = None
20 | if full_class_name in globals():
21 | cls = globals()[full_class_name]
22 | else:
23 | if "." in full_class_name:
24 | import importlib
25 | module_name, cls_name = full_class_name.rsplit('.', 1)
26 | mod = importlib.import_module(module_name)
27 | cls = (getattr(mod, cls_name))
28 | return cls
29 |
30 |
31 | def load_teacher(checkpoint_path, model):
32 | assert os.path.isfile(checkpoint_path)
33 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
34 | saved_state_dict = checkpoint_dict['model']
35 | if hasattr(model, 'module'):
36 | state_dict = model.module.state_dict()
37 | else:
38 | state_dict = model.state_dict()
39 | new_state_dict = {}
40 | for k, v in state_dict.items():
41 | if k.startswith('enc_q') or k.startswith('flow'):
42 | new_state_dict[k] = saved_state_dict[k]
43 | else:
44 | new_state_dict[k] = v
45 | if hasattr(model, 'module'):
46 | model.module.load_state_dict(new_state_dict)
47 | else:
48 | model.load_state_dict(new_state_dict)
49 | return model
50 |
51 |
52 | def load_checkpoint(checkpoint_path, model, optimizer=None):
53 | assert os.path.isfile(checkpoint_path)
54 | checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
55 | iteration = checkpoint_dict["iteration"]
56 | learning_rate = checkpoint_dict["learning_rate"]
57 | if optimizer is not None:
58 | optimizer.load_state_dict(checkpoint_dict["optimizer"])
59 | saved_state_dict = checkpoint_dict["model"]
60 | if hasattr(model, "module"):
61 | state_dict = model.module.state_dict()
62 | else:
63 | state_dict = model.state_dict()
64 | new_state_dict = {}
65 | for k, v in state_dict.items():
66 | try:
67 | new_state_dict[k] = saved_state_dict[k]
68 | except:
69 | logger.info("%s is not in the checkpoint" % k)
70 | new_state_dict[k] = v
71 | if hasattr(model, "module"):
72 | model.module.load_state_dict(new_state_dict)
73 | else:
74 | model.load_state_dict(new_state_dict)
75 | logger.info(
76 | "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
77 | )
78 | return model, optimizer, learning_rate, iteration
79 |
80 |
81 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
82 | logger.info(
83 | "Saving model and optimizer state at iteration {} to {}".format(
84 | iteration, checkpoint_path
85 | )
86 | )
87 | if hasattr(model, "module"):
88 | state_dict = model.module.state_dict()
89 | else:
90 | state_dict = model.state_dict()
91 | torch.save(
92 | {
93 | "model": state_dict,
94 | "iteration": iteration,
95 | "optimizer": optimizer.state_dict(),
96 | "learning_rate": learning_rate,
97 | },
98 | checkpoint_path,
99 | )
100 |
101 |
102 | def load_model(checkpoint_path, model):
103 | assert os.path.isfile(checkpoint_path)
104 | checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
105 | saved_state_dict = checkpoint_dict["model"]
106 | if hasattr(model, "module"):
107 | state_dict = model.module.state_dict()
108 | else:
109 | state_dict = model.state_dict()
110 | new_state_dict = {}
111 | for k, v in state_dict.items():
112 | try:
113 | new_state_dict[k] = saved_state_dict[k]
114 | except:
115 | logger.info("%s is not in the checkpoint" % k)
116 | new_state_dict[k] = v
117 | if hasattr(model, "module"):
118 | model.module.load_state_dict(new_state_dict)
119 | else:
120 | model.load_state_dict(new_state_dict)
121 | return model
122 |
123 |
124 | def save_model(model, checkpoint_path):
125 | if hasattr(model, 'module'):
126 | state_dict = model.module.state_dict()
127 | else:
128 | state_dict = model.state_dict()
129 | torch.save({'model': state_dict}, checkpoint_path)
130 |
131 |
132 | def summarize(
133 | writer,
134 | global_step,
135 | scalars={},
136 | histograms={},
137 | images={},
138 | audios={},
139 | audio_sampling_rate=22050,
140 | ):
141 | for k, v in scalars.items():
142 | writer.add_scalar(k, v, global_step)
143 | for k, v in histograms.items():
144 | writer.add_histogram(k, v, global_step)
145 | for k, v in images.items():
146 | writer.add_image(k, v, global_step, dataformats="HWC")
147 | for k, v in audios.items():
148 | writer.add_audio(k, v, global_step, audio_sampling_rate)
149 |
150 |
151 | def latest_checkpoint_path(dir_path, regex="G_*.pth"):
152 | f_list = glob.glob(os.path.join(dir_path, regex))
153 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
154 | x = f_list[-1]
155 | print(x)
156 | return x
157 |
158 |
159 | def plot_spectrogram_to_numpy(spectrogram):
160 | global MATPLOTLIB_FLAG
161 | if not MATPLOTLIB_FLAG:
162 | import matplotlib
163 |
164 | matplotlib.use("Agg")
165 | MATPLOTLIB_FLAG = True
166 | mpl_logger = logging.getLogger("matplotlib")
167 | mpl_logger.setLevel(logging.WARNING)
168 | import matplotlib.pylab as plt
169 | import numpy as np
170 |
171 | fig, ax = plt.subplots(figsize=(10, 2))
172 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
173 | plt.colorbar(im, ax=ax)
174 | plt.xlabel("Frames")
175 | plt.ylabel("Channels")
176 | plt.tight_layout()
177 |
178 | fig.canvas.draw()
179 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
180 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
181 | plt.close()
182 | return data
183 |
184 |
185 | def plot_alignment_to_numpy(alignment, info=None):
186 | global MATPLOTLIB_FLAG
187 | if not MATPLOTLIB_FLAG:
188 | import matplotlib
189 |
190 | matplotlib.use("Agg")
191 | MATPLOTLIB_FLAG = True
192 | mpl_logger = logging.getLogger("matplotlib")
193 | mpl_logger.setLevel(logging.WARNING)
194 | import matplotlib.pylab as plt
195 | import numpy as np
196 |
197 | fig, ax = plt.subplots(figsize=(6, 4))
198 | im = ax.imshow(
199 | alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
200 | )
201 | fig.colorbar(im, ax=ax)
202 | xlabel = "Decoder timestep"
203 | if info is not None:
204 | xlabel += "\n\n" + info
205 | plt.xlabel(xlabel)
206 | plt.ylabel("Encoder timestep")
207 | plt.tight_layout()
208 |
209 | fig.canvas.draw()
210 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
211 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
212 | plt.close()
213 | return data
214 |
215 |
216 | def load_wav_to_torch(full_path):
217 | sampling_rate, data = read(full_path)
218 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate
219 |
220 |
221 | def load_filepaths_and_text(filename, split="|"):
222 | with open(filename, encoding="utf-8") as f:
223 | filepaths_and_text = []
224 | for line in f:
225 | path_text = line.strip().split(split)
226 | filepaths_and_text.append(path_text)
227 | return filepaths_and_text
228 |
229 |
230 | def get_hparams(init=True):
231 | parser = argparse.ArgumentParser()
232 | parser.add_argument(
233 | "-c",
234 | "--config",
235 | type=str,
236 | default="./configs/bert_vits.json",
237 | help="JSON file for configuration",
238 | )
239 | parser.add_argument("-m", "--model", type=str, required=True, help="Model name")
240 |
241 | args = parser.parse_args()
242 | model_dir = os.path.join("./logs", args.model)
243 |
244 | if not os.path.exists(model_dir):
245 | os.makedirs(model_dir)
246 |
247 | config_path = args.config
248 | config_save_path = os.path.join(model_dir, "config.json")
249 | if init:
250 | with open(config_path, "r") as f:
251 | data = f.read()
252 | with open(config_save_path, "w") as f:
253 | f.write(data)
254 | else:
255 | with open(config_save_path, "r") as f:
256 | data = f.read()
257 | config = json.loads(data)
258 |
259 | hparams = HParams(**config)
260 | hparams.model_dir = model_dir
261 | return hparams
262 |
263 |
264 | def get_hparams_from_dir(model_dir):
265 | config_save_path = os.path.join(model_dir, "config.json")
266 | with open(config_save_path, "r") as f:
267 | data = f.read()
268 | config = json.loads(data)
269 |
270 | hparams = HParams(**config)
271 | hparams.model_dir = model_dir
272 | return hparams
273 |
274 |
275 | def get_hparams_from_file(config_path):
276 | with open(config_path, "r") as f:
277 | data = f.read()
278 | config = json.loads(data)
279 |
280 | hparams = HParams(**config)
281 | return hparams
282 |
283 |
284 | def check_git_hash(model_dir):
285 | source_dir = os.path.dirname(os.path.realpath(__file__))
286 | if not os.path.exists(os.path.join(source_dir, ".git")):
287 | logger.warn(
288 | "{} is not a git repository, therefore hash value comparison will be ignored.".format(
289 | source_dir
290 | )
291 | )
292 | return
293 |
294 | cur_hash = subprocess.getoutput("git rev-parse HEAD")
295 |
296 | path = os.path.join(model_dir, "githash")
297 | if os.path.exists(path):
298 | saved_hash = open(path).read()
299 | if saved_hash != cur_hash:
300 | logger.warn(
301 | "git hash values are different. {}(saved) != {}(current)".format(
302 | saved_hash[:8], cur_hash[:8]
303 | )
304 | )
305 | else:
306 | open(path, "w").write(cur_hash)
307 |
308 |
309 | def get_logger(model_dir, filename="train.log"):
310 | global logger
311 | logger = logging.getLogger(os.path.basename(model_dir))
312 | logger.setLevel(logging.DEBUG)
313 |
314 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
315 | if not os.path.exists(model_dir):
316 | os.makedirs(model_dir)
317 | h = logging.FileHandler(os.path.join(model_dir, filename))
318 | h.setLevel(logging.DEBUG)
319 | h.setFormatter(formatter)
320 | logger.addHandler(h)
321 | return logger
322 |
323 |
324 | class HParams:
325 | def __init__(self, **kwargs):
326 | for k, v in kwargs.items():
327 | if type(v) == dict:
328 | v = HParams(**v)
329 | self[k] = v
330 |
331 | def keys(self):
332 | return self.__dict__.keys()
333 |
334 | def items(self):
335 | return self.__dict__.items()
336 |
337 | def values(self):
338 | return self.__dict__.values()
339 |
340 | def __len__(self):
341 | return len(self.__dict__)
342 |
343 | def __getitem__(self, key):
344 | return getattr(self, key)
345 |
346 | def __setitem__(self, key, value):
347 | return setattr(self, key, value)
348 |
349 | def __contains__(self, key):
350 | return key in self.__dict__
351 |
352 | def __repr__(self):
353 | return self.__dict__.__repr__()
354 |
--------------------------------------------------------------------------------
/vits_infer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import numpy as np
4 |
5 | import torch
6 | import utils
7 | import argparse
8 |
9 | from scipy.io import wavfile
10 | from text.symbols import symbols
11 | from text import cleaned_text_to_sequence
12 | from vits_pinyin import VITS_PinYin
13 |
14 | parser = argparse.ArgumentParser(description='Inference code for bert vits models')
15 | parser.add_argument('--config', type=str, required=True)
16 | parser.add_argument('--model', type=str, required=True)
17 | args = parser.parse_args()
18 |
19 | def save_wav(wav, path, rate):
20 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) * 0.6
21 | wavfile.write(path, rate, wav.astype(np.int16))
22 |
23 | # device
24 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25 | device = torch.device("cpu")
26 |
27 | # pinyin
28 | tts_front = VITS_PinYin("./bert", device)
29 |
30 | # config
31 | hps = utils.get_hparams_from_file(args.config)
32 |
33 | # model
34 | net_g = utils.load_class(hps.train.eval_class)(
35 | len(symbols),
36 | hps.data.filter_length // 2 + 1,
37 | hps.train.segment_size // hps.data.hop_length,
38 | **hps.model)
39 |
40 | # model_path = "logs/bert_vits/G_200000.pth"
41 | # utils.save_model(net_g, "vits_bert_model.pth")
42 | # model_path = "vits_bert_model.pth"
43 | utils.load_model(args.model, net_g)
44 | net_g.eval()
45 | net_g.to(device)
46 |
47 | os.makedirs("./vits_infer_out/", exist_ok=True)
48 | if __name__ == "__main__":
49 | n = 0
50 | fo = open("vits_infer_item.txt", "r+", encoding='utf-8')
51 | while (True):
52 | try:
53 | item = fo.readline().strip()
54 | except Exception as e:
55 | print('nothing of except:', e)
56 | break
57 | if (item == None or item == ""):
58 | break
59 | n = n + 1
60 | phonemes, char_embeds = tts_front.chinese_to_phonemes(item)
61 | input_ids = cleaned_text_to_sequence(phonemes)
62 | with torch.no_grad():
63 | x_tst = torch.LongTensor(input_ids).unsqueeze(0).to(device)
64 | x_tst_lengths = torch.LongTensor([len(input_ids)]).to(device)
65 | x_tst_prosody = torch.FloatTensor(char_embeds).unsqueeze(0).to(device)
66 | audio = net_g.infer(x_tst, x_tst_lengths, x_tst_prosody, noise_scale=0.5,
67 | length_scale=1)[0][0, 0].data.cpu().float().numpy()
68 | save_wav(audio, f"./vits_infer_out/bert_vits_{n}.wav", hps.data.sampling_rate)
69 | fo.close()
70 |
--------------------------------------------------------------------------------
/vits_infer_item.txt:
--------------------------------------------------------------------------------
1 | 遥望星空作文独自坐在乡间的小丘上,看着阳光渐渐变暗,听着鸟鸣渐渐变弱,触着清风渐渐变凉
2 | 时光总是慢慢地偷走我们的容颜,渐渐地有些人终将离我们而远去
3 | 白色的樱花纯洁高尚,红色的樱花热烈奔放,绿色的樱花清晰澹雅,花开的美丽与快乐,花落的烂漫与潇洒都蕴藏着樱花的人生智慧
--------------------------------------------------------------------------------
/vits_infer_no_bert.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import numpy as np
4 |
5 | import torch
6 | import utils
7 | import argparse
8 |
9 | from scipy.io import wavfile
10 | from text.symbols import symbols
11 | from text import cleaned_text_to_sequence
12 | from vits_pinyin import VITS_PinYin
13 |
14 | parser = argparse.ArgumentParser(description='Inference code for bert vits models')
15 | parser.add_argument('--config', type=str, required=True)
16 | parser.add_argument('--model', type=str, required=True)
17 | args = parser.parse_args()
18 |
19 | def save_wav(wav, path, rate):
20 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) * 0.6
21 | wavfile.write(path, rate, wav.astype(np.int16))
22 |
23 | # device
24 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25 | device = torch.device("cpu")
26 |
27 | # pinyin
28 | tts_front = VITS_PinYin("./bert", device, hasBert=False)
29 |
30 | # config
31 | hps = utils.get_hparams_from_file(args.config)
32 |
33 | # model
34 | net_g = utils.load_class(hps.train.eval_class)(
35 | len(symbols),
36 | hps.data.filter_length // 2 + 1,
37 | hps.train.segment_size // hps.data.hop_length,
38 | **hps.model)
39 |
40 | # model_path = "logs/bert_vits/G_200000.pth"
41 | # utils.save_model(net_g, "vits_bert_model.pth")
42 | # model_path = "vits_bert_model.pth"
43 | utils.load_model(args.model, net_g)
44 | net_g.eval()
45 | net_g.to(device)
46 |
47 | os.makedirs("./vits_infer_out/", exist_ok=True)
48 | if __name__ == "__main__":
49 | n = 0
50 | fo = open("vits_infer_item.txt", "r+", encoding='utf-8')
51 | while (True):
52 | try:
53 | item = fo.readline().strip()
54 | except Exception as e:
55 | print('nothing of except:', e)
56 | break
57 | if (item == None or item == ""):
58 | break
59 | n = n + 1
60 | phonemes, _ = tts_front.chinese_to_phonemes(item)
61 | input_ids = cleaned_text_to_sequence(phonemes)
62 | with torch.no_grad():
63 | x_tst = torch.LongTensor(input_ids).unsqueeze(0).to(device)
64 | x_tst_lengths = torch.LongTensor([len(input_ids)]).to(device)
65 | audio = net_g.infer(x_tst, x_tst_lengths, bert=None, noise_scale=0.5,
66 | length_scale=1)[0][0, 0].data.cpu().float().numpy()
67 | save_wav(audio, f"./vits_infer_out/bert_vits_no_bert_{n}.wav", hps.data.sampling_rate)
68 | fo.close()
69 |
--------------------------------------------------------------------------------
/vits_infer_onnx.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
3 |
4 | import onnxruntime
5 | import soundfile
6 | import torch
7 | import os
8 | import torch
9 | import argparse
10 |
11 | from text import cleaned_text_to_sequence
12 | from vits_pinyin import VITS_PinYin
13 |
14 |
15 | def display(sess):
16 | for i in sess.get_inputs():
17 | print(i)
18 |
19 | print("-" * 10)
20 | for o in sess.get_outputs():
21 | print(o)
22 |
23 |
24 | class OnnxModel:
25 | def __init__(
26 | self,
27 | model: str,
28 | ):
29 | session_opts = onnxruntime.SessionOptions()
30 | session_opts.inter_op_num_threads = 1
31 | session_opts.intra_op_num_threads = 4
32 |
33 | self.session_opts = session_opts
34 |
35 | self.model = onnxruntime.InferenceSession(
36 | model,
37 | sess_options=self.session_opts,
38 | )
39 | display(self.model)
40 |
41 | meta = self.model.get_modelmeta().custom_metadata_map
42 | self.add_blank = int(meta["add_blank"])
43 | self.sample_rate = int(meta["sample_rate"])
44 | print(meta)
45 |
46 | def __call__(self, x: torch.Tensor):
47 | """
48 | Args:
49 | x:
50 | A int64 tensor of shape (L,)
51 | """
52 | x = x.unsqueeze(0)
53 | x_length = torch.tensor([x.shape[1]], dtype=torch.int64)
54 | noise_scale = torch.tensor([1], dtype=torch.float32)
55 | length_scale = torch.tensor([1], dtype=torch.float32)
56 |
57 | y = self.model.run(
58 | [
59 | self.model.get_outputs()[0].name,
60 | ],
61 | {
62 | self.model.get_inputs()[0].name: x.numpy(),
63 | self.model.get_inputs()[1].name: x_length.numpy(),
64 | self.model.get_inputs()[2].name: noise_scale.numpy(),
65 | self.model.get_inputs()[3].name: length_scale.numpy(),
66 | },
67 | )[0]
68 | return y
69 |
70 |
71 | def main():
72 | parser = argparse.ArgumentParser(
73 | description='Inference code for bert vits models')
74 | parser.add_argument('--model', type=str, required=True)
75 | args = parser.parse_args()
76 | print("Onnx model path:", args.model)
77 | model = OnnxModel(args.model)
78 |
79 | tts_front = VITS_PinYin(None, None, hasBert=False)
80 |
81 | os.makedirs("./vits_infer_out/", exist_ok=True)
82 |
83 | n = 0
84 | fo = open("vits_infer_item.txt", "r+", encoding='utf-8')
85 | while (True):
86 | try:
87 | item = fo.readline().strip()
88 | except Exception as e:
89 | print('nothing of except:', e)
90 | break
91 | if (item == None or item == ""):
92 | break
93 | n = n + 1
94 | phonemes, _ = tts_front.chinese_to_phonemes(item)
95 | input_ids = cleaned_text_to_sequence(phonemes)
96 |
97 | x = torch.tensor(input_ids, dtype=torch.int64)
98 | y = model(x)
99 |
100 | soundfile.write(
101 | f"./vits_infer_out/onnx_{n}.wav", y, model.sample_rate)
102 |
103 | fo.close()
104 |
105 |
106 | if __name__ == "__main__":
107 | main()
108 |
--------------------------------------------------------------------------------
/vits_infer_onnx_stream.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
3 |
4 | import onnxruntime
5 | import soundfile
6 | import torch
7 | import os
8 | import torch
9 | import argparse
10 | import datetime
11 | import numpy
12 |
13 | from text import cleaned_text_to_sequence
14 | from vits_pinyin import VITS_PinYin
15 |
16 |
17 | def display(sess):
18 | for i in sess.get_inputs():
19 | print(i)
20 |
21 | print("-" * 10)
22 | for o in sess.get_outputs():
23 | print(o)
24 |
25 |
26 | class OnnxModel_Encoder:
27 | def __init__(
28 | self,
29 | model: str,
30 | ):
31 | session_opts = onnxruntime.SessionOptions()
32 | session_opts.inter_op_num_threads = 1
33 | session_opts.intra_op_num_threads = 4
34 |
35 | self.session_opts = session_opts
36 |
37 | self.model = onnxruntime.InferenceSession(
38 | model,
39 | sess_options=self.session_opts,
40 | )
41 | display(self.model)
42 |
43 | meta = self.model.get_modelmeta().custom_metadata_map
44 | self.add_blank = int(meta["add_blank"])
45 | self.sample_rate = int(meta["sample_rate"])
46 | print(meta)
47 |
48 | def __call__(self, x: torch.Tensor):
49 | """
50 | Args:
51 | x:
52 | A int64 tensor of shape (L,)
53 | """
54 | x = x.unsqueeze(0)
55 | x_length = torch.tensor([x.shape[1]], dtype=torch.int64)
56 | noise_scale = torch.tensor([1], dtype=torch.float32)
57 | length_scale = torch.tensor([1], dtype=torch.float32)
58 |
59 | z_p, y_mask = self.model.run(
60 | [
61 | self.model.get_outputs()[0].name,
62 | self.model.get_outputs()[1].name,
63 | ],
64 | {
65 | self.model.get_inputs()[0].name: x.numpy(),
66 | self.model.get_inputs()[1].name: x_length.numpy(),
67 | self.model.get_inputs()[2].name: noise_scale.numpy(),
68 | self.model.get_inputs()[3].name: length_scale.numpy(),
69 | },
70 | )
71 | return z_p, y_mask
72 |
73 |
74 | class OnnxModel_Decoder:
75 | def __init__(
76 | self,
77 | model: str,
78 | ):
79 | session_opts = onnxruntime.SessionOptions()
80 | session_opts.inter_op_num_threads = 1
81 | session_opts.intra_op_num_threads = 4
82 |
83 | self.session_opts = session_opts
84 |
85 | self.model = onnxruntime.InferenceSession(
86 | model,
87 | sess_options=self.session_opts,
88 | )
89 | display(self.model)
90 |
91 | meta = self.model.get_modelmeta().custom_metadata_map
92 | self.hop_length = int(meta["hop_length"])
93 | print(meta)
94 |
95 | def __call__(self, z_p, y_mask):
96 | y = self.model.run(
97 | [
98 | self.model.get_outputs()[0].name,
99 | ],
100 | {
101 | self.model.get_inputs()[0].name: z_p,
102 | self.model.get_inputs()[1].name: y_mask,
103 | },
104 | )[0]
105 | return y
106 |
107 |
108 | def main_debug():
109 | parser = argparse.ArgumentParser(
110 | description='Inference code for bert vits models')
111 | parser.add_argument('--encoder', type=str, required=True)
112 | parser.add_argument('--decoder', type=str, required=True)
113 | args = parser.parse_args()
114 | print("Onnx model path:", args.encoder)
115 | print("Onnx model path:", args.decoder)
116 |
117 | encoder = OnnxModel_Encoder(args.encoder)
118 | decoder = OnnxModel_Decoder(args.decoder)
119 |
120 | tts_front = VITS_PinYin(None, None, hasBert=False)
121 |
122 | os.makedirs("./vits_infer_out/", exist_ok=True)
123 |
124 | n = 0
125 | fo = open("vits_infer_item.txt", "r+", encoding='utf-8')
126 | while (True):
127 | try:
128 | item = fo.readline().strip()
129 | except Exception as e:
130 | print('nothing of except:', e)
131 | break
132 | if (item == None or item == ""):
133 | break
134 | n = n + 1
135 | print(n)
136 | print(datetime.datetime.now())
137 | phonemes, _ = tts_front.chinese_to_phonemes(item)
138 | input_ids = cleaned_text_to_sequence(phonemes)
139 |
140 | x = torch.tensor(input_ids, dtype=torch.int64)
141 | z_p, y_mask = encoder(x)
142 | y = decoder(z_p, y_mask)
143 | print(datetime.datetime.now())
144 |
145 | soundfile.write(
146 | f"./vits_infer_out/onnx_stream_{n}.wav", y, encoder.sample_rate)
147 |
148 | fo.close()
149 |
150 |
151 | def main():
152 | parser = argparse.ArgumentParser(
153 | description='Inference code for bert vits models')
154 | parser.add_argument('--encoder', type=str, required=True)
155 | parser.add_argument('--decoder', type=str, required=True)
156 | args = parser.parse_args()
157 | print("Onnx model path:", args.encoder)
158 | print("Onnx model path:", args.decoder)
159 |
160 | encoder = OnnxModel_Encoder(args.encoder)
161 | decoder = OnnxModel_Decoder(args.decoder)
162 |
163 | tts_front = VITS_PinYin(None, None, hasBert=False)
164 |
165 | os.makedirs("./vits_infer_out/", exist_ok=True)
166 |
167 | n = 0
168 | fo = open("vits_infer_item.txt", "r+", encoding='utf-8')
169 | while (True):
170 | try:
171 | item = fo.readline().strip()
172 | except Exception as e:
173 | print('nothing of except:', e)
174 | break
175 | if (item == None or item == ""):
176 | break
177 | n = n + 1
178 | print(n)
179 | print(datetime.datetime.now())
180 | phonemes, _ = tts_front.chinese_to_phonemes(item)
181 | input_ids = cleaned_text_to_sequence(phonemes)
182 |
183 | x = torch.tensor(input_ids, dtype=torch.int64)
184 | z_p, y_mask = encoder(x)
185 | print(datetime.datetime.now())
186 | len_z = z_p.shape[2]
187 | print('frame size is: ', len_z)
188 | print('hop_length is: ', decoder.hop_length)
189 | # can not change these parameters
190 | hop_length = decoder.hop_length
191 | hop_frame = 12
192 | hop_sample = hop_frame * hop_length
193 | stream_chunk = 50
194 | stream_index = 0
195 | stream_out_wav = []
196 |
197 | while (stream_index < len_z):
198 | if (stream_index == 0): # start frame
199 | cut_s = stream_index
200 | cut_s_wav = 0
201 | else:
202 | cut_s = stream_index - hop_frame
203 | cut_s_wav = hop_sample
204 |
205 | if (stream_index + stream_chunk > len_z - hop_frame): # end frame
206 | cut_e = stream_index + stream_chunk
207 | cut_e_wav = -1
208 | else:
209 | cut_e = stream_index + stream_chunk + hop_frame
210 | cut_e_wav = -1 * hop_sample
211 |
212 | z_chunk = z_p[:, :, cut_s:cut_e]
213 | m_chunk = y_mask[:, :, cut_s:cut_e]
214 | o_chunk = decoder(z_chunk, m_chunk)
215 | o_chunk = o_chunk[cut_s_wav:cut_e_wav]
216 | stream_out_wav.extend(o_chunk)
217 | stream_index = stream_index + stream_chunk
218 | print(datetime.datetime.now())
219 |
220 | stream_out_wav = numpy.asarray(stream_out_wav)
221 | soundfile.write(
222 | f"./vits_infer_out/onnx_stream_{n}.wav", stream_out_wav, encoder.sample_rate)
223 |
224 | fo.close()
225 |
226 |
227 | if __name__ == "__main__":
228 | main()
229 |
--------------------------------------------------------------------------------
/vits_infer_out/bert_vits_1.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlayVoice/vits_chinese/c5ab514ede52ea1b8b6f0b32eb6b039c53937a22/vits_infer_out/bert_vits_1.wav
--------------------------------------------------------------------------------
/vits_infer_out/bert_vits_2.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlayVoice/vits_chinese/c5ab514ede52ea1b8b6f0b32eb6b039c53937a22/vits_infer_out/bert_vits_2.wav
--------------------------------------------------------------------------------
/vits_infer_out/bert_vits_3.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlayVoice/vits_chinese/c5ab514ede52ea1b8b6f0b32eb6b039c53937a22/vits_infer_out/bert_vits_3.wav
--------------------------------------------------------------------------------
/vits_infer_pause.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import numpy as np
4 |
5 | import torch
6 | import utils
7 | import argparse
8 |
9 | from scipy.io import wavfile
10 | from text.symbols import symbols
11 | from text import cleaned_text_to_sequence
12 | from vits_pinyin import VITS_PinYin
13 |
14 | parser = argparse.ArgumentParser(description='Inference code for bert vits models')
15 | parser.add_argument('--config', type=str, required=True)
16 | parser.add_argument('--model', type=str, required=True)
17 | parser.add_argument('--pause', type=int, required=True)
18 | args = parser.parse_args()
19 |
20 | def save_wav(wav, path, rate):
21 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) * 0.6
22 | wavfile.write(path, rate, wav.astype(np.int16))
23 |
24 | # device
25 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26 | device = torch.device("cpu")
27 |
28 | # pinyin
29 | tts_front = VITS_PinYin("./bert", device)
30 |
31 | # config
32 | hps = utils.get_hparams_from_file(args.config)
33 |
34 | # model
35 | net_g = utils.load_class(hps.train.eval_class)(
36 | len(symbols),
37 | hps.data.filter_length // 2 + 1,
38 | hps.train.segment_size // hps.data.hop_length,
39 | **hps.model)
40 |
41 | # model_path = "logs/bert_vits/G_200000.pth"
42 | # utils.save_model(net_g, "vits_bert_model.pth")
43 | # model_path = "vits_bert_model.pth"
44 | utils.load_model(args.model, net_g)
45 | net_g.eval()
46 | net_g.to(device)
47 |
48 | os.makedirs("./vits_infer_out/", exist_ok=True)
49 | if __name__ == "__main__":
50 | n = 0
51 | fo = open("vits_infer_item.txt", "r+", encoding='utf-8')
52 | while (True):
53 | try:
54 | item = fo.readline().strip()
55 | except Exception as e:
56 | print('nothing of except:', e)
57 | break
58 | if (item == None or item == ""):
59 | break
60 | n = n + 1
61 | phonemes, char_embeds = tts_front.chinese_to_phonemes(item)
62 | input_ids = cleaned_text_to_sequence(phonemes)
63 | pause_tmpt = np.array(input_ids)
64 | pause_mask = np.where(pause_tmpt == 2, 0, 1)
65 | pause_valu = np.where(pause_tmpt == 2, 1, 0)
66 | assert args.pause > 1
67 | pause_valu = pause_valu * ((args.pause * 16) // 256)
68 | with torch.no_grad():
69 | x_tst = torch.LongTensor(input_ids).unsqueeze(0).to(device)
70 | x_tst_lengths = torch.LongTensor([len(input_ids)]).to(device)
71 | x_tst_prosody = torch.FloatTensor(char_embeds).unsqueeze(0).to(device)
72 | audio = net_g.infer_pause(x_tst, x_tst_lengths, x_tst_prosody, pause_mask, pause_valu, noise_scale=0.5,
73 | length_scale=1)[0][0, 0].data.cpu().float().numpy()
74 | save_wav(audio, f"./vits_infer_out/bert_vits_{n}.wav", hps.data.sampling_rate)
75 | fo.close()
76 |
--------------------------------------------------------------------------------
/vits_infer_stream.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import numpy as np
4 |
5 | import torch
6 | import utils
7 | import argparse
8 |
9 | from scipy.io import wavfile
10 | from text.symbols import symbols
11 | from text import cleaned_text_to_sequence
12 | from vits_pinyin import VITS_PinYin
13 |
14 | parser = argparse.ArgumentParser(description='Inference code for bert vits models')
15 | parser.add_argument('--config', type=str, required=True)
16 | parser.add_argument('--model', type=str, required=True)
17 | args = parser.parse_args()
18 |
19 | def save_wav(wav, path, rate):
20 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) * 0.6
21 | wavfile.write(path, rate, wav.astype(np.int16))
22 |
23 | # device
24 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25 | device = torch.device("cpu")
26 |
27 | # pinyin
28 | tts_front = VITS_PinYin("./bert", device)
29 |
30 | # config
31 | hps = utils.get_hparams_from_file(args.config)
32 |
33 | # model
34 | net_g = utils.load_class(hps.train.eval_class)(
35 | len(symbols),
36 | hps.data.filter_length // 2 + 1,
37 | hps.train.segment_size // hps.data.hop_length,
38 | **hps.model)
39 |
40 | # model_path = "logs/bert_vits/G_200000.pth"
41 | # utils.save_model(net_g, "vits_bert_model.pth")
42 | # model_path = "vits_bert_model.pth"
43 | utils.load_model(args.model, net_g)
44 | net_g.eval()
45 | net_g.to(device)
46 |
47 | os.makedirs("./vits_infer_out/", exist_ok=True)
48 | if __name__ == "__main__":
49 | n = 0
50 | fo = open("vits_infer_item.txt", "r+", encoding='utf-8')
51 | while (True):
52 | try:
53 | item = fo.readline().strip()
54 | except Exception as e:
55 | print('nothing of except:', e)
56 | break
57 | if (item == None or item == ""):
58 | break
59 | n = n + 1
60 | phonemes, char_embeds = tts_front.chinese_to_phonemes(item)
61 | input_ids = cleaned_text_to_sequence(phonemes)
62 | with torch.no_grad():
63 | x_tst = torch.LongTensor(input_ids).unsqueeze(0).to(device)
64 | x_tst_lengths = torch.LongTensor([len(input_ids)]).to(device)
65 | x_tst_prosody = torch.FloatTensor(char_embeds).unsqueeze(0).to(device)
66 | audio = net_g.infer_stream(x_tst, x_tst_lengths, x_tst_prosody, noise_scale=0.5,length_scale=1)
67 | save_wav(audio, f"./vits_infer_out/bert_vits_stream_{n}.wav", hps.data.sampling_rate)
68 | fo.close()
69 |
--------------------------------------------------------------------------------
/vits_pinyin.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | from tn.chinese.normalizer import Normalizer
4 |
5 | from pypinyin import lazy_pinyin, Style
6 | from pypinyin.core import load_phrases_dict
7 |
8 | from text import pinyin_dict
9 | from bert import TTSProsody
10 |
11 |
12 | def is_chinese(uchar):
13 | if uchar >= u'\u4e00' and uchar <= u'\u9fa5':
14 | return True
15 | else:
16 | return False
17 |
18 |
19 | def clean_chinese(text: str):
20 | text = text.strip()
21 | text_clean = []
22 | for char in text:
23 | if (is_chinese(char)):
24 | text_clean.append(char)
25 | else:
26 | if len(text_clean) > 1 and is_chinese(text_clean[-1]):
27 | text_clean.append(',')
28 | text_clean = ''.join(text_clean).strip(',')
29 | return text_clean
30 |
31 |
32 | def load_pinyin_dict():
33 | my_dict={}
34 | with open("./text/pinyin-local.txt", "r", encoding='utf-8') as f:
35 | content = f.readlines()
36 | for line in content:
37 | cuts = line.strip().split()
38 | hanzi = cuts[0]
39 | phone = cuts[1:]
40 | tmp = []
41 | for p in phone:
42 | tmp.append([p])
43 | my_dict[hanzi] = tmp
44 | load_phrases_dict(my_dict)
45 |
46 |
47 | class VITS_PinYin:
48 | def __init__(self, bert_path, device, hasBert=True):
49 | load_pinyin_dict()
50 | self.hasBert = hasBert
51 | if self.hasBert:
52 | self.prosody = TTSProsody(bert_path, device)
53 | self.normalizer = Normalizer()
54 |
55 | def get_phoneme4pinyin(self, pinyins):
56 | result = []
57 | count_phone = []
58 | for pinyin in pinyins:
59 | if pinyin[:-1] in pinyin_dict:
60 | tone = pinyin[-1]
61 | a = pinyin[:-1]
62 | a1, a2 = pinyin_dict[a]
63 | result += [a1, a2 + tone]
64 | count_phone.append(2)
65 | return result, count_phone
66 |
67 | def chinese_to_phonemes(self, text):
68 | text = self.normalizer.normalize(text)
69 | text = clean_chinese(text)
70 | phonemes = ["sil"]
71 | chars = ['[PAD]']
72 | count_phone = []
73 | count_phone.append(1)
74 | for subtext in text.split(","):
75 | if (len(subtext) == 0):
76 | continue
77 | pinyins = self.correct_pinyin_tone3(subtext)
78 | sub_p, sub_c = self.get_phoneme4pinyin(pinyins)
79 | phonemes.extend(sub_p)
80 | phonemes.append("sp")
81 | count_phone.extend(sub_c)
82 | count_phone.append(1)
83 | chars.append(subtext)
84 | chars.append(',')
85 | phonemes.append("sil")
86 | count_phone.append(1)
87 | chars.append('[PAD]')
88 | chars = "".join(chars)
89 | char_embeds = None
90 | if self.hasBert:
91 | char_embeds = self.prosody.get_char_embeds(chars)
92 | char_embeds = self.prosody.expand_for_phone(char_embeds, count_phone)
93 | return " ".join(phonemes), char_embeds
94 |
95 | def correct_pinyin_tone3(self, text):
96 | pinyin_list = lazy_pinyin(text,
97 | style=Style.TONE3,
98 | strict=False,
99 | neutral_tone_with_five=True,
100 | tone_sandhi=True)
101 | # , tone_sandhi=True -> 33变调
102 | return pinyin_list
103 |
--------------------------------------------------------------------------------
/vits_prepare.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import argparse
5 | import utils
6 |
7 | from bert import TTSProsody
8 | from bert.prosody_tool import is_chinese, pinyin_dict
9 | from utils import load_wav_to_torch
10 | from mel_processing import spectrogram_torch
11 |
12 |
13 | os.makedirs("./data/waves", exist_ok=True)
14 | os.makedirs("./data/berts", exist_ok=True)
15 | os.makedirs("./data/temps", exist_ok=True)
16 |
17 |
18 | def log(info: str):
19 | with open(f'./data/prepare.log', "a", encoding='utf-8') as flog:
20 | print(info, file=flog)
21 |
22 |
23 | def get_spec(hps, filename):
24 | audio, sampling_rate = load_wav_to_torch(filename)
25 | assert sampling_rate == hps.data.sampling_rate, f"{sampling_rate} is not {hps.data.sampling_rate}"
26 | audio_norm = audio / hps.data.max_wav_value
27 | audio_norm = audio_norm.unsqueeze(0)
28 | spec = spectrogram_torch(
29 | audio_norm,
30 | hps.data.filter_length,
31 | hps.data.sampling_rate,
32 | hps.data.hop_length,
33 | hps.data.win_length,
34 | center=False,
35 | )
36 | spec = torch.squeeze(spec, 0)
37 | return spec
38 |
39 |
40 | if __name__ == "__main__":
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument(
43 | "-c",
44 | "--config",
45 | type=str,
46 | default="./configs/bert_vits.json",
47 | help="JSON file for configuration",
48 | )
49 | args = parser.parse_args()
50 | hps = utils.get_hparams_from_file(args.config)
51 |
52 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53 | # device = torch.device("cpu")
54 | prosody = TTSProsody("./bert", device)
55 |
56 | fo = open(f"./data/000001-010000.txt", "r+", encoding='utf-8')
57 | scrips = []
58 | while (True):
59 | try:
60 | message = fo.readline().strip()
61 | pinyins = fo.readline().strip()
62 | except Exception as e:
63 | print('nothing of except:', e)
64 | break
65 | if (message == None):
66 | break
67 | if (message == ""):
68 | break
69 | infosub = message.split("\t")
70 | fileidx = infosub[0]
71 | message = infosub[1]
72 | message = message.replace("#1", "")
73 | message = message.replace("#2", "")
74 | message = message.replace("#3", "")
75 | message = message.replace("#4", "")
76 | log(f"{fileidx}\t{message}")
77 | log(f"\t{pinyins}")
78 |
79 | try:
80 | phone_index = 0
81 | phone_items = []
82 | phone_items.append('sil')
83 | count_phone = []
84 | count_phone.append(1)
85 |
86 | pinyins = pinyins.split()
87 | len_pys = len(pinyins)
88 | for word in message:
89 | if is_chinese(word):
90 | count_phone.append(2)
91 | if (phone_index >= len_pys):
92 | print(len_pys)
93 | print(phone_index)
94 | pinyin = pinyins[phone_index]
95 | phone_index = phone_index + 1
96 | if pinyin[:-1] in pinyin_dict:
97 | tone = pinyin[-1]
98 | a = pinyin[:-1]
99 | a1, a2 = pinyin_dict[a]
100 | phone_items += [a1, a2 + tone]
101 | else:
102 | raise IndexError(f'Unkown PinYin: {pinyin}')
103 | else:
104 | count_phone.append(1)
105 | phone_items.append('sp')
106 | count_phone.append(1)
107 | phone_items.append('sil')
108 | phone_items_str = ' '.join(phone_items)
109 | log(f"\t{phone_items_str}")
110 | except IndexError as e:
111 | print(f"{fileidx}\t{message}")
112 | print('except:', e)
113 | continue
114 |
115 | text = f'[PAD]{message}[PAD]'
116 | char_embeds = prosody.get_char_embeds(text)
117 | char_embeds = prosody.expand_for_phone(char_embeds, count_phone)
118 | char_embeds_path = f"./data/berts/{fileidx}.npy"
119 | np.save(char_embeds_path, char_embeds, allow_pickle=False)
120 |
121 | wave_path = f"./data/waves/{fileidx}.wav"
122 | spec_path = f"./data/temps/{fileidx}.spec.pt"
123 | spec = get_spec(hps, wave_path)
124 |
125 | torch.save(spec, spec_path)
126 | scrips.append(
127 | f"./data/waves/{fileidx}.wav|./data/temps/{fileidx}.spec.pt|./data/berts/{fileidx}.npy|{phone_items_str}")
128 |
129 | fo.close()
130 |
131 | fout = open(f'./filelists/all.txt', 'w', encoding='utf-8')
132 | for item in scrips:
133 | print(item, file=fout)
134 | fout.close()
135 | fout = open(f'./filelists/valid.txt', 'w', encoding='utf-8')
136 | for item in scrips[:100]:
137 | print(item, file=fout)
138 | fout.close()
139 | fout = open(f'./filelists/train.txt', 'w', encoding='utf-8')
140 | for item in scrips[100:]:
141 | print(item, file=fout)
142 | fout.close()
143 |
--------------------------------------------------------------------------------
/vits_resample.py:
--------------------------------------------------------------------------------
1 | import os
2 | import librosa
3 | import argparse
4 | import numpy as np
5 | from tqdm import tqdm
6 | from concurrent.futures import ThreadPoolExecutor, as_completed
7 | from scipy.io import wavfile
8 |
9 |
10 | def resample_wave(wav_in, wav_out, sample_rate):
11 | wav, _ = librosa.load(wav_in, sr=sample_rate)
12 | wav = wav / np.abs(wav).max() * 0.6
13 | wav = wav / max(0.01, np.max(np.abs(wav))) * 32767 * 0.6
14 | wavfile.write(wav_out, sample_rate, wav.astype(np.int16))
15 |
16 |
17 | def process_file(file, wavPath, outPath, sr):
18 | if file.endswith(".wav"):
19 | file = file[:-4]
20 | resample_wave(f"{wavPath}/{file}.wav", f"{outPath}/{file}.wav", sr)
21 |
22 |
23 | def process_files_with_thread_pool(wavPath, outPath, sr, thread_num=None):
24 | files = [f for f in os.listdir(f"./{wavPath}") if f.endswith(".wav")]
25 |
26 | with ThreadPoolExecutor(max_workers=thread_num) as executor:
27 | futures = {executor.submit(process_file, file, wavPath, outPath, sr): file for file in files}
28 |
29 | for future in tqdm(as_completed(futures), total=len(futures), desc=f'Processing {sr}'):
30 | future.result()
31 |
32 |
33 | if __name__ == "__main__":
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument("-w", "--wav", help="wav", dest="wav", required=True)
36 | parser.add_argument("-o", "--out", help="out", dest="out", required=True)
37 | parser.add_argument("-s", "--sr", help="sample rate", dest="sr", type=int, required=True)
38 | parser.add_argument("-t", "--thread_count", help="thread count to process, set 0 to use all cpu cores", dest="thread_count", type=int, default=1)
39 |
40 | args = parser.parse_args()
41 | print(args.wav)
42 | print(args.out)
43 | print(args.sr)
44 |
45 | os.makedirs(args.out, exist_ok=True)
46 | wavPath = args.wav
47 | outPath = args.out
48 |
49 | if args.thread_count == 0:
50 | process_num = os.cpu_count() // 2 + 1
51 | else:
52 | process_num = args.thread_count
53 | process_files_with_thread_pool(wavPath, outPath, args.sr, process_num)
54 |
--------------------------------------------------------------------------------