├── .gitattributes
├── LICENSE
├── README.md
├── __init__.py
├── conf
├── glow_tts_baseline.yaml
├── glow_tts_std.yaml
├── glow_tts_stdp.yaml
└── hifigan16k_ft.yaml
├── dataset
├── dev_tts_common_voice.json
└── train_tts_common_voice.json
├── glow_tts_with_pitch.py
├── media
└── glow_tts_stdp.png
├── models
├── __init__.py
└── glow_tts_with_pitch.py
├── modules
├── glow_tts_modules
│ ├── __init__.py
│ ├── glow_tts_submodules.py
│ ├── glow_tts_submodules_with_pitch.py
│ ├── stocpred_modules.py
│ └── transforms.py
└── glow_tts_with_pitch.py
├── notebook
└── inference_glowtts_stdp.ipynb
├── train_glowtts_baseline.sh
├── train_glowtts_std.sh
├── train_glowtts_stdp.sh
└── utils
├── __init__.py
├── data.py
├── glow_tts_loss.py
├── helpers.py
├── tts_data_types.py
└── tts_tokenizers.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.ipynb linguist-vendored
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Stochastic Pitch Prediction Improves the Diversity and Naturalness of Speech in Glow-TTS
3 |
4 | ### Sewade Ogun, Vincent Colotte, Emmanuel Vincent
5 |
6 | In our recent [paper](https://arxiv.org/abs/2305.17724), we proposed GlowTTS-STDP, a flow-based TTS model that improves the naturalness and diversity of generated utterances.
7 |
8 |
9 |
10 |
11 | Glow-TTS-STDP at Inference |
12 |
13 |
14 |  |
15 |
16 |
17 |
18 |
19 | ## 2. Requirements
20 |
21 | Our model is based on the GlowTTS architecture which we implemented in the NeMo toolkit. You have to install the same version of NeMo used in our experiments in order to ensure all dependencies work correctly.
22 |
23 | ```bash
24 | apt-get update && apt-get install -y libsndfile1 ffmpeg
25 | git clone https://github.com/NVIDIA/NeMo
26 | cd NeMo
27 | git checkout v1.5.0
28 | ./reinstall.sh
29 | ```
30 |
31 | After installation, you should have;
32 | 1) NeMo toolkit (version 1.5.0), [https://github.com/NVIDIA/NeMo](https://github.com/NVIDIA/NeMo)
33 | 2) Pytorch 1.10.0 or above
34 | 3) Pytorch Lightning
35 |
36 |
37 | GPUs are required for model training. Kindly note that we used mixed-precision training for all our experiments.
38 |
39 | PS: Checkout the NeMo github page if you have problems with the library installations.
40 |
41 | ## 2. Model setup
42 |
43 | Clone this github repo after installing NeMo and changing to the correct branch successfully. This repo contains;
44 | i. the model,
45 | ii. the dataset (without the audio files),
46 | iii. the training scripts,
47 | iv. configuration files for all the experiments
48 |
49 | ```bash
50 | git clone https://github.com/ogunlao/glowtts_stdp
51 | ```
52 |
53 | ## 2. Pre-requisites
54 |
55 | a) Download and extract [the English subset of Common Voice Version 7.0](https://commonvoice.mozilla.org/en/datasets) into the `dataset` directory. Convert the files from mp3 to wav, and resample the files to 16 kHz for faster data loading. The training and validation json files, which contains [CommonVoice WV-MOS-4.0-all](https://arxiv.org/abs/2210.06370) has been provided.
56 |
57 | b) A HiFi-GAN vocoder trained with 16 kHz multi-speaker speech utterances is required. We trained a Hifi-GAN v1 on [LibriTTS](http://www.openslr.org/60). HiFI-GAN can be trained using the NeMO toolkit.
58 | The config file for hifi-GAN is provided in `glowtts_stdp/conf/hifigan16k_ft.yaml`
59 |
60 | c) A speaker embedding file is required either in the form of a pickle or json file. We extract embedding vectors using the open source library, [resemblyzer](https://github.com/resemble-ai/Resemblyzer).
61 |
62 | Embeddings should be saved as a lookup table (dictionary) using the structure:
63 |
64 | ```
65 | {
66 | audio1: [[embedding vector1]],
67 | audio2: [[embedding vector1]],
68 | }
69 | ```
70 |
71 | Notice that audio files are without extension. The lookup table can either be saved on disk as a pickle or json file.
72 |
73 |
74 | ## 3. Training Example
75 |
76 | To train the baseline GlowTTS model
77 | ```sh
78 | cd glowtts_stdp
79 | sh train_glowtts_baseline.sh
80 | ```
81 |
82 | To train the GlowTTS-STD model (model with stochastic duration prediction)
83 | ```sh
84 | cd glowtts_stdp
85 | sh train_glowtts_std.sh
86 | ```
87 |
88 | To train the GlowTTS-STDP model (model with stochastic duration prediction and stochastic pitch prediction)
89 | ```sh
90 | cd glowtts_stdp
91 | sh train_glowtts_stdp.sh
92 | ```
93 |
94 | NeMo uses [Hydra](https://hydra.cc/) for hyperparameter configuration, therefore hyperparameters can be changed either in their respective config file or in their train scripts.
95 |
96 | ## 4. Inference Example
97 |
98 | See [inference_glowtts_stdp.ipynb](notebook/inference_glowtts_stdp.ipynb)
99 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ogunlao/glowtts_stdp/07f71bbfce405018f43db7c086f46c6b91defa28/__init__.py
--------------------------------------------------------------------------------
/conf/glow_tts_baseline.yaml:
--------------------------------------------------------------------------------
1 | name: "GlowTTS_baseline"
2 | gin_channels: 256
3 | use_stoch_dur_pred: false
4 | use_stoch_pitch_pred: false
5 | use_log_pitch: false
6 | sup_data_path: ???
7 | sup_data_types: ["speaker_emb"]
8 |
9 | train_dataset: ???
10 | validation_datasets: ???
11 | test_datasets: null
12 |
13 | phoneme_dict_path: "../NeMo/scripts/tts_dataset_files/cmudict-0.7b_nv22.10"
14 | heteronyms_path: "../NeMo/scripts/tts_dataset_files/heteronyms-052722"
15 | whitelist_path: "../NeMo/nemo_text_processing/text_normalization/en/data/whitelist/lj_speech.tsv"
16 |
17 | speaker_emb_path: ???
18 |
19 | # Default values from librosa.pyin
20 | pitch_fmin: null
21 | pitch_fmax: null
22 | pitch_mean: null
23 | pitch_std: null
24 |
25 | # Default values for dataset with sample_rate=22050
26 | sample_rate: 16000
27 | n_mel_channels: 80
28 | n_window_size: 1024
29 | n_window_stride: 256
30 | n_fft: 1024
31 | lowfreq: 0
32 | highfreq: 8000
33 | window: hann
34 | pad_value: 0.0
35 |
36 | model:
37 | n_speakers: 4469
38 | gin_channels: ${gin_channels}
39 | use_external_speaker_emb: true
40 | speaker_emb_path: ${speaker_emb_path}
41 | use_stoch_dur_pred: ${use_stoch_dur_pred}
42 | use_stoch_pitch_pred: ${use_stoch_pitch_pred}
43 | pitch_loss_scale: 0.0
44 |
45 | max_token_duration: 75
46 | symbols_embedding_dim: 256 #384
47 | pitch_embedding_kernel_size: 3
48 |
49 | pitch_fmin: ${pitch_fmin}
50 | pitch_fmax: ${pitch_fmax}
51 |
52 | pitch_mean: ${pitch_mean}
53 | pitch_std: ${pitch_std}
54 |
55 | sample_rate: ${sample_rate}
56 | n_mel_channels: ${n_mel_channels}
57 | n_window_size: ${n_window_size}
58 | n_window_stride: ${n_window_stride}
59 | n_fft: ${n_fft}
60 | lowfreq: ${lowfreq}
61 | highfreq: ${highfreq}
62 | window: ${window}
63 | pad_value: ${pad_value}
64 |
65 | # use_log_pitch: ${use_log_pitch}
66 |
67 | # text_normalizer:
68 | # _target_: nemo_text_processing.text_normalization.normalize.Normalizer
69 | # lang: en
70 | # input_case: cased
71 | # whitelist: ${whitelist_path}
72 |
73 | # text_normalizer_call_kwargs:
74 | # verbose: false
75 | # punct_pre_process: true
76 | # punct_post_process: true
77 |
78 | text_tokenizer:
79 | _target_: utils.tts_tokenizers.EnglishPhonemesTokenizer
80 | punct: true
81 | stresses: false # true
82 | chars: true
83 | apostrophe: true
84 | pad_with_space: false
85 | add_blank_at: "last"
86 | add_blank_to_text: true
87 | g2p:
88 | _target_: nemo_text_processing.g2p.modules.EnglishG2p
89 | phoneme_dict: ${phoneme_dict_path}
90 | heteronyms: ${heteronyms_path}
91 | phoneme_probability: 0.8
92 |
93 | train_ds:
94 | dataset:
95 | _target_: utils.data.TTSDataset
96 | manifest_filepath: ${train_dataset}
97 | sample_rate: ${model.sample_rate}
98 | sup_data_path: ${sup_data_path}
99 | sup_data_types: ${sup_data_types}
100 | n_fft: ${n_fft}
101 | win_length: ${model.n_fft}
102 | hop_length: ${model.n_window_stride}
103 | window: ${model.window}
104 | n_mels: ${model.n_mel_channels}
105 | lowfreq: ${model.lowfreq}
106 | highfreq: ${model.highfreq}
107 | max_duration: 16.7
108 | min_duration: 0.1
109 | ignore_file: null
110 | trim: false
111 | # pitch_fmin: ${model.pitch_fmin}
112 | # pitch_fmax: ${model.pitch_fmax}
113 | # pitch_norm: true
114 | # pitch_mean: ${model.pitch_mean}
115 | # pitch_std: ${model.pitch_std}
116 | # use_log_pitch: ${model.use_log_pitch}
117 | speaker_emb_path: ${speaker_emb_path}
118 |
119 | dataloader_params:
120 | drop_last: false
121 | shuffle: true
122 | batch_size: 32
123 | num_workers: 32
124 | pin_memory: true
125 |
126 | validation_ds:
127 | dataset:
128 | _target_: utils.data.TTSDataset
129 | manifest_filepath: ${validation_datasets}
130 | sample_rate: ${model.sample_rate}
131 | sup_data_path: ${sup_data_path}
132 | sup_data_types: ${sup_data_types}
133 | n_fft: ${model.n_fft}
134 | win_length: ${model.n_window_size}
135 | hop_length: ${model.n_window_stride}
136 | window: ${model.window}
137 | n_mels: ${model.n_mel_channels}
138 | lowfreq: ${model.lowfreq}
139 | highfreq: ${model.highfreq}
140 | max_duration: null
141 | min_duration: null
142 | ignore_file: null
143 | trim: false
144 | # pitch_fmin: ${model.pitch_fmin}
145 | # pitch_fmax: ${model.pitch_fmax}
146 | # pitch_norm: true
147 | # pitch_mean: ${model.pitch_mean}
148 | # pitch_std: ${model.pitch_std}
149 | # use_log_pitch: ${model.use_log_pitch}
150 | speaker_emb_path: ${speaker_emb_path}
151 |
152 | dataloader_params:
153 | drop_last: false
154 | shuffle: false
155 | batch_size: 2
156 | num_workers: 2
157 | pin_memory: true
158 |
159 | preprocessor:
160 | _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
161 | features: ${model.n_mel_channels}
162 | lowfreq: ${model.lowfreq}
163 | highfreq: ${model.highfreq}
164 | n_fft: ${model.n_fft}
165 | n_window_size: ${model.n_window_size}
166 | window_size: false
167 | n_window_stride: ${model.n_window_stride}
168 | window_stride: false
169 | pad_to: 1
170 | pad_value: ${model.pad_value}
171 | sample_rate: ${model.sample_rate}
172 | window: ${model.window}
173 | normalize: null
174 | preemph: null
175 | dither: 0.0
176 | frame_splicing: 1
177 | log: true
178 | log_zero_guard_type: add
179 | log_zero_guard_value: 1e-05
180 | mag_power: 1.0
181 |
182 | encoder:
183 | _target_: modules.glow_tts_with_pitch.TextEncoder
184 | n_vocab: 148
185 | out_channels: ${model.n_mel_channels}
186 | hidden_channels: 192
187 | filter_channels: 768
188 | filter_channels_dp: 256
189 | kernel_size: 3
190 | p_dropout: 0.1
191 | n_layers: 6
192 | n_heads: 2
193 | window_size: 4
194 | prenet: true
195 | mean_only: true
196 | gin_channels: ${gin_channels}
197 | use_stoch_dur_pred: ${use_stoch_dur_pred}
198 |
199 | decoder:
200 | _target_: modules.glow_tts_with_pitch.FlowSpecDecoder
201 | in_channels: ${model.n_mel_channels}
202 | hidden_channels: 192
203 | kernel_size: 5
204 | n_blocks: 12
205 | n_layers: 4
206 | n_sqz: 2
207 | n_split: 4
208 | sigmoid_scale: false
209 | p_dropout: 0.05
210 | dilation_rate: 1
211 | gin_channels: ${gin_channels}
212 |
213 | optim:
214 | name: radam
215 | lr: 2e-4
216 | # optimizer arguments
217 | betas: [0.9, 0.98]
218 | weight_decay: 0.0
219 |
220 | # scheduler setup
221 | sched:
222 | name: CosineAnnealing
223 |
224 | # Scheduler params
225 | warmup_steps: 6000
226 | min_lr: 1e-5
227 | last_epoch: -1
228 |
229 |
230 | trainer:
231 | accelerator: auto
232 | devices: -1 #-1 # number of gpus
233 | strategy: ddp
234 | num_nodes: 1
235 | enable_checkpointing: false # Provided by exp_manager
236 | logger: false # Provided by exp_manager
237 | max_epochs: 1000
238 | max_steps: -1 # computed at runtime if not set
239 | accumulate_grad_batches: 2
240 | log_every_n_steps: 100 # Interval of logging.
241 | check_val_every_n_epoch: 2
242 | amp_backend: native
243 | precision: 16 # mixed-precision training
244 | gradient_clip_val: 5.0
245 |
246 |
247 | exp_manager:
248 | exp_dir: null
249 | name: ${name}
250 | resume_if_exists: False
251 | resume_ignore_no_checkpoint: True
252 | create_tensorboard_logger: True
253 | create_checkpoint_callback: True
254 | checkpoint_callback_params:
255 | always_save_nemo: True
256 | save_top_k: 1
257 | monitor: "val_loss"
258 | mode: "min"
259 | create_early_stopping_callback: False
260 | early_stopping_params:
261 | monitor: "val_loss"
262 | patience: 10
263 | verbose: True
264 | mode: "min"
265 | create_wandb_logger: False
266 | wandb_logger_kwargs:
267 | name: null
268 | project: null
269 |
270 | hydra:
271 | run:
272 | dir: .
273 | job_logging:
274 | root:
275 | handlers: null
--------------------------------------------------------------------------------
/conf/glow_tts_std.yaml:
--------------------------------------------------------------------------------
1 | name: "GlowTTS_std"
2 | gin_channels: 256
3 | use_stoch_dur_pred: true
4 | use_stoch_pitch_pred: false
5 | use_log_pitch: false
6 | sup_data_path: ???
7 | sup_data_types: ["speaker_emb"]
8 |
9 | train_dataset: ???
10 | validation_datasets: ???
11 | test_datasets: null
12 |
13 | phoneme_dict_path: "../NeMo/scripts/tts_dataset_files/cmudict-0.7b_nv22.10"
14 | heteronyms_path: "../NeMo/scripts/tts_dataset_files/heteronyms-052722"
15 | whitelist_path: "../NeMo/nemo_text_processing/text_normalization/en/data/whitelist/lj_speech.tsv"
16 |
17 | speaker_emb_path: ???
18 |
19 | # Default values from librosa.pyin
20 | pitch_fmin: null
21 | pitch_fmax: null
22 | pitch_mean: null
23 | pitch_std: null
24 |
25 | # Default values for dataset with sample_rate=22050
26 | sample_rate: 16000
27 | n_mel_channels: 80
28 | n_window_size: 1024
29 | n_window_stride: 256
30 | n_fft: 1024
31 | lowfreq: 0
32 | highfreq: 8000
33 | window: hann
34 | pad_value: 0.0
35 |
36 | model:
37 | n_speakers: 4469
38 | gin_channels: ${gin_channels}
39 | use_external_speaker_emb: true
40 | speaker_emb_path: ${speaker_emb_path}
41 | use_stoch_dur_pred: ${use_stoch_dur_pred}
42 | use_stoch_pitch_pred: ${use_stoch_pitch_pred}
43 | pitch_loss_scale: 0.0
44 |
45 | max_token_duration: 75
46 | symbols_embedding_dim: 256 #384
47 | pitch_embedding_kernel_size: 3
48 |
49 | pitch_fmin: ${pitch_fmin}
50 | pitch_fmax: ${pitch_fmax}
51 |
52 | pitch_mean: ${pitch_mean}
53 | pitch_std: ${pitch_std}
54 |
55 | sample_rate: ${sample_rate}
56 | n_mel_channels: ${n_mel_channels}
57 | n_window_size: ${n_window_size}
58 | n_window_stride: ${n_window_stride}
59 | n_fft: ${n_fft}
60 | lowfreq: ${lowfreq}
61 | highfreq: ${highfreq}
62 | window: ${window}
63 | pad_value: ${pad_value}
64 |
65 | # use_log_pitch: ${use_log_pitch}
66 |
67 | # text_normalizer:
68 | # _target_: nemo_text_processing.text_normalization.normalize.Normalizer
69 | # lang: en
70 | # input_case: cased
71 | # whitelist: ${whitelist_path}
72 |
73 | # text_normalizer_call_kwargs:
74 | # verbose: false
75 | # punct_pre_process: true
76 | # punct_post_process: true
77 |
78 | text_tokenizer:
79 | _target_: utils.tts_tokenizers.EnglishPhonemesTokenizer
80 | punct: true
81 | stresses: false # true
82 | chars: true
83 | apostrophe: true
84 | pad_with_space: false
85 | add_blank_at: "last"
86 | add_blank_to_text: true
87 | g2p:
88 | _target_: nemo_text_processing.g2p.modules.EnglishG2p
89 | phoneme_dict: ${phoneme_dict_path}
90 | heteronyms: ${heteronyms_path}
91 | phoneme_probability: 0.8
92 |
93 | train_ds:
94 | dataset:
95 | _target_: utils.data.TTSDataset
96 | manifest_filepath: ${train_dataset}
97 | sample_rate: ${model.sample_rate}
98 | sup_data_path: ${sup_data_path}
99 | sup_data_types: ${sup_data_types}
100 | n_fft: ${n_fft}
101 | win_length: ${model.n_fft}
102 | hop_length: ${model.n_window_stride}
103 | window: ${model.window}
104 | n_mels: ${model.n_mel_channels}
105 | lowfreq: ${model.lowfreq}
106 | highfreq: ${model.highfreq}
107 | max_duration: 16.7
108 | min_duration: 0.1
109 | ignore_file: null
110 | trim: false
111 | # pitch_fmin: ${model.pitch_fmin}
112 | # pitch_fmax: ${model.pitch_fmax}
113 | # pitch_norm: true
114 | # pitch_mean: ${model.pitch_mean}
115 | # pitch_std: ${model.pitch_std}
116 | # use_log_pitch: ${model.use_log_pitch}
117 | speaker_emb_path: ${speaker_emb_path}
118 |
119 | dataloader_params:
120 | drop_last: false
121 | shuffle: true
122 | batch_size: 32
123 | num_workers: 32
124 | pin_memory: true
125 |
126 | validation_ds:
127 | dataset:
128 | _target_: utils.data.TTSDataset
129 | manifest_filepath: ${validation_datasets}
130 | sample_rate: ${model.sample_rate}
131 | sup_data_path: ${sup_data_path}
132 | sup_data_types: ${sup_data_types}
133 | n_fft: ${model.n_fft}
134 | win_length: ${model.n_window_size}
135 | hop_length: ${model.n_window_stride}
136 | window: ${model.window}
137 | n_mels: ${model.n_mel_channels}
138 | lowfreq: ${model.lowfreq}
139 | highfreq: ${model.highfreq}
140 | max_duration: null
141 | min_duration: null
142 | ignore_file: null
143 | trim: false
144 | # pitch_fmin: ${model.pitch_fmin}
145 | # pitch_fmax: ${model.pitch_fmax}
146 | # pitch_norm: true
147 | # pitch_mean: ${model.pitch_mean}
148 | # pitch_std: ${model.pitch_std}
149 | # use_log_pitch: ${model.use_log_pitch}
150 | speaker_emb_path: ${speaker_emb_path}
151 |
152 | dataloader_params:
153 | drop_last: false
154 | shuffle: false
155 | batch_size: 2
156 | num_workers: 2
157 | pin_memory: true
158 |
159 | preprocessor:
160 | _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
161 | features: ${model.n_mel_channels}
162 | lowfreq: ${model.lowfreq}
163 | highfreq: ${model.highfreq}
164 | n_fft: ${model.n_fft}
165 | n_window_size: ${model.n_window_size}
166 | window_size: false
167 | n_window_stride: ${model.n_window_stride}
168 | window_stride: false
169 | pad_to: 1
170 | pad_value: ${model.pad_value}
171 | sample_rate: ${model.sample_rate}
172 | window: ${model.window}
173 | normalize: null
174 | preemph: null
175 | dither: 0.0
176 | frame_splicing: 1
177 | log: true
178 | log_zero_guard_type: add
179 | log_zero_guard_value: 1e-05
180 | mag_power: 1.0
181 |
182 | encoder:
183 | _target_: modules.glow_tts_with_pitch.TextEncoder
184 | n_vocab: 148
185 | out_channels: ${model.n_mel_channels}
186 | hidden_channels: 192
187 | filter_channels: 768
188 | filter_channels_dp: 256
189 | kernel_size: 3
190 | p_dropout: 0.1
191 | n_layers: 6
192 | n_heads: 2
193 | window_size: 4
194 | prenet: true
195 | mean_only: true
196 | gin_channels: ${gin_channels}
197 | use_stoch_dur_pred: ${use_stoch_dur_pred}
198 |
199 | decoder:
200 | _target_: modules.glow_tts_with_pitch.FlowSpecDecoder
201 | in_channels: ${model.n_mel_channels}
202 | hidden_channels: 192
203 | kernel_size: 5
204 | n_blocks: 12
205 | n_layers: 4
206 | n_sqz: 2
207 | n_split: 4
208 | sigmoid_scale: false
209 | p_dropout: 0.05
210 | dilation_rate: 1
211 | gin_channels: ${gin_channels}
212 |
213 | optim:
214 | name: radam
215 | lr: 2e-4
216 | # optimizer arguments
217 | betas: [0.9, 0.98]
218 | weight_decay: 0.0
219 |
220 | # scheduler setup
221 | sched:
222 | name: CosineAnnealing
223 |
224 | # Scheduler params
225 | warmup_steps: 6000
226 | min_lr: 1e-5
227 | last_epoch: -1
228 |
229 |
230 | trainer:
231 | accelerator: auto
232 | devices: -1 #-1 # number of gpus
233 | strategy: ddp
234 | num_nodes: 1
235 | enable_checkpointing: false # Provided by exp_manager
236 | logger: false # Provided by exp_manager
237 | max_epochs: 1000
238 | max_steps: -1 # computed at runtime if not set
239 | accumulate_grad_batches: 2
240 | log_every_n_steps: 100 # Interval of logging.
241 | check_val_every_n_epoch: 4
242 | amp_backend: native
243 | precision: 16 # mixed-precision training
244 | gradient_clip_val: 5.0
245 |
246 | exp_manager:
247 | exp_dir: null
248 | name: ${name}
249 | resume_if_exists: False
250 | resume_ignore_no_checkpoint: True
251 | create_tensorboard_logger: True
252 | create_checkpoint_callback: True
253 | checkpoint_callback_params:
254 | always_save_nemo: True
255 | save_top_k: 1
256 | monitor: "val_loss"
257 | mode: "min"
258 | create_early_stopping_callback: False
259 | early_stopping_params:
260 | monitor: "val_loss"
261 | patience: 10
262 | verbose: True
263 | mode: "min"
264 | create_wandb_logger: False
265 | wandb_logger_kwargs:
266 | name: null
267 | project: null
268 |
269 | hydra:
270 | run:
271 | dir: .
272 | job_logging:
273 | root:
274 | handlers: null
--------------------------------------------------------------------------------
/conf/glow_tts_stdp.yaml:
--------------------------------------------------------------------------------
1 | name: "GlowTTS_stdp"
2 | gin_channels: 256
3 | use_stoch_dur_pred: true
4 | use_stoch_pitch_pred: true
5 | sup_data_path: ???
6 | sup_data_types: ["pitch", "speaker_emb"]
7 |
8 | use_log_pitch: true
9 | use_normalized_pitch: false
10 | unvoiced_value: null
11 | use_frame_emb_for_pitch: true
12 |
13 |
14 | train_dataset: ???
15 | validation_datasets: ???
16 | test_datasets: null
17 |
18 | phoneme_dict_path: "../NeMo/scripts/tts_dataset_files/cmudict-0.7b_nv22.10"
19 | heteronyms_path: "../NeMo/scripts/tts_dataset_files/heteronyms-052722"
20 | whitelist_path: "../NeMo/nemo_text_processing/text_normalization/en/data/whitelist/lj_speech.tsv"
21 |
22 | speaker_emb_path: ???
23 |
24 | # Default values from librosa.pyin
25 | pitch_fmin: null
26 | pitch_fmax: null
27 | pitch_mean: null
28 | pitch_std: null
29 | pitch_norm: false
30 |
31 | # Default values for dataset with sample_rate=22050
32 | sample_rate: 16000
33 | n_mel_channels: 80
34 | n_window_size: 1024
35 | n_window_stride: 256
36 | n_fft: 1024
37 | lowfreq: 0
38 | highfreq: 8000
39 | window: hann
40 | pad_value: 0.0
41 |
42 | model:
43 | n_speakers: 4469
44 | gin_channels: ${gin_channels}
45 | use_external_speaker_emb: true
46 | speaker_emb_path: ${speaker_emb_path}
47 | use_stoch_dur_pred: ${use_stoch_dur_pred}
48 | use_stoch_pitch_pred: ${use_stoch_pitch_pred}
49 | pitch_loss_scale: 0.5
50 |
51 | unvoiced_value: ${unvoiced_value}
52 | use_log_pitch: ${use_log_pitch}
53 | use_normalized_pitch: ${use_normalized_pitch}
54 | use_frame_emb_for_pitch: ${use_frame_emb_for_pitch}
55 |
56 | max_token_duration: 75
57 | symbols_embedding_dim: 256 #384
58 | pitch_embedding_kernel_size: 3
59 |
60 | pitch_fmin: ${pitch_fmin}
61 | pitch_fmax: ${pitch_fmax}
62 |
63 | pitch_mean: ${pitch_mean}
64 | pitch_std: ${pitch_std}
65 | pitch_norm: ${pitch_norm}
66 |
67 | sample_rate: ${sample_rate}
68 | n_mel_channels: ${n_mel_channels}
69 | n_window_size: ${n_window_size}
70 | n_window_stride: ${n_window_stride}
71 | n_fft: ${n_fft}
72 | lowfreq: ${lowfreq}
73 | highfreq: ${highfreq}
74 | window: ${window}
75 | pad_value: ${pad_value}
76 |
77 | # text_normalizer:
78 | # _target_: nemo_text_processing.text_normalization.normalize.Normalizer
79 | # lang: en
80 | # input_case: cased
81 | # whitelist: ${whitelist_path}
82 |
83 | # text_normalizer_call_kwargs:
84 | # verbose: false
85 | # punct_pre_process: true
86 | # punct_post_process: true
87 |
88 | text_tokenizer:
89 | _target_: utils.tts_tokenizers.EnglishPhonemesTokenizer
90 | punct: true
91 | stresses: false #true
92 | chars: true
93 | apostrophe: true
94 | pad_with_space: false
95 | add_blank_at: "last"
96 | add_blank_to_text: true
97 | g2p:
98 | _target_: nemo_text_processing.g2p.modules.EnglishG2p
99 | phoneme_dict: ${phoneme_dict_path}
100 | heteronyms: ${heteronyms_path}
101 | phoneme_probability: 0.8
102 |
103 | train_ds:
104 | dataset:
105 | _target_: utils.data.TTSDataset
106 | manifest_filepath: ${train_dataset}
107 | sample_rate: ${model.sample_rate}
108 | sup_data_path: ${sup_data_path}
109 | sup_data_types: ${sup_data_types}
110 | n_fft: ${n_fft}
111 | win_length: ${model.n_fft}
112 | hop_length: ${model.n_window_stride}
113 | window: ${model.window}
114 | n_mels: ${model.n_mel_channels}
115 | lowfreq: ${model.lowfreq}
116 | highfreq: ${model.highfreq}
117 | max_duration: 16.7
118 | min_duration: 0.1
119 | ignore_file: null
120 | trim: false
121 | pitch_fmin: ${model.pitch_fmin}
122 | pitch_fmax: ${model.pitch_fmax}
123 | pitch_norm: ${model.pitch_norm}
124 | pitch_mean: ${model.pitch_mean}
125 | pitch_std: ${model.pitch_std}
126 | use_log_pitch: ${model.use_log_pitch}
127 | use_beta_binomial_interpolator: true
128 | speaker_emb_path: ${speaker_emb_path}
129 |
130 | dataloader_params:
131 | drop_last: false
132 | shuffle: true
133 | batch_size: 24
134 | num_workers: 24
135 | pin_memory: true
136 |
137 | validation_ds:
138 | dataset:
139 | _target_: utils.data.TTSDataset
140 | manifest_filepath: ${validation_datasets}
141 | sample_rate: ${model.sample_rate}
142 | sup_data_path: ${sup_data_path}
143 | sup_data_types: ${sup_data_types}
144 | n_fft: ${model.n_fft}
145 | win_length: ${model.n_window_size}
146 | hop_length: ${model.n_window_stride}
147 | window: ${model.window}
148 | n_mels: ${model.n_mel_channels}
149 | lowfreq: ${model.lowfreq}
150 | highfreq: ${model.highfreq}
151 | max_duration: null
152 | min_duration: null
153 | ignore_file: null
154 | trim: false
155 | pitch_fmin: ${model.pitch_fmin}
156 | pitch_fmax: ${model.pitch_fmax}
157 | pitch_norm: ${model.pitch_norm}
158 | pitch_mean: ${model.pitch_mean}
159 | pitch_std: ${model.pitch_std}
160 | use_log_pitch: ${model.use_log_pitch}
161 | use_beta_binomial_interpolator: true
162 | speaker_emb_path: ${speaker_emb_path}
163 |
164 | dataloader_params:
165 | drop_last: false
166 | shuffle: false
167 | batch_size: 4
168 | num_workers: 4
169 | pin_memory: true
170 |
171 | preprocessor:
172 | _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
173 | features: ${model.n_mel_channels}
174 | lowfreq: ${model.lowfreq}
175 | highfreq: ${model.highfreq}
176 | n_fft: ${model.n_fft}
177 | n_window_size: ${model.n_window_size}
178 | window_size: false
179 | n_window_stride: ${model.n_window_stride}
180 | window_stride: false
181 | pad_to: 1 #16
182 | pad_value: ${model.pad_value}
183 | sample_rate: ${model.sample_rate}
184 | window: ${model.window}
185 | normalize: null
186 | preemph: null
187 | dither: 0.0
188 | frame_splicing: 1
189 | log: true
190 | log_zero_guard_type: add
191 | log_zero_guard_value: 1e-05
192 | mag_power: 1.0
193 |
194 | encoder:
195 | _target_: modules.glow_tts_with_pitch.TextEncoder
196 | n_vocab: 148
197 | out_channels: ${model.n_mel_channels}
198 | hidden_channels: 192 #192 #256 if using stoch_dur_pred
199 | filter_channels: 768
200 | filter_channels_dp: 256 #256 if not using stoch_dur_pred
201 | kernel_size: 3
202 | p_dropout: 0.1
203 | n_layers: 6
204 | n_heads: 2
205 | window_size: 4
206 | prenet: true
207 | mean_only: true
208 | gin_channels: ${gin_channels}
209 | use_stoch_dur_pred: ${use_stoch_dur_pred}
210 |
211 | decoder:
212 | _target_: modules.glow_tts_with_pitch.FlowSpecDecoder
213 | in_channels: ${model.n_mel_channels}
214 | hidden_channels: 192 #256 if using stoch_dur_pred
215 | kernel_size: 5
216 | n_blocks: 12
217 | n_layers: 4
218 | n_sqz: 2
219 | n_split: 4
220 | sigmoid_scale: false
221 | p_dropout: 0.05
222 | dilation_rate: 1
223 | gin_channels: ${gin_channels}
224 |
225 | pitch_predictor:
226 | _target_: modules.glow_tts_modules.StochasticPitchPredictor
227 | in_channels: 192 #80 #192
228 | filter_channels: 256
229 | kernel_size: 3
230 | p_dropout: 0.1
231 | n_flows: 4
232 | gin_channels: ${gin_channels}
233 |
234 | optim:
235 | name: radam
236 | lr: 2e-4
237 | # optimizer arguments
238 | betas: [0.9, 0.98]
239 | weight_decay: 0.0
240 |
241 | # scheduler setup
242 | sched:
243 | name: CosineAnnealing
244 |
245 | # Scheduler params
246 | warmup_steps: 6000
247 | min_lr: 1e-5
248 | last_epoch: -1
249 |
250 |
251 | trainer:
252 | accelerator: auto
253 | devices: -1 # number of gpus
254 | strategy: ddp
255 | num_nodes: 1
256 | enable_checkpointing: false # Provided by exp_manager
257 | logger: false # Provided by exp_manager
258 | max_epochs: 1000
259 | max_steps: -1 # computed at runtime if not set
260 | accumulate_grad_batches: 2
261 | log_every_n_steps: 100 # Interval of logging.
262 | check_val_every_n_epoch: 2
263 | amp_backend: native
264 | precision: 16 # mixed-precision training
265 | gradient_clip_val: 5.0
266 |
267 | exp_manager:
268 | exp_dir: null
269 | name: ${name}
270 | resume_if_exists: False
271 | resume_ignore_no_checkpoint: True
272 | create_tensorboard_logger: True
273 | create_checkpoint_callback: True
274 | checkpoint_callback_params:
275 | always_save_nemo: True
276 | save_top_k: 1
277 | monitor: "val_loss"
278 | mode: "min"
279 | create_early_stopping_callback: False
280 | early_stopping_params:
281 | monitor: "val_loss"
282 | patience: 10
283 | verbose: True
284 | mode: "min"
285 | create_wandb_logger: False
286 | wandb_logger_kwargs:
287 | name: null
288 | project: null
289 |
290 | hydra:
291 | run:
292 | dir: .
293 | job_logging:
294 | root:
295 | handlers: null
--------------------------------------------------------------------------------
/conf/hifigan16k_ft.yaml:
--------------------------------------------------------------------------------
1 | name: "HifiGAN_CV_16k"
2 | train_dataset: ???
3 | validation_datasets: ???
4 |
5 | defaults:
6 | - model/generator: v1
7 | - model/train_ds: train_ds
8 | - model/validation_ds: val_ds
9 |
10 | model:
11 | preprocessor:
12 | _target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures
13 | dither: 0.0
14 | frame_splicing: 1
15 | nfilt: 80
16 | highfreq: 8000
17 | log: true
18 | log_zero_guard_type: clamp
19 | log_zero_guard_value: 1e-05
20 | lowfreq: 0
21 | mag_power: 1.0
22 | n_fft: 1024
23 | n_window_size: 1024
24 | n_window_stride: 256
25 | normalize: null
26 | pad_to: 0
27 | pad_value: -11.52
28 | preemph: null
29 | sample_rate: 16000
30 | window: hann
31 | use_grads: false
32 | exact_pad: true
33 |
34 | optim:
35 | _target_: torch.optim.AdamW
36 | lr: 0.0002
37 | betas: [0.8, 0.99]
38 |
39 | sched:
40 | name: CosineAnnealing
41 | min_lr: 1e-8 # 1e-5
42 | warmup_ratio: 0.02
43 |
44 | max_steps: 2500000
45 | l1_loss_factor: 45
46 | denoise_strength: 0.0025
47 |
48 | trainer:
49 | accelerator: auto
50 | devices: -1 # number of gpus
51 | strategy: ddp
52 | num_nodes: 1
53 | enable_checkpointing: false # Provided by exp_manager
54 | logger: false # Provided by exp_manager
55 | max_steps: ${model.max_steps}
56 | accumulate_grad_batches: 1
57 | log_every_n_steps: 100
58 | check_val_every_n_epoch: 1
59 |
60 | exp_manager:
61 | exp_dir: null
62 | name: ${name}
63 | create_tensorboard_logger: True
64 | create_checkpoint_callback: True
65 | checkpoint_callback_params:
66 | save_top_k: 1
67 | monitor: "val_loss"
68 | mode: "min"
69 | early_stopping_params:
70 | monitor: "val_loss"
71 | patience: 10
72 | verbose: True
73 | mode: "min"
74 |
75 | hydra:
76 | run:
77 | dir: .
78 | job_logging:
79 | root:
80 | handlers: null
81 |
82 |
--------------------------------------------------------------------------------
/glow_tts_with_pitch.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import pytorch_lightning as pl
16 |
17 | from nemo.collections.common.callbacks import LogEpochTimeCallback
18 | from nemo.core.config import hydra_runner
19 | from nemo.utils.exp_manager import exp_manager
20 |
21 | from models.glow_tts_with_pitch import GlowTTSModel
22 |
23 | @hydra_runner(config_path="conf", config_name="glow_tts")
24 | def main(cfg):
25 | trainer = pl.Trainer(**cfg.trainer)
26 | exp_manager(trainer, cfg.get("exp_manager", None))
27 | model = GlowTTSModel(cfg=cfg.model, trainer=trainer)
28 | model.maybe_init_from_pretrained_checkpoint(cfg=cfg)
29 | lr_logger = pl.callbacks.LearningRateMonitor()
30 | epoch_time_logger = LogEpochTimeCallback()
31 | trainer.callbacks.extend([lr_logger, epoch_time_logger])
32 | trainer.fit(model)
33 |
34 |
35 | if __name__ == "__main__":
36 | main() # noqa pylint: disable=no-value-for-parameter
--------------------------------------------------------------------------------
/media/glow_tts_stdp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ogunlao/glowtts_stdp/07f71bbfce405018f43db7c086f46c6b91defa28/media/glow_tts_stdp.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ogunlao/glowtts_stdp/07f71bbfce405018f43db7c086f46c6b91defa28/models/__init__.py
--------------------------------------------------------------------------------
/models/glow_tts_with_pitch.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import contextlib
15 | import math
16 | from copy import deepcopy
17 | from dataclasses import dataclass
18 | from typing import Any, Dict, Optional
19 |
20 | import torch
21 | import torch.utils.data
22 | from hydra.utils import instantiate
23 | from omegaconf import MISSING, DictConfig, OmegaConf
24 | from pytorch_lightning import Trainer
25 | from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
26 |
27 | from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
28 | from utils.helpers import (
29 | log_audio_to_tb,
30 | plot_alignment_to_numpy,
31 | plot_spectrogram_to_numpy,
32 | process_batch,
33 | )
34 | from utils.glow_tts_loss import GlowTTSLoss
35 | from utils.data import load_speaker_emb
36 | from nemo.collections.tts.losses.fastpitchloss import PitchLoss
37 | from nemo.collections.tts.models.base import SpectrogramGenerator
38 |
39 | from modules.glow_tts_with_pitch import GlowTTSModule
40 | from modules.glow_tts_modules.glow_tts_submodules import sequence_mask
41 |
42 | from nemo.core.classes.common import PretrainedModelInfo, typecheck
43 | from nemo.core.neural_types.elements import (
44 | AcousticEncodedRepresentation,
45 | LengthsType,
46 | MelSpectrogramType,
47 | TokenIndex,
48 | RegressionValuesType,
49 | )
50 | from nemo.core.neural_types.neural_type import NeuralType
51 | from nemo.utils import logging
52 |
53 |
54 | @dataclass
55 | class GlowTTSConfig:
56 | encoder: Dict[Any, Any] = MISSING
57 | decoder: Dict[Any, Any] = MISSING
58 | parser: Dict[Any, Any] = MISSING
59 | preprocessor: Dict[Any, Any] = MISSING
60 | train_ds: Optional[Dict[Any, Any]] = None
61 | validation_ds: Optional[Dict[Any, Any]] = None
62 | test_ds: Optional[Dict[Any, Any]] = None
63 |
64 |
65 | class GlowTTSModel(SpectrogramGenerator):
66 | """
67 | GlowTTS model used to generate spectrograms from text
68 | Consists of a text encoder and an invertible spectrogram decoder
69 | """
70 |
71 | def __init__(self, cfg: DictConfig, trainer: Trainer = None):
72 | if isinstance(cfg, dict):
73 | cfg = OmegaConf.create(cfg)
74 |
75 | # Setup normalizer
76 | self.normalizer = None
77 | self.text_normalizer_call = None
78 | self.text_normalizer_call_kwargs = {}
79 | self._setup_normalizer(cfg)
80 |
81 | # Setup vocabulary (=tokenizer) and input_fft_kwargs (supported only with self.learn_alignment=True)
82 | input_fft_kwargs = {}
83 |
84 | self.vocab = None
85 |
86 | self.ds_class_name = cfg.train_ds.dataset._target_.split(".")[-1]
87 |
88 | if self.ds_class_name == "TTSDataset":
89 | self._setup_tokenizer(cfg)
90 | assert self.vocab is not None
91 | input_fft_kwargs["n_embed"] = len(self.vocab.tokens)
92 | input_fft_kwargs["padding_idx"] = self.vocab.pad
93 |
94 | self._parser = None
95 |
96 | super().__init__(cfg=cfg, trainer=trainer)
97 |
98 | schema = OmegaConf.structured(GlowTTSConfig)
99 | # ModelPT ensures that cfg is a DictConfig, but do this second check in case ModelPT changes
100 | if isinstance(cfg, dict):
101 | cfg = OmegaConf.create(cfg)
102 | elif not isinstance(cfg, DictConfig):
103 | raise ValueError(f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig")
104 | # Ensure passed cfg is compliant with schema
105 | OmegaConf.merge(cfg, schema)
106 |
107 | self.preprocessor = instantiate(self._cfg.preprocessor)
108 |
109 | # if self._cfg.parser.get("add_blank"):
110 | self._cfg.encoder.n_vocab = len(self.vocab.tokens)
111 | encoder = instantiate(self._cfg.encoder)
112 | decoder = instantiate(self._cfg.decoder)
113 | pitch_stats = None
114 | unvoiced_value = 0.0
115 | if "pitch_predictor" in self._cfg:
116 | pitch_predictor = instantiate(self._cfg.pitch_predictor)
117 | # inject into encoder
118 | if self._cfg.get("unvoiced_value") is not None:
119 | unvoiced_value = self._cfg.get("unvoiced_value")
120 | elif self._cfg.get("use_normalized_pitch"):
121 | # compute unoviced value
122 |
123 | pitch_mean, pitch_std = self._cfg.get("pitch_mean", 0.0), self._cfg.get("pitch_std", 1.0)
124 | pitch_mean = pitch_mean or 0.0
125 | pitch_std = pitch_std or 1.0
126 | unvoiced_value = -pitch_mean/pitch_std
127 | pitch_stats = {"pitch_mean": pitch_mean, "pitch_std": pitch_std,
128 | # "pitch_fmin": self._cfg.get("pitch_fmin"),
129 | }
130 | else:
131 | pitch_predictor = None
132 | self.glow_tts = GlowTTSModule(
133 | encoder,
134 | decoder,
135 | pitch_predictor,
136 | n_speakers=cfg.n_speakers,
137 | gin_channels=cfg.gin_channels,
138 | use_external_speaker_emb=cfg.get("use_external_speaker_emb", False),
139 | use_stoch_dur_pred=cfg.get("use_stoch_dur_pred", False),
140 | use_stoch_pitch_pred=cfg.get("use_stoch_pitch_pred", False),
141 | unvoiced_value = unvoiced_value,
142 | use_log_pitch = cfg.get("use_log_pitch", False),
143 | use_normalized_pitch = cfg.get("use_normalized_pitch", False),
144 | use_frame_emb_for_pitch = cfg.get("use_frame_emb_for_pitch", False),
145 | pitch_stats=pitch_stats,
146 | )
147 | self.loss = GlowTTSLoss()
148 | if pitch_predictor is not None:
149 | self.pitch_loss_scale = cfg.get("pitch_loss_scale", 0.1)
150 | self.pitch_loss_fn = PitchLoss(loss_scale=self.pitch_loss_scale)
151 |
152 | else:
153 | self.pitch_loss_scale = 0.0
154 | self.pitch_loss_fn = None
155 |
156 |
157 | def parse(self, str_input: str, normalize=True) -> torch.tensor:
158 | if str_input[-1] not in [".", "!", "?"]:
159 | str_input = str_input + "."
160 |
161 | if self.training:
162 | logging.warning("parse() is meant to be called in eval mode.")
163 |
164 | if normalize and self.text_normalizer_call is not None:
165 | str_input = self.text_normalizer_call(str_input, **self.text_normalizer_call_kwargs)
166 |
167 | eval_phon_mode = contextlib.nullcontext()
168 | if hasattr(self.vocab, "set_phone_prob"):
169 | eval_phon_mode = self.vocab.set_phone_prob(prob=1.0)
170 |
171 | # Disable mixed g2p representation if necessary
172 | with eval_phon_mode:
173 | tokens = self.parser(str_input)
174 |
175 | x = torch.tensor(tokens).unsqueeze_(0).long().to(self.device)
176 | return x
177 |
178 | @property
179 | def parser(self):
180 | if self._parser is not None:
181 | return self._parser
182 |
183 | ds_class_name = self._cfg.train_ds.dataset._target_.split(".")[-1]
184 |
185 | if ds_class_name == "TTSDataset":
186 | self._parser = self.vocab.encode
187 | else:
188 | raise ValueError(f"Unknown dataset class: {ds_class_name}")
189 |
190 | return self._parser
191 |
192 | @typecheck(
193 | input_types={
194 | "x": NeuralType(("B", "T"), TokenIndex()),
195 | "x_lengths": NeuralType(("B"), LengthsType()),
196 | "y": NeuralType(("B", "D", "T"), MelSpectrogramType(), optional=True),
197 | "y_lengths": NeuralType(("B"), LengthsType(), optional=True),
198 | "gen": NeuralType(optional=True),
199 | "noise_scale": NeuralType(optional=True),
200 | "length_scale": NeuralType(optional=True),
201 | "speaker": NeuralType(("B"), TokenIndex(), optional=True),
202 | "speaker_embeddings": NeuralType(
203 | ("B", "D"), AcousticEncodedRepresentation(), optional=True
204 | ),
205 | "stoch_dur_noise_scale": NeuralType(optional=True),
206 | "stoch_pitch_noise_scale": NeuralType(optional=True),
207 | "pitch_scale": NeuralType(optional=True),
208 | "pitch": NeuralType(('B', 'T_audio'), RegressionValuesType()),
209 | }
210 | )
211 |
212 | def forward(
213 | self,
214 | *,
215 | x,
216 | x_lengths,
217 | y=None,
218 | y_lengths=None,
219 | speaker=None,
220 | gen=False,
221 | noise_scale=0.0,
222 | length_scale=1.0,
223 | speaker_embeddings=None,
224 | stoch_dur_noise_scale=1.0,
225 | stoch_pitch_noise_scale=1.0,
226 | pitch_scale=0.0,
227 | pitch=None,
228 | ):
229 | if gen:
230 | return self.glow_tts.generate_spect(
231 | text=x,
232 | text_lengths=x_lengths,
233 | noise_scale=noise_scale,
234 | length_scale=length_scale,
235 | speaker=speaker,
236 | speaker_embeddings=speaker_embeddings,
237 | stoch_dur_noise_scale=stoch_dur_noise_scale,
238 | stoch_pitch_noise_scale=stoch_pitch_noise_scale,
239 | pitch_scale=pitch_scale,
240 | )
241 | else:
242 | return self.glow_tts(
243 | text=x,
244 | text_lengths=x_lengths,
245 | spect=y,
246 | spect_lengths=y_lengths,
247 | speaker=speaker,
248 | speaker_embeddings=speaker_embeddings,
249 | pitch=pitch,
250 | )
251 |
252 | def step(
253 | self,
254 | y,
255 | y_lengths,
256 | x,
257 | x_lengths,
258 | speaker,
259 | speaker_embeddings,
260 | pitch,
261 | ):
262 |
263 | z, y_m, y_logs, logdet, logw, logw_, y_lengths, attn, stoch_dur_loss, pitch, pitch_pred, pitch_loss = self(
264 | x=x,
265 | x_lengths=x_lengths,
266 | y=y,
267 | y_lengths=y_lengths,
268 | speaker=speaker,
269 | speaker_embeddings=speaker_embeddings,
270 | pitch=pitch,
271 | )
272 |
273 | l_mle, l_length, logdet = self.loss(
274 | z=z,
275 | y_m=y_m,
276 | y_logs=y_logs,
277 | logdet=logdet,
278 | logw=logw,
279 | logw_=logw_,
280 | x_lengths=x_lengths,
281 | y_lengths=y_lengths,
282 | stoch_dur_loss=stoch_dur_loss,
283 | )
284 |
285 | if self.pitch_loss_fn is not None:
286 | if pitch_loss is None:
287 | pitch_loss = self.pitch_loss_fn(pitch_predicted=pitch_pred, pitch_tgt=pitch, len=x_lengths)
288 | else:
289 | pitch_loss = self.pitch_loss_scale*pitch_loss
290 |
291 | if pitch_loss is None:
292 | loss = sum([l_mle, l_length])
293 | pitch_loss = torch.tensor([0.0]).to(device=l_mle.device)
294 | else:
295 | loss = sum([l_mle, l_length, pitch_loss])
296 |
297 |
298 | return l_mle, l_length, logdet, loss, attn, pitch_loss
299 |
300 | def training_step(self, batch, batch_idx):
301 |
302 | batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set)
303 | y = batch_dict.get("audio")
304 | y_lengths = batch_dict.get("audio_lens")
305 | x = batch_dict.get("text")
306 | x_lengths = batch_dict.get("text_lens")
307 | attn_prior = batch_dict.get("align_prior_matrix", None)
308 | pitch = batch_dict.get("pitch", None)
309 | energy = batch_dict.get("energy", None)
310 | speaker = batch_dict.get("speaker_id", None)
311 | speaker_embeddings = batch_dict.get("speaker_emb", None)
312 |
313 | y, y_lengths = self.preprocessor(input_signal=y, length=y_lengths)
314 |
315 | l_mle, l_length, logdet, loss, _, pitch_loss = self.step(
316 | y, y_lengths, x, x_lengths, speaker, speaker_embeddings, pitch
317 | )
318 |
319 | output = {
320 | "loss": loss, # required
321 | "progress_bar": {"l_mle": l_mle, "l_length": l_length, "logdet": logdet},
322 | "log": {"loss": loss, "l_mle": l_mle, "l_length": l_length, "logdet": logdet, "pitch_loss": pitch_loss},
323 | }
324 |
325 | return output
326 |
327 | @torch.no_grad()
328 | def compute_likelihood(self, batch_dict):
329 | y = batch_dict.get("audio")
330 | y_lengths = batch_dict.get("audio_lens")
331 | x = batch_dict.get("text")
332 | x_lengths = batch_dict.get("text_lens")
333 | attn_prior = batch_dict.get("align_prior_matrix", None)
334 | pitch = batch_dict.get("pitch", None)
335 | energy = batch_dict.get("energy", None)
336 | speaker = batch_dict.get("speaker_id", None)
337 | speaker_embeddings = batch_dict.get("speaker_emb", None)
338 |
339 | y, y_lengths, x, x_lengths, speaker_embeddings = (
340 | y.to(self.device),
341 | y_lengths.to(self.device),
342 | x.to(self.device),
343 | x_lengths.to(self.device),
344 | speaker_embeddings.to(self.device),
345 | )
346 | y, y_lengths = self.preprocessor(input_signal=y, length=y_lengths)
347 |
348 | z, y_m, y_logs, logdet, logw, logw_, y_lengths, attn, stoch_dur_loss, pitch, pitch_pred, pitch_loss = self(
349 | x=x,
350 | x_lengths=x_lengths,
351 | y=y,
352 | y_lengths=y_lengths,
353 | speaker=speaker,
354 | speaker_embeddings=speaker_embeddings,
355 | pitch=pitch,
356 | )
357 |
358 | l_mle_normal = 0.5 * math.log(2 * math.pi) + (
359 | torch.sum(y_logs) + 0.5 * torch.sum(torch.exp(-2 * y_logs) * (z - y_m) ** 2)
360 | ) / (torch.sum(y_lengths) * z.shape[1])
361 |
362 | return l_mle_normal
363 |
364 | def validation_step(self, batch, batch_idx):
365 |
366 | batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set)
367 | y = batch_dict.get("audio")
368 | y_lengths = batch_dict.get("audio_lens")
369 | x = batch_dict.get("text")
370 | x_lengths = batch_dict.get("text_lens")
371 | attn_prior = batch_dict.get("align_prior_matrix", None)
372 | pitch = batch_dict.get("pitch", None)
373 | energy = batch_dict.get("energy", None)
374 | speaker = batch_dict.get("speaker_id", None)
375 | speaker_embeddings = batch_dict.get("speaker_emb", None)
376 |
377 | y, y_lengths = self.preprocessor(input_signal=y, length=y_lengths)
378 |
379 | l_mle, l_length, logdet, loss, attn, pitch_loss = self.step(
380 | y,
381 | y_lengths,
382 | x,
383 | x_lengths,
384 | speaker,
385 | speaker_embeddings,
386 | pitch,
387 | )
388 |
389 | y_gen, attn_gen = self(
390 | x=x,
391 | x_lengths=x_lengths,
392 | gen=True,
393 | speaker=speaker,
394 | speaker_embeddings=speaker_embeddings,
395 | pitch=None, # use predicted pitch
396 | noise_scale=0.667,
397 | )
398 |
399 | return {
400 | "loss": loss,
401 | "l_mle": l_mle,
402 | "l_length": l_length,
403 | "logdet": logdet,
404 | "y": y,
405 | "y_gen": y_gen,
406 | "x": x,
407 | "attn": attn,
408 | "attn_gen": attn_gen,
409 | "progress_bar": {"l_mle": l_mle, "l_length": l_length, "logdet": logdet},
410 | "pitch_loss": pitch_loss,
411 | }
412 |
413 | def validation_epoch_end(self, outputs):
414 | avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
415 | avg_mle = torch.stack([x["l_mle"] for x in outputs]).mean()
416 | avg_length_loss = torch.stack([x["l_length"] for x in outputs]).mean()
417 | avg_logdet = torch.stack([x["logdet"] for x in outputs]).mean()
418 | avg_pitch_loss = torch.stack([x["pitch_loss"] for x in outputs]).mean()
419 | tensorboard_logs = {
420 | "val_loss": avg_loss,
421 | "val_mle": avg_mle,
422 | "val_length_loss": avg_length_loss,
423 | "val_logdet": avg_logdet,
424 | "val_pitch_loss": avg_pitch_loss,
425 | }
426 | if self.logger is not None and self.logger.experiment is not None:
427 | tb_logger = self.logger.experiment
428 | if isinstance(self.logger, LoggerCollection):
429 | for logger in self.logger:
430 | if isinstance(logger, TensorBoardLogger):
431 | tb_logger = logger.experiment
432 | break
433 |
434 | separated_tokens = self.vocab.decode(outputs[0]["x"][0])
435 |
436 | tb_logger.add_text("separated tokens", separated_tokens, self.global_step)
437 | tb_logger.add_image(
438 | "real_spectrogram",
439 | plot_spectrogram_to_numpy(outputs[0]["y"][0].data.cpu().numpy()),
440 | self.global_step,
441 | dataformats="HWC",
442 | )
443 | tb_logger.add_image(
444 | "generated_spectrogram",
445 | plot_spectrogram_to_numpy(outputs[0]["y_gen"][0].data.cpu().numpy()),
446 | self.global_step,
447 | dataformats="HWC",
448 | )
449 | tb_logger.add_image(
450 | "alignment_for_real_sp",
451 | plot_alignment_to_numpy(outputs[0]["attn"][0].data.cpu().numpy()),
452 | self.global_step,
453 | dataformats="HWC",
454 | )
455 | tb_logger.add_image(
456 | "alignment_for_generated_sp",
457 | plot_alignment_to_numpy(outputs[0]["attn_gen"][0].data.cpu().numpy()),
458 | self.global_step,
459 | dataformats="HWC",
460 | )
461 | log_audio_to_tb(tb_logger, outputs[0]["y"][0], "true_audio_gf", self.global_step)
462 | log_audio_to_tb(
463 | tb_logger, outputs[0]["y_gen"][0], "generated_audio_gf", self.global_step
464 | )
465 | self.log("val_loss", avg_loss)
466 | return {"val_loss": avg_loss, "log": tensorboard_logs}
467 |
468 | def _setup_normalizer(self, cfg):
469 | if "text_normalizer" in cfg:
470 | normalizer_kwargs = {}
471 |
472 | if "whitelist" in cfg.text_normalizer:
473 | normalizer_kwargs["whitelist"] = self.register_artifact(
474 | 'text_normalizer.whitelist', cfg.text_normalizer.whitelist
475 | )
476 |
477 | try:
478 | self.normalizer = instantiate(cfg.text_normalizer, **normalizer_kwargs)
479 | except Exception as e:
480 | logging.error(e)
481 | raise ImportError(
482 | "`pynini` not installed, please install via NeMo/nemo_text_processing/pynini_install.sh"
483 | )
484 |
485 | self.text_normalizer_call = self.normalizer.normalize
486 | if "text_normalizer_call_kwargs" in cfg:
487 | self.text_normalizer_call_kwargs = cfg.text_normalizer_call_kwargs
488 |
489 | def _setup_tokenizer(self, cfg):
490 | text_tokenizer_kwargs = {}
491 |
492 | if "phoneme_dict" in cfg.text_tokenizer:
493 | text_tokenizer_kwargs["phoneme_dict"] = self.register_artifact(
494 | "text_tokenizer.phoneme_dict", cfg.text_tokenizer.phoneme_dict,
495 | )
496 | if "heteronyms" in cfg.text_tokenizer:
497 | text_tokenizer_kwargs["heteronyms"] = self.register_artifact(
498 | "text_tokenizer.heteronyms", cfg.text_tokenizer.heteronyms,
499 | )
500 |
501 | if "g2p" in cfg.text_tokenizer:
502 | g2p_kwargs = {}
503 |
504 | if "phoneme_dict" in cfg.text_tokenizer.g2p:
505 | g2p_kwargs["phoneme_dict"] = self.register_artifact(
506 | 'text_tokenizer.g2p.phoneme_dict', cfg.text_tokenizer.g2p.phoneme_dict,
507 | )
508 |
509 | if "heteronyms" in cfg.text_tokenizer.g2p:
510 | g2p_kwargs["heteronyms"] = self.register_artifact(
511 | 'text_tokenizer.g2p.heteronyms', cfg.text_tokenizer.g2p.heteronyms,
512 | )
513 |
514 | text_tokenizer_kwargs["g2p"] = instantiate(cfg.text_tokenizer.g2p, **g2p_kwargs)
515 |
516 | self.vocab = instantiate(cfg.text_tokenizer, **text_tokenizer_kwargs)
517 |
518 | def _setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, name: str = "train"):
519 | if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
520 | raise ValueError(f"No dataset for {name}")
521 | if "dataloader_params" not in cfg or not isinstance(cfg.dataloader_params, DictConfig):
522 | raise ValueError(f"No dataloder_params for {name}")
523 | if shuffle_should_be:
524 | if 'shuffle' not in cfg.dataloader_params:
525 | logging.warning(
526 | f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
527 | "config. Manually setting to True"
528 | )
529 | with open_dict(cfg.dataloader_params):
530 | cfg.dataloader_params.shuffle = True
531 | elif not cfg.dataloader_params.shuffle:
532 | logging.error(f"The {name} dataloader for {self} has shuffle set to False!!!")
533 | elif cfg.dataloader_params.shuffle:
534 | logging.error(f"The {name} dataloader for {self} has shuffle set to True!!!")
535 |
536 | if cfg.dataset._target_ == "utils.data.TTSDataset":
537 | phon_mode = contextlib.nullcontext()
538 | if hasattr(self.vocab, "set_phone_prob"):
539 | phon_mode = self.vocab.set_phone_prob(prob=None if name == "val" else self.vocab.phoneme_probability)
540 |
541 | with phon_mode:
542 | print("I got here!!!!!!!", self.vocab)
543 | dataset = instantiate(
544 | cfg.dataset,
545 | text_normalizer=self.normalizer,
546 | text_normalizer_call_kwargs=self.text_normalizer_call_kwargs,
547 | text_tokenizer=self.vocab,
548 | )
549 | else:
550 | dataset = instantiate(cfg.dataset)
551 |
552 | return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params)
553 |
554 | def setup_training_data(self, train_data_config: Optional[DictConfig]):
555 | self._train_dl = self._setup_dataloader_from_config(cfg=train_data_config, name="train")
556 |
557 | def setup_validation_data(self, val_data_config: Optional[DictConfig]):
558 | self._validation_dl = self._setup_dataloader_from_config(cfg=val_data_config, shuffle_should_be=False, name="val")
559 |
560 | def setup_test_data(self, test_data_config: Optional[DictConfig]):
561 | self._test_dl = self._setup_dataloader_from_config(cfg=test_data_config, shuffle_should_be=False, name="test")
562 |
563 |
564 | def generate_spectrogram(
565 | self,
566 | tokens: "torch.tensor",
567 | noise_scale: float = 0.0,
568 | length_scale: float = 1.0,
569 | speaker: int = None,
570 | speaker_embeddings: "torch.tensor" = None,
571 | stoch_dur_noise_scale: float = 1.0,
572 | stoch_pitch_noise_scale: float = 1.0,
573 | pitch_scale: float = 0.0,
574 | ) -> torch.tensor:
575 |
576 | self.eval()
577 |
578 | token_len = torch.tensor([tokens.shape[1]]).to(self.device)
579 |
580 | if isinstance(speaker, int):
581 | speaker = torch.tensor([speaker]).to(self.device)
582 | else:
583 | speaker = None
584 |
585 | if speaker_embeddings is not None:
586 | speaker_embeddings = speaker_embeddings.to(self.device)
587 |
588 | spect, _ = self(
589 | x=tokens,
590 | x_lengths=token_len,
591 | speaker=speaker,
592 | gen=True,
593 | noise_scale=noise_scale,
594 | length_scale=length_scale,
595 | speaker_embeddings=speaker_embeddings,
596 | stoch_dur_noise_scale=stoch_dur_noise_scale,
597 | stoch_pitch_noise_scale=stoch_pitch_noise_scale,
598 | pitch_scale=pitch_scale,
599 | )
600 |
601 | return spect
602 |
603 | @torch.no_grad()
604 | def generate_spectrogram_with_mas(
605 | self,
606 | batch,
607 | noise_scale: float = 0.667,
608 | length_scale: float = 1.0,
609 | external_speaker: int = None,
610 | external_speaker_embeddings: "torch.tensor" = None,
611 | gen: bool = False,
612 | randomize_speaker=False,
613 | ) -> torch.tensor:
614 | """Forced aligned generation of synthetic mel-spectrogram of mel-spectrograms."""
615 | # y is audio or melspec, x is text token ids
616 | self.eval()
617 |
618 | batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set)
619 | y = batch_dict.get("audio")
620 | y_lengths = batch_dict.get("audio_lens")
621 | x = batch_dict.get("text")
622 | x_lengths = batch_dict.get("text_lens")
623 | attn_prior = batch_dict.get("align_prior_matrix", None)
624 | pitch = batch_dict.get("pitch", None)
625 | energy = batch_dict.get("energy", None)
626 | speaker = batch_dict.get("speaker_id", None)
627 | speaker_embeddings = batch_dict.get("speaker_emb", None)
628 |
629 |
630 | y, y_lengths = self.preprocessor(input_signal=y, length=y_lengths)
631 |
632 | if len(speaker_embeddings.shape) == 1:
633 | speaker_embeddings = speaker_embeddings.unsqueeze(0)
634 | speaker = None
635 | if speaker_embeddings is not None:
636 | speaker_embeddings = torch.nn.functional.normalize(speaker_embeddings).unsqueeze(-1).to(self.device)
637 |
638 | y, y_lengths, x, x_lengths, speaker_embeddings = (
639 | y.to(self.device),
640 | y_lengths.to(self.device),
641 | x.to(self.device),
642 | x_lengths.to(self.device),
643 | speaker_embeddings,
644 | )
645 | (z, y_m, y_logs, logdet, log_durs_predicted, log_durs_extracted, \
646 | spect_lengths, attn, stoch_dur_loss, pitch, pitch_pred, pitch_pred_loss
647 | ) = self(
648 | x=x,
649 | x_lengths=x_lengths,
650 | y=y,
651 | y_lengths=y_lengths,
652 | speaker=speaker,
653 | gen=gen,
654 | speaker_embeddings=speaker_embeddings.squeeze(-1),
655 | pitch=pitch,
656 | )
657 |
658 | y_max_length = z.size(2)
659 |
660 | y_mask = torch.unsqueeze(sequence_mask(spect_lengths, y_max_length), 1)
661 |
662 | # predicted aligned feature
663 | z = (y_m + torch.exp(y_logs) * torch.randn_like(y_m) * noise_scale) * y_mask
664 |
665 | if external_speaker_embeddings is not None:
666 | external_speaker_embeddings = torch.nn.functional.normalize(
667 | external_speaker_embeddings
668 | ).unsqueeze(-1)
669 |
670 | if randomize_speaker:
671 | # use a random speaker embedding to decode utterance
672 | idx = torch.randperm(speaker_embeddings.shape[0])
673 | speaker_embeddings = speaker_embeddings[idx].view(speaker_embeddings.size())
674 | if len(speaker_embeddings.shape) == 2:
675 | speaker_embeddings = speaker_embeddings.unsqueeze(-1).to(self.device)
676 |
677 | # invert with same or different speaker through the decoder
678 | y_pred, _ = self.glow_tts.decoder(
679 | spect=z.to(self.device),
680 | spect_mask=y_mask.to(self.device),
681 | speaker_embeddings=external_speaker_embeddings or speaker_embeddings,
682 | reverse=True, pitch=pitch,
683 | )
684 | return y_pred, spect_lengths
685 |
686 |
687 | @classmethod
688 | def list_available_models(cls) -> "List[PretrainedModelInfo]":
689 | """
690 | This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
691 | Returns:
692 | List of available pre-trained models.
693 | """
694 | list_of_models = []
695 | model = PretrainedModelInfo(
696 | pretrained_model_name="tts_en_glowtts",
697 | location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_glowtts/versions/1.0.0rc1/files/tts_en_glowtts.nemo",
698 | description="This model is trained on LJSpeech sampled at 22050Hz, and can be used to generate female English voices with an American accent.",
699 | class_=cls,
700 | aliases=["GlowTTS-22050Hz"],
701 | )
702 | list_of_models.append(model)
703 | return list_of_models
--------------------------------------------------------------------------------
/modules/glow_tts_modules/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from modules.glow_tts_modules.glow_tts_submodules_with_pitch import StochasticPitchPredictor
--------------------------------------------------------------------------------
/modules/glow_tts_modules/glow_tts_submodules.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # MIT License
16 | #
17 | # Copyright (c) 2020 Jaehyeon Kim
18 | #
19 | # Permission is hereby granted, free of charge, to any person obtaining a copy
20 | # of this software and associated documentation files (the "Software"), to deal
21 | # in the Software without restriction, including without limitation the rights
22 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
23 | # copies of the Software, and to permit persons to whom the Software is
24 | # furnished to do so, subject to the following conditions:
25 | #
26 | # The above copyright notice and this permission notice shall be included in all
27 | # copies or substantial portions of the Software.
28 | #
29 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
30 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
31 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
32 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
33 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
34 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
35 | # SOFTWARE.
36 |
37 | import math
38 |
39 | import numpy as np
40 | import torch
41 | from torch import nn
42 | from torch.nn import functional as F
43 |
44 | from modules.glow_tts_modules import stocpred_modules
45 |
46 |
47 | def convert_pad_shape(pad_shape):
48 | """
49 | Used to get arguments for F.pad
50 | """
51 | l = pad_shape[::-1]
52 | pad_shape = [item for sublist in l for item in sublist]
53 | return pad_shape
54 |
55 |
56 | def sequence_mask(length, max_length=None):
57 | """
58 | Get masks for given lengths
59 | """
60 | if max_length is None:
61 | max_length = length.max()
62 | x = torch.arange(max_length, dtype=length.dtype, device=length.device)
63 |
64 | return x.unsqueeze(0) < length.unsqueeze(1)
65 |
66 |
67 | def maximum_path(value, mask, max_neg_val=None):
68 | """
69 | Monotonic alignment search algorithm
70 | Numpy-friendly version. It's about 4 times faster than torch version.
71 | value: [b, t_x, t_y]
72 | mask: [b, t_x, t_y]
73 | """
74 | if max_neg_val is None:
75 | max_neg_val = -np.inf # Patch for Sphinx complaint
76 | value = value * mask
77 |
78 | device = value.device
79 | dtype = value.dtype
80 | value = value.cpu().detach().numpy()
81 | mask = mask.cpu().detach().numpy().astype(np.bool)
82 |
83 | b, t_x, t_y = value.shape
84 | direction = np.zeros(value.shape, dtype=np.int64)
85 | v = np.zeros((b, t_x), dtype=np.float32)
86 | x_range = np.arange(t_x, dtype=np.float32).reshape(1, -1)
87 | for j in range(t_y):
88 | v0 = np.pad(v, [[0, 0], [1, 0]], mode="constant", constant_values=max_neg_val)[:, :-1]
89 | v1 = v
90 | max_mask = v1 >= v0
91 | v_max = np.where(max_mask, v1, v0)
92 | direction[:, :, j] = max_mask
93 |
94 | index_mask = x_range <= j
95 | v = np.where(index_mask, v_max + value[:, :, j], max_neg_val)
96 | direction = np.where(mask, direction, 1)
97 |
98 | path = np.zeros(value.shape, dtype=np.float32)
99 | index = mask[:, :, 0].sum(1).astype(np.int64) - 1
100 | index_range = np.arange(b)
101 | try:
102 | for j in reversed(range(t_y)):
103 | path[index_range, index, j] = 1
104 | index = index + direction[index_range, index, j] - 1
105 | except Exception as e:
106 | print("index range", index_range)
107 | print(e)
108 | path = path * mask.astype(np.float32)
109 | path = torch.from_numpy(path).to(device=device, dtype=dtype)
110 | return path
111 |
112 |
113 | def generate_path(duration, mask):
114 | """
115 | Generate alignment based on predicted durations
116 | duration: [b, t_x]
117 | mask: [b, t_x, t_y]
118 | """
119 |
120 | b, t_x, t_y = mask.shape
121 | cum_duration = torch.cumsum(duration, 1)
122 |
123 | cum_duration_flat = cum_duration.view(b * t_x)
124 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
125 | path = path.view(b, t_x, t_y)
126 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
127 | path = path * mask
128 | return path
129 |
130 |
131 | @torch.jit.script
132 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
133 | n_channels_int = n_channels[0]
134 | in_act = input_a + input_b
135 | t_act = torch.tanh(in_act[:, :n_channels_int, :])
136 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
137 | acts = t_act * s_act
138 | return acts
139 |
140 |
141 | class LayerNorm(nn.Module):
142 | def __init__(self, channels, eps=1e-4):
143 | super().__init__()
144 | self.channels = channels
145 | self.eps = eps
146 |
147 | self.gamma = nn.Parameter(torch.ones(channels))
148 | self.beta = nn.Parameter(torch.zeros(channels))
149 |
150 | def forward(self, x):
151 | n_dims = len(x.shape)
152 | mean = torch.mean(x, 1, keepdim=True)
153 | variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
154 |
155 | x = (x - mean) * torch.rsqrt(variance + self.eps)
156 |
157 | shape = [1, -1] + [1] * (n_dims - 2)
158 | x = x * self.gamma.view(*shape) + self.beta.view(*shape)
159 | return x
160 |
161 |
162 | class ConvReluNorm(nn.Module):
163 | def __init__(
164 | self,
165 | in_channels,
166 | hidden_channels,
167 | out_channels,
168 | kernel_size,
169 | n_layers,
170 | p_dropout,
171 | ):
172 | super().__init__()
173 | self.in_channels = in_channels
174 | self.hidden_channels = hidden_channels
175 | self.out_channels = out_channels
176 | self.kernel_size = kernel_size
177 | self.n_layers = n_layers
178 | self.p_dropout = p_dropout
179 | assert n_layers > 1, "Number of layers should be larger than 0."
180 |
181 | self.conv_layers = nn.ModuleList()
182 | self.norm_layers = nn.ModuleList()
183 | self.conv_layers.append(
184 | nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
185 | )
186 | self.norm_layers.append(LayerNorm(hidden_channels))
187 | self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
188 | for _ in range(n_layers - 1):
189 | self.conv_layers.append(
190 | nn.Conv1d(
191 | hidden_channels,
192 | hidden_channels,
193 | kernel_size,
194 | padding=kernel_size // 2,
195 | )
196 | )
197 | self.norm_layers.append(LayerNorm(hidden_channels))
198 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
199 | self.proj.weight.data.zero_()
200 | self.proj.bias.data.zero_()
201 |
202 | def forward(self, x, x_mask):
203 | x_org = x
204 | for i in range(self.n_layers):
205 | x = self.conv_layers[i](x * x_mask)
206 | x = self.norm_layers[i](x)
207 | x = self.relu_drop(x)
208 | x = x_org + self.proj(x)
209 | return x * x_mask
210 |
211 |
212 | class WN(torch.nn.Module):
213 | def __init__(
214 | self, hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=0
215 | ):
216 | super(WN, self).__init__()
217 | assert kernel_size % 2 == 1
218 | assert hidden_channels % 2 == 0
219 |
220 | self.hidden_channels = hidden_channels
221 | self.n_layers = n_layers
222 |
223 | self.in_layers = torch.nn.ModuleList()
224 | self.res_skip_layers = torch.nn.ModuleList()
225 | self.drop = nn.Dropout(p_dropout)
226 |
227 | if gin_channels != 0:
228 | cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
229 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
230 |
231 | for i in range(n_layers):
232 | dilation = dilation_rate ** i
233 | padding = int((kernel_size * dilation - dilation) / 2)
234 | in_layer = torch.nn.Conv1d(
235 | hidden_channels,
236 | 2 * hidden_channels,
237 | kernel_size,
238 | dilation=dilation,
239 | padding=padding,
240 | )
241 | in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
242 | self.in_layers.append(in_layer)
243 |
244 | # last one is not necessary
245 | if i < n_layers - 1:
246 | res_skip_channels = 2 * hidden_channels
247 | else:
248 | res_skip_channels = hidden_channels
249 |
250 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
251 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
252 | self.res_skip_layers.append(res_skip_layer)
253 |
254 | def forward(self, x, x_mask=None, g=None, **kwargs):
255 | output = torch.zeros_like(x)
256 | n_channels_tensor = torch.IntTensor([self.hidden_channels])
257 |
258 | if g is not None:
259 | g = self.cond_layer(g)
260 |
261 | for i in range(self.n_layers):
262 | x_in = self.in_layers[i](x)
263 | x_in = self.drop(x_in)
264 | if g is not None:
265 | cond_offset = i * 2 * self.hidden_channels
266 | g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
267 | else:
268 | g_l = torch.zeros_like(x_in)
269 |
270 | acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
271 |
272 | res_skip_acts = self.res_skip_layers[i](acts)
273 | if i < self.n_layers - 1:
274 | x = (x + res_skip_acts[:, : self.hidden_channels, :]) * x_mask
275 | output = output + res_skip_acts[:, self.hidden_channels :, :]
276 | else:
277 | output = output + res_skip_acts
278 | return output * x_mask
279 |
280 | def remove_weight_norm(self):
281 | for l in self.in_layers:
282 | torch.nn.utils.remove_weight_norm(l)
283 | for l in self.res_skip_layers:
284 | torch.nn.utils.remove_weight_norm(l)
285 |
286 |
287 | class ActNorm(nn.Module):
288 | def __init__(self, channels, ddi=False, **kwargs):
289 | super().__init__()
290 | self.channels = channels
291 | self.initialized = not ddi
292 |
293 | self.logs = nn.Parameter(torch.zeros(1, channels, 1))
294 | self.bias = nn.Parameter(torch.zeros(1, channels, 1))
295 |
296 | def forward(self, x, x_mask=None, reverse=False, **kwargs):
297 | if x_mask is None:
298 | x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype)
299 | x_len = torch.sum(x_mask, [1, 2])
300 | if not self.initialized:
301 | self.initialize(x, x_mask)
302 | self.initialized = True
303 |
304 | if reverse:
305 | z = (x - self.bias) * torch.exp(-self.logs) * x_mask
306 | logdet = None
307 | else:
308 | z = (self.bias + torch.exp(self.logs) * x) * x_mask
309 | logdet = torch.sum(self.logs) * x_len # [b]
310 |
311 | return z, logdet
312 |
313 | def store_inverse(self):
314 | pass
315 |
316 | def set_ddi(self, ddi):
317 | self.initialized = not ddi
318 |
319 | def initialize(self, x, x_mask):
320 | with torch.no_grad():
321 | denom = torch.sum(x_mask, [0, 2])
322 | m = torch.sum(x * x_mask, [0, 2]) / denom
323 | m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
324 | v = m_sq - (m ** 2)
325 | logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
326 |
327 | bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
328 | logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
329 |
330 | self.bias.data.copy_(bias_init)
331 | self.logs.data.copy_(logs_init)
332 |
333 |
334 | class InvConvNear(nn.Module):
335 | def __init__(self, channels, n_split=4, no_jacobian=False, **kwargs):
336 | super().__init__()
337 | assert n_split % 2 == 0
338 | self.channels = channels
339 | self.n_split = n_split
340 | self.no_jacobian = no_jacobian
341 |
342 | w_init = torch.linalg.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0]
343 | if torch.det(w_init) < 0:
344 | w_init[:, 0] = -1 * w_init[:, 0]
345 | self.weight = nn.Parameter(w_init)
346 |
347 | def forward(self, x, x_mask=None, reverse=False, **kwargs):
348 | b, c, t = x.size()
349 | assert c % self.n_split == 0
350 | if x_mask is None:
351 | x_mask = 1
352 | x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
353 | else:
354 | x_len = torch.sum(x_mask, [1, 2])
355 |
356 | x = x.view(b, 2, c // self.n_split, self.n_split // 2, t)
357 | x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.n_split, c // self.n_split, t)
358 |
359 | if reverse:
360 | if hasattr(self, "weight_inv"):
361 | weight = self.weight_inv
362 | else:
363 | weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
364 | logdet = None
365 | else:
366 | weight = self.weight
367 | if self.no_jacobian:
368 | logdet = 0
369 | else:
370 | logdet = torch.logdet(self.weight) * (c / self.n_split) * x_len # [b]
371 |
372 | weight = weight.view(self.n_split, self.n_split, 1, 1)
373 | z = F.conv2d(x, weight)
374 |
375 | z = z.view(b, 2, self.n_split // 2, c // self.n_split, t)
376 | z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask
377 | return z, logdet
378 |
379 | def store_inverse(self):
380 | self.weight_inv = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
381 |
382 |
383 | class CouplingBlock(nn.Module):
384 | def __init__(
385 | self,
386 | in_channels,
387 | hidden_channels,
388 | kernel_size,
389 | dilation_rate,
390 | n_layers,
391 | gin_channels=0,
392 | p_dropout=0,
393 | sigmoid_scale=False,
394 | ):
395 | super().__init__()
396 | self.in_channels = in_channels
397 | self.hidden_channels = hidden_channels
398 | self.kernel_size = kernel_size
399 | self.dilation_rate = dilation_rate
400 | self.n_layers = n_layers
401 | self.gin_channels = gin_channels
402 | self.p_dropout = p_dropout
403 | self.sigmoid_scale = sigmoid_scale
404 |
405 | start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
406 | start = torch.nn.utils.weight_norm(start)
407 | self.start = start
408 | # Initializing last layer to 0 makes the affine coupling layers
409 | # do nothing at first. This helps with training stability
410 | end = torch.nn.Conv1d(hidden_channels, in_channels, 1)
411 | end.weight.data.zero_()
412 | end.bias.data.zero_()
413 | self.end = end
414 |
415 | self.wn = WN(
416 | hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout, gin_channels
417 | )
418 |
419 | def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs):
420 | if x_mask is None:
421 | x_mask = 1
422 | x_0, x_1 = x[:, : self.in_channels // 2], x[:, self.in_channels // 2 :]
423 |
424 | x = self.start(x_0) * x_mask
425 | x = self.wn(x, x_mask, g)
426 | out = self.end(x)
427 |
428 | z_0 = x_0
429 | m = out[:, : self.in_channels // 2, :]
430 | logs = out[:, self.in_channels // 2 :, :]
431 | if self.sigmoid_scale:
432 | logs = torch.log(1e-6 + torch.sigmoid(logs + 2))
433 |
434 | if reverse:
435 | z_1 = (x_1 - m) * torch.exp(-logs) * x_mask
436 | logdet = None
437 | else:
438 | z_1 = (m + torch.exp(logs) * x_1) * x_mask
439 | logdet = torch.sum(logs * x_mask, [1, 2])
440 |
441 | z = torch.cat([z_0, z_1], 1)
442 | return z, logdet
443 |
444 | def store_inverse(self):
445 | self.wn.remove_weight_norm()
446 |
447 |
448 | class AttentionBlock(nn.Module):
449 | def __init__(
450 | self,
451 | channels,
452 | out_channels,
453 | n_heads,
454 | window_size=None,
455 | heads_share=True,
456 | p_dropout=0.0,
457 | block_length=None,
458 | proximal_bias=False,
459 | proximal_init=False,
460 | ):
461 | super().__init__()
462 | assert channels % n_heads == 0
463 |
464 | self.channels = channels
465 | self.out_channels = out_channels
466 | self.n_heads = n_heads
467 | self.window_size = window_size
468 | self.heads_share = heads_share
469 | self.block_length = block_length
470 | self.proximal_bias = proximal_bias
471 | self.p_dropout = p_dropout
472 | self.attn = None
473 |
474 | self.k_channels = channels // n_heads
475 | self.conv_q = nn.Conv1d(channels, channels, 1)
476 | self.conv_k = nn.Conv1d(channels, channels, 1)
477 | self.conv_v = nn.Conv1d(channels, channels, 1)
478 | if window_size is not None:
479 | n_heads_rel = 1 if heads_share else n_heads
480 | rel_stddev = self.k_channels ** -0.5
481 | self.emb_rel_k = nn.Parameter(
482 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev
483 | )
484 | self.emb_rel_v = nn.Parameter(
485 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev
486 | )
487 | self.conv_o = nn.Conv1d(channels, out_channels, 1)
488 | self.drop = nn.Dropout(p_dropout)
489 |
490 | nn.init.xavier_uniform_(self.conv_q.weight)
491 | nn.init.xavier_uniform_(self.conv_k.weight)
492 | if proximal_init:
493 | self.conv_k.weight.data.copy_(self.conv_q.weight.data)
494 | self.conv_k.bias.data.copy_(self.conv_q.bias.data)
495 | nn.init.xavier_uniform_(self.conv_v.weight)
496 |
497 | def forward(self, x, c, attn_mask=None):
498 | q = self.conv_q(x)
499 | k = self.conv_k(c)
500 | v = self.conv_v(c)
501 |
502 | x, self.attn = self.attention(q, k, v, mask=attn_mask)
503 |
504 | x = self.conv_o(x)
505 | return x
506 |
507 | def attention(self, query, key, value, mask=None):
508 | # reshape [b, d, t] -> [b, n_h, t, d_k]
509 | b, d, t_s, t_t = key.size(0), key.size(1), key.size(2), query.size(2)
510 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
511 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
512 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
513 |
514 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
515 | if self.window_size is not None:
516 | assert t_s == t_t, "Relative attention is only available for self-attention."
517 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
518 | rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
519 | rel_logits = self._relative_position_to_absolute_position(rel_logits)
520 | scores_local = rel_logits / math.sqrt(self.k_channels)
521 | scores = scores + scores_local
522 | if self.proximal_bias:
523 | assert t_s == t_t, "Proximal bias is only available for self-attention."
524 | scores = scores + self._attention_bias_proximal(t_s).to(
525 | device=scores.device, dtype=scores.dtype
526 | )
527 | if mask is not None:
528 | scores = scores.masked_fill(mask == 0, -1e4)
529 | if self.block_length is not None:
530 | block_mask = (
531 | torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
532 | )
533 | scores = scores * block_mask + -1e4 * (1 - block_mask)
534 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
535 | p_attn = self.drop(p_attn)
536 | output = torch.matmul(p_attn, value)
537 | if self.window_size is not None:
538 | relative_weights = self._absolute_position_to_relative_position(p_attn)
539 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
540 | output = output + self._matmul_with_relative_values(
541 | relative_weights, value_relative_embeddings
542 | )
543 | output = (
544 | output.transpose(2, 3).contiguous().view(b, d, t_t)
545 | ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
546 | return output, p_attn
547 |
548 | def _matmul_with_relative_values(self, x, y):
549 | """
550 | x: [b, h, l, m]
551 | y: [h or 1, m, d]
552 | ret: [b, h, l, d]
553 | """
554 | ret = torch.matmul(x, y.unsqueeze(0))
555 | return ret
556 |
557 | def _matmul_with_relative_keys(self, x, y):
558 | """
559 | x: [b, h, l, d]
560 | y: [h or 1, m, d]
561 | ret: [b, h, l, m]
562 | """
563 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
564 | return ret
565 |
566 | def _get_relative_embeddings(self, relative_embeddings, length):
567 | # Pad first before slice to avoid using cond ops.
568 | pad_length = max(length - (self.window_size + 1), 0)
569 | slice_start_position = max((self.window_size + 1) - length, 0)
570 | slice_end_position = slice_start_position + 2 * length - 1
571 | if pad_length > 0:
572 | padded_relative_embeddings = F.pad(
573 | relative_embeddings,
574 | convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
575 | )
576 | else:
577 | padded_relative_embeddings = relative_embeddings
578 | used_relative_embeddings = padded_relative_embeddings[
579 | :, slice_start_position:slice_end_position
580 | ]
581 | return used_relative_embeddings
582 |
583 | def _relative_position_to_absolute_position(self, x):
584 | """
585 | x: [b, h, l, 2*l-1]
586 | ret: [b, h, l, l]
587 | """
588 | batch, heads, length, _ = x.size()
589 | # Concat columns of pad to shift from relative to absolute indexing.
590 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
591 |
592 | # Concat extra elements so to add up to shape (len+1, 2*len-1).
593 | x_flat = x.view([batch, heads, length * 2 * length])
594 | x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
595 |
596 | # Reshape and slice out the padded elements.
597 | x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
598 | :, :, :length, length - 1 :
599 | ]
600 | return x_final
601 |
602 | def _absolute_position_to_relative_position(self, x):
603 | """
604 | x: [b, h, l, l]
605 | ret: [b, h, l, 2*l-1]
606 | """
607 | batch, heads, length, _ = x.size()
608 | # padd along column
609 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
610 | x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)])
611 | # add 0's in the beginning that will skew the elements after reshape
612 | x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
613 | x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
614 | return x_final
615 |
616 | def _attention_bias_proximal(self, length):
617 | """Bias for self-attention to encourage attention to close positions.
618 | Args:
619 | length: an integer scalar.
620 | Returns:
621 | a Tensor with shape [1, 1, length, length]
622 | """
623 | r = torch.arange(length, dtype=torch.float32)
624 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
625 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
626 |
627 |
628 | class FeedForwardNetwork(nn.Module):
629 | def __init__(
630 | self,
631 | in_channels,
632 | out_channels,
633 | filter_channels,
634 | kernel_size,
635 | p_dropout=0.0,
636 | activation=None,
637 | ):
638 | super().__init__()
639 | self.in_channels = in_channels
640 | self.out_channels = out_channels
641 | self.filter_channels = filter_channels
642 | self.kernel_size = kernel_size
643 | self.p_dropout = p_dropout
644 | self.activation = activation
645 |
646 | self.conv_1 = nn.Conv1d(
647 | in_channels, filter_channels, kernel_size, padding=kernel_size // 2
648 | )
649 | self.conv_2 = nn.Conv1d(
650 | filter_channels, out_channels, kernel_size, padding=kernel_size // 2
651 | )
652 | self.drop = nn.Dropout(p_dropout)
653 |
654 | def forward(self, x, x_mask):
655 | x = self.conv_1(x * x_mask)
656 | if self.activation == "gelu":
657 | x = x * torch.sigmoid(1.702 * x)
658 | else:
659 | x = torch.relu(x)
660 | x = self.drop(x)
661 | x = self.conv_2(x * x_mask)
662 | return x * x_mask
663 |
664 |
665 | class DurationPredictor(nn.Module):
666 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
667 | """
668 | Token duration predictor for the GlowTTS model.
669 | Takes in embeddings of the input tokens and predicts how many frames of
670 | mel-spectrogram are aligned to each text token.
671 | Architecture is the same as the duration predictor in FastSpeech.
672 | Args:
673 | in_channels: Number of channels for the token embeddings
674 | filter_channels: Number of channels in the intermediate layers
675 | kernel_size: Kernels size for the convolution layers
676 | p_dropout: Dropout probability
677 | """
678 |
679 | super().__init__()
680 |
681 | self.drop = nn.Dropout(p_dropout)
682 | self.conv_1 = nn.Conv1d(
683 | in_channels, filter_channels, kernel_size, padding=kernel_size // 2
684 | )
685 | self.norm_1 = LayerNorm(filter_channels)
686 | self.conv_2 = nn.Conv1d(
687 | filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
688 | )
689 | self.norm_2 = LayerNorm(filter_channels)
690 | self.proj = nn.Conv1d(filter_channels, 1, 1)
691 |
692 | def forward(self, spect, mask, **kwargs):
693 | x = self.conv_1(spect * mask)
694 | x = torch.relu(x)
695 | x = self.norm_1(x)
696 | x = self.drop(x)
697 | x = self.conv_2(x * mask)
698 | x = torch.relu(x)
699 | x = self.norm_2(x)
700 | x = self.drop(x)
701 | x = self.proj(x * mask)
702 | durs = x * mask
703 | return durs.squeeze(1)
704 |
705 |
706 | class StochasticDurationPredictor(nn.Module):
707 | """Borrowed from VITS"""
708 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
709 | super().__init__()
710 |
711 | filter_channels = in_channels # it needs to be removed from future version.
712 | self.in_channels = in_channels
713 | self.filter_channels = filter_channels
714 | self.kernel_size = kernel_size
715 | self.p_dropout = p_dropout
716 | self.n_flows = n_flows
717 | self.gin_channels = gin_channels
718 |
719 | self.log_flow = stocpred_modules.Log()
720 | self.flows = nn.ModuleList()
721 | self.flows.append(stocpred_modules.ElementwiseAffine(2))
722 | for i in range(n_flows):
723 | self.flows.append(stocpred_modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
724 | self.flows.append(stocpred_modules.Flip())
725 |
726 | self.post_pre = nn.Conv1d(1, filter_channels, 1)
727 | self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
728 | self.post_convs = stocpred_modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
729 | self.post_flows = nn.ModuleList()
730 | self.post_flows.append(stocpred_modules.ElementwiseAffine(2))
731 | for i in range(4):
732 | self.post_flows.append(stocpred_modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
733 | self.post_flows.append(stocpred_modules.Flip())
734 | self.pre = nn.Conv1d(in_channels, filter_channels, 1)
735 | self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
736 | self.convs = stocpred_modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
737 | if gin_channels != 0:
738 | self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
739 |
740 | def forward(self, spect, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
741 | x = torch.detach(spect)
742 | x = self.pre(x)
743 | if g is not None:
744 | g = torch.detach(g)
745 | x = x + self.cond(g)
746 | x = self.convs(x, x_mask)
747 | x = self.proj(x) * x_mask
748 | if not reverse:
749 | flows = self.flows
750 | assert w is not None
751 |
752 | logdet_tot_q = 0
753 | h_w = self.post_pre(w)
754 | h_w = self.post_convs(h_w, x_mask)
755 | h_w = self.post_proj(h_w) * x_mask
756 | e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
757 | z_q = e_q
758 | for flow in self.post_flows:
759 | z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
760 | logdet_tot_q += logdet_q
761 | z_u, z1 = torch.split(z_q, [1, 1], 1)
762 | u = torch.sigmoid(z_u) * x_mask
763 | z0 = (w - u) * x_mask
764 | logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2])
765 | logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q
766 |
767 | logdet_tot = 0
768 | z0, logdet = self.log_flow(z0, x_mask)
769 | logdet_tot += logdet
770 | z = torch.cat([z0, z1], 1)
771 | for flow in flows:
772 | z, logdet = flow(z, x_mask, g=x, reverse=reverse)
773 | logdet_tot = logdet_tot + logdet
774 | nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot
775 | stoch_dur_loss = nll + logq
776 | stoch_dur_loss = stoch_dur_loss / torch.sum(x_mask)
777 | stoch_dur_loss = torch.sum(stoch_dur_loss)
778 |
779 | return stoch_dur_loss, None # [b]
780 |
781 | else:
782 | flows = list(reversed(self.flows))
783 | flows = flows[:-2] + [flows[-1]] # remove a useless vflow
784 | z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
785 | for flow in flows:
786 | z = flow(z, x_mask, g=x, reverse=reverse)
787 | z0, z1 = torch.split(z, [1, 1], 1)
788 | logw = z0
789 | return None, logw
790 |
791 |
--------------------------------------------------------------------------------
/modules/glow_tts_modules/glow_tts_submodules_with_pitch.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # MIT License
16 | #
17 | # Copyright (c) 2020 Jaehyeon Kim
18 | #
19 | # Permission is hereby granted, free of charge, to any person obtaining a copy
20 | # of this software and associated documentation files (the "Software"), to deal
21 | # in the Software without restriction, including without limitation the rights
22 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
23 | # copies of the Software, and to permit persons to whom the Software is
24 | # furnished to do so, subject to the following conditions:
25 | #
26 | # The above copyright notice and this permission notice shall be included in all
27 | # copies or substantial portions of the Software.
28 | #
29 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
30 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
31 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
32 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
33 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
34 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
35 | # SOFTWARE.
36 |
37 | import math
38 |
39 | import numpy as np
40 | import torch
41 | from torch import nn
42 | from torch.nn import functional as F
43 |
44 | from modules.glow_tts_modules import stocpred_modules
45 |
46 |
47 | def convert_pad_shape(pad_shape):
48 | """
49 | Used to get arguments for F.pad
50 | """
51 | l = pad_shape[::-1]
52 | pad_shape = [item for sublist in l for item in sublist]
53 | return pad_shape
54 |
55 |
56 | def sequence_mask(length, max_length=None):
57 | """
58 | Get masks for given lengths
59 | """
60 | if max_length is None:
61 | max_length = length.max()
62 | x = torch.arange(max_length, dtype=length.dtype, device=length.device)
63 |
64 | return x.unsqueeze(0) < length.unsqueeze(1)
65 |
66 |
67 | def maximum_path(value, mask, max_neg_val=None):
68 | """
69 | Monotonic alignment search algorithm
70 | Numpy-friendly version. It's about 4 times faster than torch version.
71 | value: [b, t_x, t_y]
72 | mask: [b, t_x, t_y]
73 | """
74 | if max_neg_val is None:
75 | max_neg_val = -np.inf # Patch for Sphinx complaint
76 | value = value * mask
77 |
78 | device = value.device
79 | dtype = value.dtype
80 | value = value.cpu().detach().numpy()
81 | mask = mask.cpu().detach().numpy().astype(np.bool)
82 |
83 | b, t_x, t_y = value.shape
84 | direction = np.zeros(value.shape, dtype=np.int64)
85 | v = np.zeros((b, t_x), dtype=np.float32)
86 | x_range = np.arange(t_x, dtype=np.float32).reshape(1, -1)
87 | for j in range(t_y):
88 | v0 = np.pad(v, [[0, 0], [1, 0]], mode="constant", constant_values=max_neg_val)[:, :-1]
89 | v1 = v
90 | max_mask = v1 >= v0
91 | v_max = np.where(max_mask, v1, v0)
92 | direction[:, :, j] = max_mask
93 |
94 | index_mask = x_range <= j
95 | v = np.where(index_mask, v_max + value[:, :, j], max_neg_val)
96 | direction = np.where(mask, direction, 1)
97 |
98 | path = np.zeros(value.shape, dtype=np.float32)
99 | index = mask[:, :, 0].sum(1).astype(np.int64) - 1
100 | index_range = np.arange(b)
101 | try:
102 | for j in reversed(range(t_y)):
103 | path[index_range, index, j] = 1
104 | index = index + direction[index_range, index, j] - 1
105 | except Exception as e:
106 | print("index range", index_range)
107 | print(e)
108 | path = path * mask.astype(np.float32)
109 | path = torch.from_numpy(path).to(device=device, dtype=dtype)
110 | return path
111 |
112 |
113 | def generate_path(duration, mask):
114 | """
115 | Generate alignment based on predicted durations
116 | duration: [b, t_x]
117 | mask: [b, t_x, t_y]
118 | """
119 |
120 | b, t_x, t_y = mask.shape
121 | cum_duration = torch.cumsum(duration, 1)
122 |
123 | cum_duration_flat = cum_duration.view(b * t_x)
124 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
125 | path = path.view(b, t_x, t_y)
126 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
127 | path = path * mask
128 | return path
129 |
130 |
131 | @torch.jit.script
132 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
133 | n_channels_int = n_channels[0]
134 | in_act = input_a + input_b
135 | t_act = torch.tanh(in_act[:, :n_channels_int, :])
136 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
137 | acts = t_act * s_act
138 | return acts
139 |
140 |
141 | class LayerNorm(nn.Module):
142 | def __init__(self, channels, eps=1e-4):
143 | super().__init__()
144 | self.channels = channels
145 | self.eps = eps
146 |
147 | self.gamma = nn.Parameter(torch.ones(channels))
148 | self.beta = nn.Parameter(torch.zeros(channels))
149 |
150 | def forward(self, x):
151 | n_dims = len(x.shape)
152 | mean = torch.mean(x, 1, keepdim=True)
153 | variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
154 |
155 | x = (x - mean) * torch.rsqrt(variance + self.eps)
156 |
157 | shape = [1, -1] + [1] * (n_dims - 2)
158 | x = x * self.gamma.view(*shape) + self.beta.view(*shape)
159 | return x
160 |
161 |
162 | class ConvReluNorm(nn.Module):
163 | def __init__(
164 | self,
165 | in_channels,
166 | hidden_channels,
167 | out_channels,
168 | kernel_size,
169 | n_layers,
170 | p_dropout,
171 | ):
172 | super().__init__()
173 | self.in_channels = in_channels
174 | self.hidden_channels = hidden_channels
175 | self.out_channels = out_channels
176 | self.kernel_size = kernel_size
177 | self.n_layers = n_layers
178 | self.p_dropout = p_dropout
179 | assert n_layers > 1, "Number of layers should be larger than 0."
180 |
181 | self.conv_layers = nn.ModuleList()
182 | self.norm_layers = nn.ModuleList()
183 | self.conv_layers.append(
184 | nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
185 | )
186 | self.norm_layers.append(LayerNorm(hidden_channels))
187 | self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
188 | for _ in range(n_layers - 1):
189 | self.conv_layers.append(
190 | nn.Conv1d(
191 | hidden_channels,
192 | hidden_channels,
193 | kernel_size,
194 | padding=kernel_size // 2,
195 | )
196 | )
197 | self.norm_layers.append(LayerNorm(hidden_channels))
198 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
199 | self.proj.weight.data.zero_()
200 | self.proj.bias.data.zero_()
201 |
202 | def forward(self, x, x_mask):
203 | x_org = x
204 | for i in range(self.n_layers):
205 | x = self.conv_layers[i](x * x_mask)
206 | x = self.norm_layers[i](x)
207 | x = self.relu_drop(x)
208 | x = x_org + self.proj(x)
209 | return x * x_mask
210 |
211 |
212 | class WN(torch.nn.Module):
213 | def __init__(
214 | self, hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=0
215 | ):
216 | super(WN, self).__init__()
217 | assert kernel_size % 2 == 1
218 | assert hidden_channels % 2 == 0
219 |
220 | self.hidden_channels = hidden_channels
221 | self.n_layers = n_layers
222 |
223 | self.in_layers = torch.nn.ModuleList()
224 | self.res_skip_layers = torch.nn.ModuleList()
225 | self.drop = nn.Dropout(p_dropout)
226 |
227 | if gin_channels != 0:
228 | cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
229 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
230 |
231 | for i in range(n_layers):
232 | dilation = dilation_rate ** i
233 | padding = int((kernel_size * dilation - dilation) / 2)
234 | in_layer = torch.nn.Conv1d(
235 | hidden_channels,
236 | 2 * hidden_channels,
237 | kernel_size,
238 | dilation=dilation,
239 | padding=padding,
240 | )
241 | in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
242 | self.in_layers.append(in_layer)
243 |
244 | # last one is not necessary
245 | if i < n_layers - 1:
246 | res_skip_channels = 2 * hidden_channels
247 | else:
248 | res_skip_channels = hidden_channels
249 |
250 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
251 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
252 | self.res_skip_layers.append(res_skip_layer)
253 |
254 | def forward(self, x, x_mask=None, g=None, **kwargs):
255 | output = torch.zeros_like(x)
256 | n_channels_tensor = torch.IntTensor([self.hidden_channels])
257 | if g is not None:
258 | g = self.cond_layer(g)
259 | for i in range(self.n_layers):
260 | x_in = self.in_layers[i](x)
261 | x_in = self.drop(x_in)
262 | if g is not None:
263 | cond_offset = i * 2 * self.hidden_channels
264 | g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
265 | else:
266 | g_l = torch.zeros_like(x_in)
267 | acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
268 |
269 | res_skip_acts = self.res_skip_layers[i](acts)
270 | if i < self.n_layers - 1:
271 | x = (x + res_skip_acts[:, : self.hidden_channels, :]) * x_mask
272 | output = output + res_skip_acts[:, self.hidden_channels :, :]
273 | else:
274 | output = output + res_skip_acts
275 | return output * x_mask
276 |
277 | def remove_weight_norm(self):
278 | for l in self.in_layers:
279 | torch.nn.utils.remove_weight_norm(l)
280 | for l in self.res_skip_layers:
281 | torch.nn.utils.remove_weight_norm(l)
282 |
283 |
284 | class WNP(torch.nn.Module):
285 | def __init__(
286 | self, hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=0, n_sqz=2,
287 | ):
288 | super(WNP, self).__init__()
289 | assert kernel_size % 2 == 1
290 | assert hidden_channels % 2 == 0
291 |
292 | self.hidden_channels = hidden_channels
293 | self.n_layers = n_layers
294 |
295 | self.in_layers = torch.nn.ModuleList()
296 | self.res_skip_layers = torch.nn.ModuleList()
297 | self.drop = nn.Dropout(p_dropout)
298 | self.n_sqz = n_sqz
299 |
300 | if gin_channels != 0:
301 | cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers // self.n_sqz, 1)
302 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
303 |
304 | for i in range(n_layers):
305 | dilation = dilation_rate ** i
306 | padding = int((kernel_size * dilation - dilation) / 2)
307 | in_layer = torch.nn.Conv1d(
308 | hidden_channels,
309 | 2 * hidden_channels,
310 | kernel_size,
311 | dilation=dilation,
312 | padding=padding,
313 | )
314 | in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
315 | self.in_layers.append(in_layer)
316 |
317 | # last one is not necessary
318 | if i < n_layers - 1:
319 | res_skip_channels = 2 * hidden_channels
320 | else:
321 | res_skip_channels = hidden_channels
322 |
323 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
324 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
325 | self.res_skip_layers.append(res_skip_layer)
326 |
327 | def forward(self, x, x_mask=None, g=None, **kwargs):
328 | output = torch.zeros_like(x)
329 | n_channels_tensor = torch.IntTensor([self.hidden_channels])
330 | if g is not None:
331 | g = self.cond_layer(g)
332 | if self.n_sqz > 1:
333 | g = self.squeeze(g, self.n_sqz)
334 | else:
335 | return x
336 |
337 | for i in range(self.n_layers):
338 | x_in = self.in_layers[i](x)
339 | x_in = self.drop(x_in)
340 | if g is not None:
341 | cond_offset = i * 2 * self.hidden_channels
342 | g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
343 | else:
344 | g_l = torch.zeros_like(x_in)
345 |
346 | acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
347 |
348 | res_skip_acts = self.res_skip_layers[i](acts)
349 | if i < self.n_layers - 1:
350 | x = (x + res_skip_acts[:, : self.hidden_channels, :]) * x_mask
351 | output = output + res_skip_acts[:, self.hidden_channels :, :]
352 | else:
353 | output = output + res_skip_acts
354 | return output * x_mask
355 |
356 | def remove_weight_norm(self):
357 | for l in self.in_layers:
358 | torch.nn.utils.remove_weight_norm(l)
359 | for l in self.res_skip_layers:
360 | torch.nn.utils.remove_weight_norm(l)
361 |
362 | def squeeze(self, x, n_sqz=2):
363 | b, c, t = x.size()
364 |
365 | t = (t // n_sqz) * n_sqz
366 | x = x[:, :, :t]
367 | x_sqz = x.view(b, c, t // n_sqz, n_sqz)
368 |
369 | x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz)
370 |
371 | return x_sqz
372 |
373 |
374 | class ActNorm(nn.Module):
375 | def __init__(self, channels, ddi=False, **kwargs):
376 | super().__init__()
377 | self.channels = channels
378 | self.initialized = not ddi
379 |
380 | self.logs = nn.Parameter(torch.zeros(1, channels, 1))
381 | self.bias = nn.Parameter(torch.zeros(1, channels, 1))
382 |
383 | def forward(self, x, x_mask=None, reverse=False, **kwargs):
384 | if x_mask is None:
385 | x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype)
386 | x_len = torch.sum(x_mask, [1, 2])
387 | if not self.initialized:
388 | self.initialize(x, x_mask)
389 | self.initialized = True
390 |
391 | if reverse:
392 | z = (x - self.bias) * torch.exp(-self.logs) * x_mask
393 | logdet = None
394 | else:
395 | z = (self.bias + torch.exp(self.logs) * x) * x_mask
396 | logdet = torch.sum(self.logs) * x_len # [b]
397 |
398 | return z, logdet
399 |
400 | def store_inverse(self):
401 | pass
402 |
403 | def set_ddi(self, ddi):
404 | self.initialized = not ddi
405 |
406 | def initialize(self, x, x_mask):
407 | with torch.no_grad():
408 | denom = torch.sum(x_mask, [0, 2])
409 | m = torch.sum(x * x_mask, [0, 2]) / denom
410 | m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
411 | v = m_sq - (m ** 2)
412 | logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
413 |
414 | bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
415 | logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
416 |
417 | self.bias.data.copy_(bias_init)
418 | self.logs.data.copy_(logs_init)
419 |
420 |
421 | class InvConvNear(nn.Module):
422 | def __init__(self, channels, n_split=4, no_jacobian=False, **kwargs):
423 | super().__init__()
424 | assert n_split % 2 == 0
425 | self.channels = channels
426 | self.n_split = n_split
427 | self.no_jacobian = no_jacobian
428 |
429 | w_init = torch.linalg.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0]
430 | if torch.det(w_init) < 0:
431 | w_init[:, 0] = -1 * w_init[:, 0]
432 | self.weight = nn.Parameter(w_init)
433 |
434 | def forward(self, x, x_mask=None, reverse=False, **kwargs):
435 | b, c, t = x.size()
436 | assert c % self.n_split == 0
437 | if x_mask is None:
438 | x_mask = 1
439 | x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
440 | else:
441 | x_len = torch.sum(x_mask, [1, 2])
442 |
443 | x = x.view(b, 2, c // self.n_split, self.n_split // 2, t)
444 | x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.n_split, c // self.n_split, t)
445 |
446 | if reverse:
447 | if hasattr(self, "weight_inv"):
448 | weight = self.weight_inv
449 | else:
450 | weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
451 | logdet = None
452 | else:
453 | weight = self.weight
454 | if self.no_jacobian:
455 | logdet = 0
456 | else:
457 | logdet = torch.logdet(self.weight) * (c / self.n_split) * x_len # [b]
458 |
459 | weight = weight.view(self.n_split, self.n_split, 1, 1)
460 | z = F.conv2d(x, weight)
461 |
462 | z = z.view(b, 2, self.n_split // 2, c // self.n_split, t)
463 | z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask
464 | return z, logdet
465 |
466 | def store_inverse(self):
467 | self.weight_inv = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
468 |
469 |
470 | class CouplingBlock(nn.Module):
471 | def __init__(
472 | self,
473 | in_channels,
474 | hidden_channels,
475 | kernel_size,
476 | dilation_rate,
477 | n_layers,
478 | gin_channels=0,
479 | p_dropout=0,
480 | sigmoid_scale=False,
481 | n_sqz=2,
482 | ):
483 | super().__init__()
484 | self.in_channels = in_channels
485 | self.hidden_channels = hidden_channels
486 | self.kernel_size = kernel_size
487 | self.dilation_rate = dilation_rate
488 | self.n_layers = n_layers
489 | self.gin_channels = gin_channels
490 | self.p_dropout = p_dropout
491 | self.sigmoid_scale = sigmoid_scale
492 | self.n_sqz = n_sqz
493 |
494 | start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
495 | start = torch.nn.utils.weight_norm(start)
496 | self.start = start
497 | # Initializing last layer to 0 makes the affine coupling layers
498 | # do nothing at first. This helps with training stability
499 | end = torch.nn.Conv1d(hidden_channels, in_channels, 1)
500 | end.weight.data.zero_()
501 | end.bias.data.zero_()
502 | self.end = end
503 |
504 | self.wn = WN(
505 | hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout, gin_channels
506 | )
507 | self.wn_pitch = WNP(
508 | hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout, 1, n_sqz
509 | ) # 1 dimension of pitch
510 |
511 | def forward(self, x, x_mask=None, reverse=False, g=None, pitch=None, **kwargs):
512 | if x_mask is None:
513 | x_mask = 1
514 | x_0, x_1 = x[:, : self.in_channels // 2], x[:, self.in_channels // 2 :]
515 |
516 | x = self.start(x_0) * x_mask
517 | # add speaker emb
518 | x = self.wn(x, x_mask, g)
519 | # add pitch emb
520 | if pitch is not None and len(pitch.shape) == 2:
521 | pitch = pitch.unsqueeze(1) # B, T -> B,C,T
522 |
523 | x = self.wn_pitch(x, x_mask, pitch)
524 | out = self.end(x)
525 |
526 | z_0 = x_0
527 | m = out[:, : self.in_channels // 2, :]
528 | logs = out[:, self.in_channels // 2 :, :]
529 | if self.sigmoid_scale:
530 | logs = torch.log(1e-6 + torch.sigmoid(logs + 2))
531 |
532 | if reverse:
533 | z_1 = (x_1 - m) * torch.exp(-logs) * x_mask
534 | logdet = None
535 | else:
536 | z_1 = (m + torch.exp(logs) * x_1) * x_mask
537 | logdet = torch.sum(logs * x_mask, [1, 2])
538 |
539 | z = torch.cat([z_0, z_1], 1)
540 | return z, logdet
541 |
542 | def store_inverse(self):
543 | self.wn.remove_weight_norm()
544 |
545 |
546 | class AttentionBlock(nn.Module):
547 | def __init__(
548 | self,
549 | channels,
550 | out_channels,
551 | n_heads,
552 | window_size=None,
553 | heads_share=True,
554 | p_dropout=0.0,
555 | block_length=None,
556 | proximal_bias=False,
557 | proximal_init=False,
558 | ):
559 | super().__init__()
560 | assert channels % n_heads == 0
561 |
562 | self.channels = channels
563 | self.out_channels = out_channels
564 | self.n_heads = n_heads
565 | self.window_size = window_size
566 | self.heads_share = heads_share
567 | self.block_length = block_length
568 | self.proximal_bias = proximal_bias
569 | self.p_dropout = p_dropout
570 | self.attn = None
571 |
572 | self.k_channels = channels // n_heads
573 | self.conv_q = nn.Conv1d(channels, channels, 1)
574 | self.conv_k = nn.Conv1d(channels, channels, 1)
575 | self.conv_v = nn.Conv1d(channels, channels, 1)
576 | if window_size is not None:
577 | n_heads_rel = 1 if heads_share else n_heads
578 | rel_stddev = self.k_channels ** -0.5
579 | self.emb_rel_k = nn.Parameter(
580 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev
581 | )
582 | self.emb_rel_v = nn.Parameter(
583 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev
584 | )
585 | self.conv_o = nn.Conv1d(channels, out_channels, 1)
586 | self.drop = nn.Dropout(p_dropout)
587 |
588 | nn.init.xavier_uniform_(self.conv_q.weight)
589 | nn.init.xavier_uniform_(self.conv_k.weight)
590 | if proximal_init:
591 | self.conv_k.weight.data.copy_(self.conv_q.weight.data)
592 | self.conv_k.bias.data.copy_(self.conv_q.bias.data)
593 | nn.init.xavier_uniform_(self.conv_v.weight)
594 |
595 | def forward(self, x, c, attn_mask=None):
596 | q = self.conv_q(x)
597 | k = self.conv_k(c)
598 | v = self.conv_v(c)
599 |
600 | x, self.attn = self.attention(q, k, v, mask=attn_mask)
601 |
602 | x = self.conv_o(x)
603 | return x
604 |
605 | def attention(self, query, key, value, mask=None):
606 | # reshape [b, d, t] -> [b, n_h, t, d_k]
607 | b, d, t_s, t_t = key.size(0), key.size(1), key.size(2), query.size(2)
608 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
609 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
610 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
611 |
612 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
613 | if self.window_size is not None:
614 | assert t_s == t_t, "Relative attention is only available for self-attention."
615 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
616 | rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
617 | rel_logits = self._relative_position_to_absolute_position(rel_logits)
618 | scores_local = rel_logits / math.sqrt(self.k_channels)
619 | scores = scores + scores_local
620 | if self.proximal_bias:
621 | assert t_s == t_t, "Proximal bias is only available for self-attention."
622 | scores = scores + self._attention_bias_proximal(t_s).to(
623 | device=scores.device, dtype=scores.dtype
624 | )
625 | if mask is not None:
626 | scores = scores.masked_fill(mask == 0, -1e4)
627 | if self.block_length is not None:
628 | block_mask = (
629 | torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
630 | )
631 | scores = scores * block_mask + -1e4 * (1 - block_mask)
632 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
633 | p_attn = self.drop(p_attn)
634 | output = torch.matmul(p_attn, value)
635 | if self.window_size is not None:
636 | relative_weights = self._absolute_position_to_relative_position(p_attn)
637 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
638 | output = output + self._matmul_with_relative_values(
639 | relative_weights, value_relative_embeddings
640 | )
641 | output = (
642 | output.transpose(2, 3).contiguous().view(b, d, t_t)
643 | ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
644 | return output, p_attn
645 |
646 | def _matmul_with_relative_values(self, x, y):
647 | """
648 | x: [b, h, l, m]
649 | y: [h or 1, m, d]
650 | ret: [b, h, l, d]
651 | """
652 | ret = torch.matmul(x, y.unsqueeze(0))
653 | return ret
654 |
655 | def _matmul_with_relative_keys(self, x, y):
656 | """
657 | x: [b, h, l, d]
658 | y: [h or 1, m, d]
659 | ret: [b, h, l, m]
660 | """
661 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
662 | return ret
663 |
664 | def _get_relative_embeddings(self, relative_embeddings, length):
665 | # Pad first before slice to avoid using cond ops.
666 | pad_length = max(length - (self.window_size + 1), 0)
667 | slice_start_position = max((self.window_size + 1) - length, 0)
668 | slice_end_position = slice_start_position + 2 * length - 1
669 | if pad_length > 0:
670 | padded_relative_embeddings = F.pad(
671 | relative_embeddings,
672 | convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
673 | )
674 | else:
675 | padded_relative_embeddings = relative_embeddings
676 | used_relative_embeddings = padded_relative_embeddings[
677 | :, slice_start_position:slice_end_position
678 | ]
679 | return used_relative_embeddings
680 |
681 | def _relative_position_to_absolute_position(self, x):
682 | """
683 | x: [b, h, l, 2*l-1]
684 | ret: [b, h, l, l]
685 | """
686 | batch, heads, length, _ = x.size()
687 | # Concat columns of pad to shift from relative to absolute indexing.
688 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
689 |
690 | # Concat extra elements so to add up to shape (len+1, 2*len-1).
691 | x_flat = x.view([batch, heads, length * 2 * length])
692 | x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
693 |
694 | # Reshape and slice out the padded elements.
695 | x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
696 | :, :, :length, length - 1 :
697 | ]
698 | return x_final
699 |
700 | def _absolute_position_to_relative_position(self, x):
701 | """
702 | x: [b, h, l, l]
703 | ret: [b, h, l, 2*l-1]
704 | """
705 | batch, heads, length, _ = x.size()
706 | # padd along column
707 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
708 | x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)])
709 | # add 0's in the beginning that will skew the elements after reshape
710 | x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
711 | x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
712 | return x_final
713 |
714 | def _attention_bias_proximal(self, length):
715 | """Bias for self-attention to encourage attention to close positions.
716 | Args:
717 | length: an integer scalar.
718 | Returns:
719 | a Tensor with shape [1, 1, length, length]
720 | """
721 | r = torch.arange(length, dtype=torch.float32)
722 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
723 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
724 |
725 |
726 | class FeedForwardNetwork(nn.Module):
727 | def __init__(
728 | self,
729 | in_channels,
730 | out_channels,
731 | filter_channels,
732 | kernel_size,
733 | p_dropout=0.0,
734 | activation=None,
735 | ):
736 | super().__init__()
737 | self.in_channels = in_channels
738 | self.out_channels = out_channels
739 | self.filter_channels = filter_channels
740 | self.kernel_size = kernel_size
741 | self.p_dropout = p_dropout
742 | self.activation = activation
743 |
744 | self.conv_1 = nn.Conv1d(
745 | in_channels, filter_channels, kernel_size, padding=kernel_size // 2
746 | )
747 | self.conv_2 = nn.Conv1d(
748 | filter_channels, out_channels, kernel_size, padding=kernel_size // 2
749 | )
750 | self.drop = nn.Dropout(p_dropout)
751 |
752 | def forward(self, x, x_mask):
753 | x = self.conv_1(x * x_mask)
754 | if self.activation == "gelu":
755 | x = x * torch.sigmoid(1.702 * x)
756 | else:
757 | x = torch.relu(x)
758 | x = self.drop(x)
759 | x = self.conv_2(x * x_mask)
760 | return x * x_mask
761 |
762 |
763 | class DurationPredictor(nn.Module):
764 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
765 | """
766 | Token duration predictor for the GlowTTS model.
767 | Takes in embeddings of the input tokens and predicts how many frames of
768 | mel-spectrogram are aligned to each text token.
769 | Architecture is the same as the duration predictor in FastSpeech.
770 | Args:
771 | in_channels: Number of channels for the token embeddings
772 | filter_channels: Number of channels in the intermediate layers
773 | kernel_size: Kernels size for the convolution layers
774 | p_dropout: Dropout probability
775 | """
776 |
777 | super().__init__()
778 |
779 | self.drop = nn.Dropout(p_dropout)
780 | self.conv_1 = nn.Conv1d(
781 | in_channels, filter_channels, kernel_size, padding=kernel_size // 2
782 | )
783 | self.norm_1 = LayerNorm(filter_channels)
784 | self.conv_2 = nn.Conv1d(
785 | filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
786 | )
787 | self.norm_2 = LayerNorm(filter_channels)
788 | self.proj = nn.Conv1d(filter_channels, 1, 1)
789 |
790 | def forward(self, spect, mask, **kwargs):
791 | x = torch.detach(spect)
792 | x = self.conv_1(spect * mask)
793 | x = torch.relu(x)
794 | x = self.norm_1(x)
795 | x = self.drop(x)
796 | x = self.conv_2(x * mask)
797 | x = torch.relu(x)
798 | x = self.norm_2(x)
799 | x = self.drop(x)
800 | x = self.proj(x * mask)
801 | durs = x * mask
802 | return durs.squeeze(1)
803 |
804 |
805 | class StochasticDurationPredictor(nn.Module):
806 | """Borrowed from VITS"""
807 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
808 | super().__init__()
809 |
810 | filter_channels = in_channels # it needs to be removed from future version.
811 | self.in_channels = in_channels
812 | self.filter_channels = filter_channels
813 | self.kernel_size = kernel_size
814 | self.p_dropout = p_dropout
815 | self.n_flows = n_flows
816 | self.gin_channels = gin_channels
817 |
818 | self.log_flow = stocpred_modules.Log()
819 | self.flows = nn.ModuleList()
820 | self.flows.append(stocpred_modules.ElementwiseAffine(2))
821 | for i in range(n_flows):
822 | self.flows.append(stocpred_modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
823 | self.flows.append(stocpred_modules.Flip())
824 |
825 | self.post_pre = nn.Conv1d(1, filter_channels, 1)
826 | self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
827 | self.post_convs = stocpred_modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
828 | self.post_flows = nn.ModuleList()
829 | self.post_flows.append(stocpred_modules.ElementwiseAffine(2))
830 | for i in range(4):
831 | self.post_flows.append(stocpred_modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
832 | self.post_flows.append(stocpred_modules.Flip())
833 | self.pre = nn.Conv1d(in_channels, filter_channels, 1)
834 | self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
835 | self.convs = stocpred_modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
836 | if gin_channels != 0:
837 | self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
838 |
839 | def forward(self, spect, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
840 | x = torch.detach(spect)
841 | x = self.pre(x)
842 | if g is not None:
843 | g = torch.detach(g)
844 | x = x + self.cond(g)
845 | x = self.convs(x, x_mask)
846 | x = self.proj(x) * x_mask
847 | if not reverse:
848 | flows = self.flows
849 | assert w is not None
850 |
851 | logdet_tot_q = 0
852 | h_w = self.post_pre(w)
853 | h_w = self.post_convs(h_w, x_mask)
854 | h_w = self.post_proj(h_w) * x_mask
855 |
856 | e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
857 | z_q = e_q
858 | for flow in self.post_flows:
859 | z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
860 | logdet_tot_q += logdet_q
861 | z_u, z1 = torch.split(z_q, [1, 1], 1)
862 | u = torch.sigmoid(z_u) * x_mask
863 | z0 = (w - u) * x_mask
864 | logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2])
865 | logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q
866 |
867 | logdet_tot = 0
868 | z0, logdet = self.log_flow(z0, x_mask)
869 | logdet_tot += logdet
870 | z = torch.cat([z0, z1], 1)
871 | for flow in flows:
872 | z, logdet = flow(z, x_mask, g=x, reverse=reverse)
873 | logdet_tot = logdet_tot + logdet
874 | nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot
875 | stoch_dur_loss = nll + logq
876 | stoch_dur_loss = stoch_dur_loss / torch.sum(x_mask)
877 | stoch_dur_loss = torch.sum(stoch_dur_loss)
878 |
879 | return stoch_dur_loss, None # [b]
880 |
881 | else:
882 | flows = list(reversed(self.flows))
883 | flows = flows[:-2] + [flows[-1]] # remove a useless vflow
884 | # generator=torch.Generator().manual_seed(42)
885 | z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
886 | for flow in flows:
887 | z = flow(z, x_mask, g=x, reverse=reverse)
888 | z0, z1 = torch.split(z, [1, 1], 1)
889 | logw = z0
890 | return None, logw
891 |
892 |
893 | class StochasticPitchPredictor(nn.Module):
894 | """Borrowed from VITS"""
895 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
896 | super().__init__()
897 |
898 | filter_channels = in_channels # it needs to be removed from future version.
899 | self.in_channels = in_channels
900 | self.filter_channels = filter_channels
901 | self.kernel_size = kernel_size
902 | self.p_dropout = p_dropout
903 | self.n_flows = n_flows
904 | self.gin_channels = gin_channels
905 |
906 | self.log_flow = stocpred_modules.Log()
907 | self.flows = nn.ModuleList()
908 | self.flows.append(stocpred_modules.ElementwiseAffine(2))
909 | for i in range(n_flows):
910 | self.flows.append(stocpred_modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
911 | self.flows.append(stocpred_modules.Flip())
912 |
913 | self.pre = nn.Conv1d(in_channels, filter_channels, 1)
914 | self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
915 | self.convs = stocpred_modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
916 | if gin_channels != 0:
917 | self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
918 |
919 | def forward(self, spect, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
920 | x = torch.detach(spect)
921 | x = self.pre(x)
922 | if g is not None:
923 | g = torch.detach(g)
924 | a = self.cond(g)
925 | x = x + self.cond(g)
926 | x = self.convs(x, x_mask)
927 | x = self.proj(x) * x_mask
928 | if not reverse:
929 | flows = self.flows
930 | assert w is not None
931 |
932 | e_q = torch.randn(w.size()).to(device=x.device, dtype=x.dtype) * x_mask
933 |
934 | logdet_tot = 0
935 | z = torch.cat([w, e_q], 1)
936 | for flow in flows:
937 | z, logdet = flow(z, x_mask, g=x, reverse=reverse)
938 | logdet_tot = logdet_tot + logdet
939 | nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot
940 |
941 | stoch_pitch_loss = nll / torch.sum(x_mask)
942 | stoch_pitch_loss = torch.sum(stoch_pitch_loss)
943 | return stoch_pitch_loss, None # [b]
944 |
945 | else:
946 | flows = list(reversed(self.flows))
947 | flows = flows[:-2] + [flows[-1]] # remove a useless vflow
948 | z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
949 | for flow in flows:
950 | z = flow(z, x_mask, g=x, reverse=reverse)
951 | z0, z1 = torch.split(z, [1, 1], 1)
952 | w = z0
953 | return None, w
954 |
--------------------------------------------------------------------------------
/modules/glow_tts_modules/stocpred_modules.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 |
6 | from modules.glow_tts_modules.transforms import piecewise_rational_quadratic_transform
7 |
8 |
9 | class LayerNorm(nn.Module):
10 | def __init__(self, channels, eps=1e-4):
11 | super().__init__()
12 | self.channels = channels
13 | self.eps = eps
14 |
15 | self.gamma = nn.Parameter(torch.ones(channels))
16 | self.beta = nn.Parameter(torch.zeros(channels))
17 |
18 | def forward(self, x):
19 | n_dims = len(x.shape)
20 | mean = torch.mean(x, 1, keepdim=True)
21 | variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
22 |
23 | x = (x - mean) * torch.rsqrt(variance + self.eps)
24 |
25 | shape = [1, -1] + [1] * (n_dims - 2)
26 | x = x * self.gamma.view(*shape) + self.beta.view(*shape)
27 | return x
28 |
29 |
30 | class DDSConv(nn.Module):
31 | """
32 | Dialted and Depth-Separable Convolution
33 | """
34 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
35 | super().__init__()
36 | self.channels = channels
37 | self.kernel_size = kernel_size
38 | self.n_layers = n_layers
39 | self.p_dropout = p_dropout
40 |
41 | self.drop = nn.Dropout(p_dropout)
42 | self.convs_sep = nn.ModuleList()
43 | self.convs_1x1 = nn.ModuleList()
44 | self.norms_1 = nn.ModuleList()
45 | self.norms_2 = nn.ModuleList()
46 | for i in range(n_layers):
47 | dilation = kernel_size ** i
48 | padding = (kernel_size * dilation - dilation) // 2
49 | self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
50 | groups=channels, dilation=dilation, padding=padding
51 | ))
52 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
53 | self.norms_1.append(LayerNorm(channels))
54 | self.norms_2.append(LayerNorm(channels))
55 |
56 | def forward(self, x, x_mask, g=None):
57 | if g is not None:
58 | x = x + g
59 | for i in range(self.n_layers):
60 | y = self.convs_sep[i](x * x_mask)
61 | y = self.norms_1[i](y)
62 | y = F.gelu(y)
63 | y = self.convs_1x1[i](y)
64 | y = self.norms_2[i](y)
65 | y = F.gelu(y)
66 | y = self.drop(y)
67 | x = x + y
68 | return x * x_mask
69 |
70 |
71 | class Log(nn.Module):
72 | def forward(self, x, x_mask, reverse=False, **kwargs):
73 | if not reverse:
74 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
75 | logdet = torch.sum(-y, [1, 2])
76 | return y, logdet
77 | else:
78 | x = torch.exp(x) * x_mask
79 | return x
80 |
81 |
82 | class Flip(nn.Module):
83 | def forward(self, x, *args, reverse=False, **kwargs):
84 | x = torch.flip(x, [1])
85 | if not reverse:
86 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
87 | return x, logdet
88 | else:
89 | return x
90 |
91 |
92 | class ElementwiseAffine(nn.Module):
93 | def __init__(self, channels):
94 | super().__init__()
95 | self.channels = channels
96 | self.m = nn.Parameter(torch.zeros(channels,1))
97 | self.logs = nn.Parameter(torch.zeros(channels,1))
98 |
99 | def forward(self, x, x_mask, reverse=False, **kwargs):
100 | if not reverse:
101 | y = self.m + torch.exp(self.logs) * x
102 | y = y * x_mask
103 | logdet = torch.sum(self.logs * x_mask, [1,2])
104 | return y, logdet
105 | else:
106 | x = (x - self.m) * torch.exp(-self.logs) * x_mask
107 | return x
108 |
109 |
110 | class ConvFlow(nn.Module):
111 | def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
112 | super().__init__()
113 | self.in_channels = in_channels
114 | self.filter_channels = filter_channels
115 | self.kernel_size = kernel_size
116 | self.n_layers = n_layers
117 | self.num_bins = num_bins
118 | self.tail_bound = tail_bound
119 | self.half_channels = in_channels // 2
120 |
121 | self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
122 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.)
123 | self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
124 | self.proj.weight.data.zero_()
125 | self.proj.bias.data.zero_()
126 |
127 | def forward(self, x, x_mask, g=None, reverse=False):
128 | x0, x1 = torch.split(x, [self.half_channels]*2, 1)
129 | h = self.pre(x0)
130 | h = self.convs(h, x_mask, g=g)
131 | h = self.proj(h) * x_mask
132 |
133 | b, c, t = x0.shape
134 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
135 |
136 | unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels)
137 | unnormalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels)
138 | unnormalized_derivatives = h[..., 2 * self.num_bins:]
139 |
140 | x1, logabsdet = piecewise_rational_quadratic_transform(x1,
141 | unnormalized_widths,
142 | unnormalized_heights,
143 | unnormalized_derivatives,
144 | inverse=reverse,
145 | tails='linear',
146 | tail_bound=self.tail_bound
147 | )
148 |
149 | x = torch.cat([x0, x1], 1) * x_mask
150 | logdet = torch.sum(logabsdet * x_mask, [1,2])
151 | if not reverse:
152 | return x, logdet
153 | else:
154 | return x
--------------------------------------------------------------------------------
/modules/glow_tts_modules/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 |
4 | import numpy as np
5 |
6 |
7 | DEFAULT_MIN_BIN_WIDTH = 1e-3
8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3
9 | DEFAULT_MIN_DERIVATIVE = 1e-3
10 |
11 |
12 | def piecewise_rational_quadratic_transform(inputs,
13 | unnormalized_widths,
14 | unnormalized_heights,
15 | unnormalized_derivatives,
16 | inverse=False,
17 | tails=None,
18 | tail_bound=1.,
19 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
20 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
21 | min_derivative=DEFAULT_MIN_DERIVATIVE):
22 |
23 | if tails is None:
24 | spline_fn = rational_quadratic_spline
25 | spline_kwargs = {}
26 | else:
27 | spline_fn = unconstrained_rational_quadratic_spline
28 | spline_kwargs = {
29 | 'tails': tails,
30 | 'tail_bound': tail_bound
31 | }
32 |
33 | outputs, logabsdet = spline_fn(
34 | inputs=inputs,
35 | unnormalized_widths=unnormalized_widths,
36 | unnormalized_heights=unnormalized_heights,
37 | unnormalized_derivatives=unnormalized_derivatives,
38 | inverse=inverse,
39 | min_bin_width=min_bin_width,
40 | min_bin_height=min_bin_height,
41 | min_derivative=min_derivative,
42 | **spline_kwargs
43 | )
44 | return outputs, logabsdet
45 |
46 |
47 | def searchsorted(bin_locations, inputs, eps=1e-6):
48 | bin_locations[..., -1] += eps
49 | return torch.sum(
50 | inputs[..., None] >= bin_locations,
51 | dim=-1
52 | ) - 1
53 |
54 |
55 | def unconstrained_rational_quadratic_spline(inputs,
56 | unnormalized_widths,
57 | unnormalized_heights,
58 | unnormalized_derivatives,
59 | inverse=False,
60 | tails='linear',
61 | tail_bound=1.,
62 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
63 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
64 | min_derivative=DEFAULT_MIN_DERIVATIVE):
65 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
66 | outside_interval_mask = ~inside_interval_mask
67 |
68 | outputs = torch.zeros_like(inputs)
69 | logabsdet = torch.zeros_like(inputs)
70 |
71 | if tails == 'linear':
72 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
73 | constant = np.log(np.exp(1 - min_derivative) - 1)
74 | unnormalized_derivatives[..., 0] = constant
75 | unnormalized_derivatives[..., -1] = constant
76 |
77 | outputs[outside_interval_mask] = inputs[outside_interval_mask]
78 | logabsdet[outside_interval_mask] = 0
79 | else:
80 | raise RuntimeError('{} tails are not implemented.'.format(tails))
81 |
82 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
83 | inputs=inputs[inside_interval_mask],
84 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87 | inverse=inverse,
88 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound,
89 | min_bin_width=min_bin_width,
90 | min_bin_height=min_bin_height,
91 | min_derivative=min_derivative
92 | )
93 |
94 | return outputs, logabsdet
95 |
96 | def rational_quadratic_spline(inputs,
97 | unnormalized_widths,
98 | unnormalized_heights,
99 | unnormalized_derivatives,
100 | inverse=False,
101 | left=0., right=1., bottom=0., top=1.,
102 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
103 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
104 | min_derivative=DEFAULT_MIN_DERIVATIVE):
105 | if torch.min(inputs) < left or torch.max(inputs) > right:
106 | raise ValueError('Input to a transform is not within its domain')
107 |
108 | num_bins = unnormalized_widths.shape[-1]
109 |
110 | if min_bin_width * num_bins > 1.0:
111 | raise ValueError('Minimal bin width too large for the number of bins')
112 | if min_bin_height * num_bins > 1.0:
113 | raise ValueError('Minimal bin height too large for the number of bins')
114 |
115 | widths = F.softmax(unnormalized_widths, dim=-1)
116 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
117 | cumwidths = torch.cumsum(widths, dim=-1)
118 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
119 | cumwidths = (right - left) * cumwidths + left
120 | cumwidths[..., 0] = left
121 | cumwidths[..., -1] = right
122 | widths = cumwidths[..., 1:] - cumwidths[..., :-1]
123 |
124 | derivatives = min_derivative + F.softplus(unnormalized_derivatives)
125 |
126 | heights = F.softmax(unnormalized_heights, dim=-1)
127 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
128 | cumheights = torch.cumsum(heights, dim=-1)
129 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
130 | cumheights = (top - bottom) * cumheights + bottom
131 | cumheights[..., 0] = bottom
132 | cumheights[..., -1] = top
133 | heights = cumheights[..., 1:] - cumheights[..., :-1]
134 |
135 | if inverse:
136 | bin_idx = searchsorted(cumheights, inputs)[..., None]
137 | else:
138 | bin_idx = searchsorted(cumwidths, inputs)[..., None]
139 |
140 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
141 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
142 |
143 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
144 | delta = heights / widths
145 | input_delta = delta.gather(-1, bin_idx)[..., 0]
146 |
147 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
148 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
149 |
150 | input_heights = heights.gather(-1, bin_idx)[..., 0]
151 |
152 | if inverse:
153 | a = (((inputs - input_cumheights) * (input_derivatives
154 | + input_derivatives_plus_one
155 | - 2 * input_delta)
156 | + input_heights * (input_delta - input_derivatives)))
157 | b = (input_heights * input_derivatives
158 | - (inputs - input_cumheights) * (input_derivatives
159 | + input_derivatives_plus_one
160 | - 2 * input_delta))
161 | c = - input_delta * (inputs - input_cumheights)
162 |
163 | discriminant = b.pow(2) - 4 * a * c
164 | assert (discriminant >= 0).all()
165 |
166 | root = (2 * c) / (-b - torch.sqrt(discriminant))
167 | outputs = root * input_bin_widths + input_cumwidths
168 |
169 | theta_one_minus_theta = root * (1 - root)
170 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
171 | * theta_one_minus_theta)
172 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2)
173 | + 2 * input_delta * theta_one_minus_theta
174 | + input_derivatives * (1 - root).pow(2))
175 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
176 |
177 | return outputs, -logabsdet
178 | else:
179 | theta = (inputs - input_cumwidths) / input_bin_widths
180 | theta_one_minus_theta = theta * (1 - theta)
181 |
182 | numerator = input_heights * (input_delta * theta.pow(2)
183 | + input_derivatives * theta_one_minus_theta)
184 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
185 | * theta_one_minus_theta)
186 | outputs = input_cumheights + numerator / denominator
187 |
188 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2)
189 | + 2 * input_delta * theta_one_minus_theta
190 | + input_derivatives * (1 - theta).pow(2))
191 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
192 |
193 | return outputs, logabsdet
--------------------------------------------------------------------------------
/notebook/inference_glowtts_stdp.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# TTS Inference"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {
14 | "vscode": {
15 | "languageId": "python"
16 | }
17 | },
18 | "outputs": [],
19 | "source": [
20 | "\n",
21 | "import torch\n",
22 | "\n",
23 | "from models.glow_tts_with_pitch import GlowTTSModel\n",
24 | "from utils.data import load_speaker_emb\n",
25 | "\n",
26 | "from nemo.collections.tts.models import HifiGanModel"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": null,
32 | "metadata": {
33 | "vscode": {
34 | "languageId": "python"
35 | }
36 | },
37 | "outputs": [],
38 | "source": [
39 | "def infer(\n",
40 | " spec_gen_model,\n",
41 | " vocoder_model,\n",
42 | " str_input,\n",
43 | " noise_scale=0.0,\n",
44 | " length_scale=1.0,\n",
45 | " speaker=None,\n",
46 | " speaker_embeddings=None,\n",
47 | " stoch_dur_noise_scale=0.8,\n",
48 | " stoch_pitch_noise_scale=1.0,\n",
49 | " pitch_scale=0.0,\n",
50 | "):\n",
51 | "\n",
52 | " with torch.no_grad():\n",
53 | " parsed = spec_gen_model.parse(str_input)\n",
54 | "\n",
55 | " spectrogram = spec_gen_model.generate_spectrogram(\n",
56 | " tokens=parsed,\n",
57 | " noise_scale=noise_scale,\n",
58 | " length_scale=length_scale,\n",
59 | " speaker=speaker,\n",
60 | " speaker_embeddings=speaker_embeddings,\n",
61 | " stoch_dur_noise_scale=stoch_dur_noise_scale,\n",
62 | " stoch_pitch_noise_scale=stoch_pitch_noise_scale,\n",
63 | " pitch_scale=pitch_scale,\n",
64 | " )\n",
65 | "\n",
66 | " audio = vocoder_model.convert_spectrogram_to_audio(spec=spectrogram)\n",
67 | "\n",
68 | " if spectrogram is not None:\n",
69 | " if isinstance(spectrogram, torch.Tensor):\n",
70 | " spectrogram = spectrogram.to(\"cpu\").numpy()\n",
71 | " if len(spectrogram.shape) == 3:\n",
72 | " spectrogram = spectrogram[0]\n",
73 | " if isinstance(audio, torch.Tensor):\n",
74 | " audio = audio.to(\"cpu\").numpy()\n",
75 | " return spectrogram, audio"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "metadata": {
82 | "vscode": {
83 | "languageId": "python"
84 | }
85 | },
86 | "outputs": [],
87 | "source": [
88 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
89 | "print(device)"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": null,
95 | "metadata": {
96 | "vscode": {
97 | "languageId": "python"
98 | }
99 | },
100 | "outputs": [],
101 | "source": [
102 | "# load glowtts model from checkpoint\n",
103 | "spec_gen = GlowTTSModel.load_from_checkpoint(checkpoint_path=checkpoint_path)\n",
104 | "spec_gen = spec_gen.eval().to(device)"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": null,
110 | "metadata": {
111 | "vscode": {
112 | "languageId": "python"
113 | }
114 | },
115 | "outputs": [],
116 | "source": [
117 | "# load vocoder from checkpoint\n",
118 | "vocoder = HifiGanModel.load_from_checkpoint(checkpoint).eval().to(device)"
119 | ]
120 | },
121 | {
122 | "cell_type": "code",
123 | "execution_count": null,
124 | "metadata": {
125 | "vscode": {
126 | "languageId": "python"
127 | }
128 | },
129 | "outputs": [],
130 | "source": [
131 | "# Load speaker embeddings for conditioning\n",
132 | "speaker_emb_dict = load_speaker_emb(spk_emb_path)"
133 | ]
134 | },
135 | {
136 | "cell_type": "markdown",
137 | "metadata": {
138 | "tags": []
139 | },
140 | "source": [
141 | "## Inference"
142 | ]
143 | },
144 | {
145 | "cell_type": "markdown",
146 | "metadata": {},
147 | "source": [
148 | "Now that everything is set up, let's give an input that we want our models to speak"
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": null,
154 | "metadata": {
155 | "vscode": {
156 | "languageId": "python"
157 | }
158 | },
159 | "outputs": [],
160 | "source": [
161 | "# Extract speaker embedding from file\n",
162 | "\n",
163 | "audio_path = \"common_voice_en_18498899.wav\"\n",
164 | "audio_path_wo = audio_path.split(\".\")[0]\n",
165 | "\n",
166 | "speaker_embeddings = speaker_emb_dict.get(audio_path_wo)\n",
167 | "speaker_embeddings = torch.from_numpy(speaker_embeddings).reshape(1, -1).to(device)\n",
168 | "\n",
169 | "if speaker_embeddings is None:\n",
170 | " print(\"Could not load speaker embedding\")"
171 | ]
172 | },
173 | {
174 | "cell_type": "markdown",
175 | "metadata": {},
176 | "source": [
177 | "## Inference"
178 | ]
179 | },
180 | {
181 | "cell_type": "code",
182 | "execution_count": null,
183 | "metadata": {
184 | "vscode": {
185 | "languageId": "python"
186 | }
187 | },
188 | "outputs": [],
189 | "source": [
190 | "# Inference hyperparameters\n",
191 | "\n",
192 | "sr=16000\n",
193 | "noise_scale=0.667\n",
194 | "length_scale=1.0 #\n",
195 | "stoch_dur_noise_scale=0.8 #0.0-1.0\n",
196 | "stoch_pitch_noise_scale=0.8\n",
197 | "pitch_scale=0.0\n",
198 | "speaker=None"
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": null,
204 | "metadata": {
205 | "vscode": {
206 | "languageId": "python"
207 | }
208 | },
209 | "outputs": [],
210 | "source": [
211 | "from nemo_text_processing.text_normalization.normalize import Normalizer\n",
212 | "\n",
213 | "# initialize normalizer\n",
214 | "normalizer = Normalizer(input_case=\"cased\", lang=\"en\")"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": null,
220 | "metadata": {
221 | "vscode": {
222 | "languageId": "python"
223 | }
224 | },
225 | "outputs": [],
226 | "source": [
227 | "text_to_generate = \"A look of fear crossed his face, but he regained his serenity immediately.\"\n",
228 | "\n",
229 | "# normalize text. necessary in case text contains numeric text, dates, and abbreviations\n",
230 | "text_to_generate = normalizer.normalize(text_to_generate)\n",
231 | "print(text_to_generate)"
232 | ]
233 | },
234 | {
235 | "cell_type": "code",
236 | "execution_count": null,
237 | "metadata": {
238 | "vscode": {
239 | "languageId": "python"
240 | }
241 | },
242 | "outputs": [],
243 | "source": [
244 | "\n",
245 | "log_spec, audio = infer(spec_gen, vocoder, text_to_generate, \n",
246 | " noise_scale=noise_scale,\n",
247 | " length_scale=length_scale,\n",
248 | " speaker=speaker,\n",
249 | " stoch_dur_noise_scale=stoch_dur_noise_scale,\n",
250 | " stoch_pitch_noise_scale=stoch_pitch_noise_scale,\n",
251 | " pitch_scale=pitch_scale,\n",
252 | " speaker_embeddings=speaker_embeddings,)\n"
253 | ]
254 | },
255 | {
256 | "cell_type": "code",
257 | "execution_count": null,
258 | "metadata": {
259 | "vscode": {
260 | "languageId": "python"
261 | }
262 | },
263 | "outputs": [],
264 | "source": [
265 | "ipd.Audio(audio, rate=sr)"
266 | ]
267 | },
268 | {
269 | "cell_type": "code",
270 | "execution_count": null,
271 | "metadata": {
272 | "vscode": {
273 | "languageId": "python"
274 | }
275 | },
276 | "outputs": [],
277 | "source": []
278 | }
279 | ],
280 | "metadata": {
281 | "kernelspec": {
282 | "display_name": "Python 3 (ipykernel)",
283 | "language": "python",
284 | "name": "python3"
285 | },
286 | "vscode": {
287 | "interpreter": {
288 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
289 | }
290 | }
291 | },
292 | "nbformat": 4,
293 | "nbformat_minor": 4
294 | }
295 |
--------------------------------------------------------------------------------
/train_glowtts_baseline.sh:
--------------------------------------------------------------------------------
1 |
2 | python glow_tts_with_pitch.py \
3 | --config-path="conf" \
4 | --config-name="glow_tts_baseline" \
5 | train_dataset="dataset/train_tts_common_voice.json" \
6 | validation_datasets="dataset/dev_tts_common_voice.json" \
7 | speaker_emb_path="../tts_experiments/embeddings/spk_emb_exp_1plus.pkl" \
8 | sup_data_path="../cv-corpus-7.0-2021-07-21/en/cv_mos4_all_sup_data_folder"
--------------------------------------------------------------------------------
/train_glowtts_std.sh:
--------------------------------------------------------------------------------
1 |
2 | python glow_tts_with_pitch.py \
3 | --config-path="conf" \
4 | --config-name="glow_tts_std" \
5 | train_dataset="dataset/train_tts_common_voice.json" \
6 | validation_datasets="dataset/dev_tts_common_voice.json" \
7 | speaker_emb_path="../tts_experiments/embeddings/spk_emb_exp_1plus.pkl" \
8 | sup_data_path="../cv-corpus-7.0-2021-07-21/en/cv_mos4_all_sup_data_folder"
--------------------------------------------------------------------------------
/train_glowtts_stdp.sh:
--------------------------------------------------------------------------------
1 |
2 | python glow_tts_with_pitch.py \
3 | --config-path="conf" \
4 | --config-name="glow_tts_stdp" \
5 | train_dataset="dataset/train_tts_common_voice.json" \
6 | validation_datasets="dataset/dev_tts_common_voice.json" \
7 | speaker_emb_path="../tts_experiments/embeddings/spk_emb_exp_1plus.pkl" \
8 | sup_data_path="../cv-corpus-7.0-2021-07-21/en/cv_mos4_all_sup_data_folder"
9 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ogunlao/glowtts_stdp/07f71bbfce405018f43db7c086f46c6b91defa28/utils/__init__.py
--------------------------------------------------------------------------------
/utils/glow_tts_loss.py:
--------------------------------------------------------------------------------
1 | import pickle as pkl
2 | import json
3 |
4 | from nemo.utils import logging
5 |
6 |
7 | import math
8 |
9 | import torch
10 |
11 | from nemo.core.classes import Loss, typecheck
12 | from nemo.core.neural_types.elements import *
13 | from nemo.core.neural_types.neural_type import NeuralType
14 |
15 |
16 | class GlowTTSLoss(Loss):
17 | """
18 | Loss for the GlowTTS model
19 | """
20 |
21 | @property
22 | def input_types(self):
23 | return {
24 | "z": NeuralType(('B', 'D', 'T'), NormalDistributionSamplesType()),
25 | "y_m": NeuralType(('B', 'D', 'T'), NormalDistributionMeanType()),
26 | "y_logs": NeuralType(('B', 'D', 'T'), NormalDistributionLogVarianceType()),
27 | "logdet": NeuralType(('B',), LogDeterminantType()),
28 | "logw": NeuralType(('B', 'T'), TokenLogDurationType()),
29 | "logw_": NeuralType(('B', 'T'), TokenLogDurationType()),
30 | "x_lengths": NeuralType(('B',), LengthsType()),
31 | "y_lengths": NeuralType(('B',), LengthsType()),
32 | "stoch_dur_loss": NeuralType(optional=True),
33 | }
34 |
35 | @property
36 | def output_types(self):
37 | return {
38 | "l_mle": NeuralType(elements_type=LossType()),
39 | "l_length": NeuralType(elements_type=LossType()),
40 | "logdet": NeuralType(elements_type=VoidType()),
41 | }
42 |
43 | @typecheck()
44 | def forward(self, z, y_m, y_logs, logdet, logw, logw_, x_lengths, y_lengths, stoch_dur_loss,):
45 |
46 | logdet = torch.sum(logdet)
47 | l_mle = 0.5 * math.log(2 * math.pi) + (
48 | torch.sum(y_logs) + 0.5 * torch.sum(torch.exp(-2 * y_logs) * (z - y_m) ** 2) - logdet
49 | ) / (torch.sum(y_lengths) * z.shape[1])
50 |
51 | if stoch_dur_loss is None:
52 | l_length = torch.sum((logw - logw_) ** 2) / torch.sum(x_lengths)
53 | else:
54 | l_length = stoch_dur_loss
55 | return l_mle, l_length, logdet
--------------------------------------------------------------------------------
/utils/helpers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # BSD 3-Clause License
16 | #
17 | # Copyright (c) 2021, NVIDIA Corporation
18 | # All rights reserved.
19 | #
20 | # Redistribution and use in source and binary forms, with or without
21 | # modification, are permitted provided that the following conditions are met:
22 | #
23 | # * Redistributions of source code must retain the above copyright notice, this
24 | # list of conditions and the following disclaimer.
25 | #
26 | # * Redistributions in binary form must reproduce the above copyright notice,
27 | # this list of conditions and the following disclaimer in the documentation
28 | # and/or other materials provided with the distribution.
29 | #
30 | # * Neither the name of the copyright holder nor the names of its
31 | # contributors may be used to endorse or promote products derived from
32 | # this software without specific prior written permission.
33 | #
34 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
35 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
36 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
37 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
38 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
39 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
40 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
41 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
42 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
43 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
44 |
45 | from enum import Enum
46 | from typing import Dict, Optional, Sequence
47 |
48 | import pickle as pkl
49 | import json
50 |
51 | import librosa
52 | import matplotlib.pylab as plt
53 | import numpy as np
54 | import torch
55 | from numba import jit, prange
56 | from numpy import ndarray
57 | from pesq import pesq
58 | from pystoi import stoi
59 |
60 | from nemo.collections.tts.torch.tts_data_types import DATA_STR2DATA_CLASS, MAIN_DATA_TYPES, WithLens
61 | from nemo.utils import logging
62 |
63 | HAVE_WANDB = True
64 | try:
65 | import wandb
66 | except ModuleNotFoundError:
67 | HAVE_WANDB = False
68 |
69 | try:
70 | from pytorch_lightning.utilities import rank_zero_only
71 | except ModuleNotFoundError:
72 | from functools import wraps
73 |
74 | def rank_zero_only(fn):
75 | @wraps(fn)
76 | def wrapped_fn(*args, **kwargs):
77 | logging.error(
78 | f"Function {fn} requires lighting to be installed, but it was not found. Please install lightning first"
79 | )
80 | exit(1)
81 |
82 |
83 | class OperationMode(Enum):
84 | """Training or Inference (Evaluation) mode"""
85 |
86 | training = 0
87 | validation = 1
88 | infer = 2
89 |
90 |
91 | def get_batch_size(train_dataloader):
92 | if train_dataloader.batch_size is not None:
93 | return train_dataloader.batch_size
94 | elif train_dataloader.batch_sampler is not None:
95 | if train_dataloader.batch_sampler.micro_batch_size is not None:
96 | return train_dataloader.batch_sampler.micro_batch_size
97 | else:
98 | raise ValueError(f'Could not find batch_size from batch_sampler: {train_dataloader.batch_sampler}')
99 | else:
100 | raise ValueError(f'Could not find batch_size from train_dataloader: {train_dataloader}')
101 |
102 |
103 | def get_num_workers(trainer):
104 | return trainer.num_devices * trainer.num_nodes
105 |
106 |
107 | def binarize_attention(attn, in_len, out_len):
108 | """Convert soft attention matrix to hard attention matrix.
109 |
110 | Args:
111 | attn (torch.Tensor): B x 1 x max_mel_len x max_text_len. Soft attention matrix.
112 | in_len (torch.Tensor): B. Lengths of texts.
113 | out_len (torch.Tensor): B. Lengths of spectrograms.
114 |
115 | Output:
116 | attn_out (torch.Tensor): B x 1 x max_mel_len x max_text_len. Hard attention matrix, final dim max_text_len should sum to 1.
117 | """
118 | b_size = attn.shape[0]
119 | with torch.no_grad():
120 | attn_cpu = attn.data.cpu().numpy()
121 | attn_out = torch.zeros_like(attn)
122 | for ind in range(b_size):
123 | hard_attn = mas(attn_cpu[ind, 0, : out_len[ind], : in_len[ind]])
124 | attn_out[ind, 0, : out_len[ind], : in_len[ind]] = torch.tensor(hard_attn, device=attn.device)
125 | return attn_out
126 |
127 |
128 | def binarize_attention_parallel(attn, in_lens, out_lens):
129 | """For training purposes only. Binarizes attention with MAS.
130 | These will no longer receive a gradient.
131 |
132 | Args:
133 | attn: B x 1 x max_mel_len x max_text_len
134 | """
135 | with torch.no_grad():
136 | attn_cpu = attn.data.cpu().numpy()
137 | attn_out = b_mas(attn_cpu, in_lens.cpu().numpy(), out_lens.cpu().numpy(), width=1)
138 | return torch.from_numpy(attn_out).to(attn.device)
139 |
140 |
141 | def get_mask_from_lengths(lengths, max_len: Optional[int] = None):
142 | if max_len is None:
143 | max_len = lengths.max()
144 | ids = torch.arange(0, max_len, device=lengths.device, dtype=torch.long)
145 | mask = (ids < lengths.unsqueeze(1)).bool()
146 | return mask
147 |
148 |
149 | @jit(nopython=True)
150 | def mas(attn_map, width=1):
151 | # assumes mel x text
152 | opt = np.zeros_like(attn_map)
153 | attn_map = np.log(attn_map)
154 | attn_map[0, 1:] = -np.inf
155 | log_p = np.zeros_like(attn_map)
156 | log_p[0, :] = attn_map[0, :]
157 | prev_ind = np.zeros_like(attn_map, dtype=np.int64)
158 | for i in range(1, attn_map.shape[0]):
159 | for j in range(attn_map.shape[1]): # for each text dim
160 | prev_j = np.arange(max(0, j - width), j + 1)
161 | prev_log = np.array([log_p[i - 1, prev_idx] for prev_idx in prev_j])
162 |
163 | ind = np.argmax(prev_log)
164 | log_p[i, j] = attn_map[i, j] + prev_log[ind]
165 | prev_ind[i, j] = prev_j[ind]
166 |
167 | # now backtrack
168 | curr_text_idx = attn_map.shape[1] - 1
169 | for i in range(attn_map.shape[0] - 1, -1, -1):
170 | opt[i, curr_text_idx] = 1
171 | curr_text_idx = prev_ind[i, curr_text_idx]
172 | opt[0, curr_text_idx] = 1
173 |
174 | assert opt.sum(0).all()
175 | assert opt.sum(1).all()
176 |
177 | return opt
178 |
179 |
180 | @jit(nopython=True)
181 | def mas_width1(attn_map):
182 | """mas with hardcoded width=1"""
183 | # assumes mel x text
184 | opt = np.zeros_like(attn_map)
185 | attn_map = np.log(attn_map)
186 | attn_map[0, 1:] = -np.inf
187 | log_p = np.zeros_like(attn_map)
188 | log_p[0, :] = attn_map[0, :]
189 | prev_ind = np.zeros_like(attn_map, dtype=np.int64)
190 | for i in range(1, attn_map.shape[0]):
191 | for j in range(attn_map.shape[1]): # for each text dim
192 | prev_log = log_p[i - 1, j]
193 | prev_j = j
194 |
195 | if j - 1 >= 0 and log_p[i - 1, j - 1] >= log_p[i - 1, j]:
196 | prev_log = log_p[i - 1, j - 1]
197 | prev_j = j - 1
198 |
199 | log_p[i, j] = attn_map[i, j] + prev_log
200 | prev_ind[i, j] = prev_j
201 |
202 | # now backtrack
203 | curr_text_idx = attn_map.shape[1] - 1
204 | for i in range(attn_map.shape[0] - 1, -1, -1):
205 | opt[i, curr_text_idx] = 1
206 | curr_text_idx = prev_ind[i, curr_text_idx]
207 | opt[0, curr_text_idx] = 1
208 | return opt
209 |
210 |
211 | def b_mas(b_attn_map, in_lens, out_lens, width=1):
212 | assert width == 1
213 | attn_out = np.zeros_like(b_attn_map)
214 |
215 | for b in range(b_attn_map.shape[0]):
216 | out = mas_width1(b_attn_map[b, 0, : out_lens[b], : in_lens[b]])
217 | attn_out[b, 0, : out_lens[b], : in_lens[b]] = out
218 | return attn_out
219 |
220 | # @jit(nopython=True, parallel=True)
221 | # def b_mas(b_attn_map, in_lens, out_lens, width=1):
222 | # assert width == 1
223 | # attn_out = np.zeros_like(b_attn_map)
224 |
225 | # for b in prange(b_attn_map.shape[0]):
226 | # out = mas_width1(b_attn_map[b, 0, : out_lens[b], : in_lens[b]])
227 | # attn_out[b, 0, : out_lens[b], : in_lens[b]] = out
228 | # return attn_out
229 |
230 |
231 | def griffin_lim(magnitudes, n_iters=50, n_fft=1024):
232 | """
233 | Griffin-Lim algorithm to convert magnitude spectrograms to audio signals
234 | """
235 | phase = np.exp(2j * np.pi * np.random.rand(*magnitudes.shape))
236 | complex_spec = magnitudes * phase
237 | signal = librosa.istft(complex_spec)
238 | if not np.isfinite(signal).all():
239 | logging.warning("audio was not finite, skipping audio saving")
240 | return np.array([0])
241 |
242 | for _ in range(n_iters):
243 | _, phase = librosa.magphase(librosa.stft(signal, n_fft=n_fft))
244 | complex_spec = magnitudes * phase
245 | signal = librosa.istft(complex_spec)
246 | return signal
247 |
248 |
249 | @rank_zero_only
250 | def log_audio_to_tb(
251 | swriter,
252 | spect,
253 | name,
254 | step,
255 | griffin_lim_mag_scale=1024,
256 | griffin_lim_power=1.2,
257 | sr=22050,
258 | n_fft=1024,
259 | n_mels=80,
260 | fmax=8000,
261 | ):
262 | filterbank = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmax=fmax)
263 | log_mel = spect.data.cpu().numpy().T
264 | mel = np.exp(log_mel)
265 | magnitude = np.dot(mel, filterbank) * griffin_lim_mag_scale
266 | audio = griffin_lim(magnitude.T ** griffin_lim_power)
267 | swriter.add_audio(name, audio / max(np.abs(audio)), step, sample_rate=sr)
268 |
269 |
270 | @rank_zero_only
271 | def tacotron2_log_to_tb_func(
272 | swriter,
273 | tensors,
274 | step,
275 | tag="train",
276 | log_images=False,
277 | log_images_freq=1,
278 | add_audio=True,
279 | griffin_lim_mag_scale=1024,
280 | griffin_lim_power=1.2,
281 | sr=22050,
282 | n_fft=1024,
283 | n_mels=80,
284 | fmax=8000,
285 | ):
286 | _, spec_target, mel_postnet, gate, gate_target, alignments = tensors
287 | if log_images and step % log_images_freq == 0:
288 | swriter.add_image(
289 | f"{tag}_alignment", plot_alignment_to_numpy(alignments[0].data.cpu().numpy().T), step, dataformats="HWC",
290 | )
291 | swriter.add_image(
292 | f"{tag}_mel_target", plot_spectrogram_to_numpy(spec_target[0].data.cpu().numpy()), step, dataformats="HWC",
293 | )
294 | swriter.add_image(
295 | f"{tag}_mel_predicted",
296 | plot_spectrogram_to_numpy(mel_postnet[0].data.cpu().numpy()),
297 | step,
298 | dataformats="HWC",
299 | )
300 | swriter.add_image(
301 | f"{tag}_gate",
302 | plot_gate_outputs_to_numpy(gate_target[0].data.cpu().numpy(), torch.sigmoid(gate[0]).data.cpu().numpy(),),
303 | step,
304 | dataformats="HWC",
305 | )
306 |
307 | if add_audio:
308 | filterbank = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmax=fmax)
309 | log_mel = mel_postnet[0].data.cpu().numpy().T
310 | mel = np.exp(log_mel)
311 | magnitude = np.dot(mel, filterbank) * griffin_lim_mag_scale
312 | audio = griffin_lim(magnitude.T ** griffin_lim_power)
313 | swriter.add_audio(f"audio/{tag}_predicted", audio / max(np.abs(audio)), step, sample_rate=sr)
314 |
315 | log_mel = spec_target[0].data.cpu().numpy().T
316 | mel = np.exp(log_mel)
317 | magnitude = np.dot(mel, filterbank) * griffin_lim_mag_scale
318 | audio = griffin_lim(magnitude.T ** griffin_lim_power)
319 | swriter.add_audio(f"audio/{tag}_target", audio / max(np.abs(audio)), step, sample_rate=sr)
320 |
321 |
322 | def tacotron2_log_to_wandb_func(
323 | swriter,
324 | tensors,
325 | step,
326 | tag="train",
327 | log_images=False,
328 | log_images_freq=1,
329 | add_audio=True,
330 | griffin_lim_mag_scale=1024,
331 | griffin_lim_power=1.2,
332 | sr=22050,
333 | n_fft=1024,
334 | n_mels=80,
335 | fmax=8000,
336 | ):
337 | _, spec_target, mel_postnet, gate, gate_target, alignments = tensors
338 | if not HAVE_WANDB:
339 | return
340 | if log_images and step % log_images_freq == 0:
341 | alignments = []
342 | specs = []
343 | gates = []
344 | alignments += [
345 | wandb.Image(plot_alignment_to_numpy(alignments[0].data.cpu().numpy().T), caption=f"{tag}_alignment",)
346 | ]
347 | alignments += [
348 | wandb.Image(plot_spectrogram_to_numpy(spec_target[0].data.cpu().numpy()), caption=f"{tag}_mel_target",),
349 | wandb.Image(plot_spectrogram_to_numpy(mel_postnet[0].data.cpu().numpy()), caption=f"{tag}_mel_predicted",),
350 | ]
351 | gates += [
352 | wandb.Image(
353 | plot_gate_outputs_to_numpy(
354 | gate_target[0].data.cpu().numpy(), torch.sigmoid(gate[0]).data.cpu().numpy(),
355 | ),
356 | caption=f"{tag}_gate",
357 | )
358 | ]
359 |
360 | swriter.log({"specs": specs, "alignments": alignments, "gates": gates})
361 |
362 | if add_audio:
363 | audios = []
364 | filterbank = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmax=fmax)
365 | log_mel = mel_postnet[0].data.cpu().numpy().T
366 | mel = np.exp(log_mel)
367 | magnitude = np.dot(mel, filterbank) * griffin_lim_mag_scale
368 | audio_pred = griffin_lim(magnitude.T ** griffin_lim_power)
369 |
370 | log_mel = spec_target[0].data.cpu().numpy().T
371 | mel = np.exp(log_mel)
372 | magnitude = np.dot(mel, filterbank) * griffin_lim_mag_scale
373 | audio_true = griffin_lim(magnitude.T ** griffin_lim_power)
374 |
375 | audios += [
376 | wandb.Audio(audio_true / max(np.abs(audio_true)), caption=f"{tag}_wav_target", sample_rate=sr,),
377 | wandb.Audio(audio_pred / max(np.abs(audio_pred)), caption=f"{tag}_wav_predicted", sample_rate=sr,),
378 | ]
379 |
380 | swriter.log({"audios": audios})
381 |
382 |
383 | def plot_alignment_to_numpy(alignment, title='', info=None, phoneme_seq=None, vmin=None, vmax=None):
384 | if phoneme_seq:
385 | fig, ax = plt.subplots(figsize=(15, 10))
386 | else:
387 | fig, ax = plt.subplots(figsize=(6, 4))
388 | im = ax.imshow(alignment, aspect='auto', origin='lower', interpolation='none', vmin=vmin, vmax=vmax)
389 | ax.set_title(title)
390 | fig.colorbar(im, ax=ax)
391 | xlabel = 'Decoder timestep'
392 | if info is not None:
393 | xlabel += '\n\n' + info
394 | plt.xlabel(xlabel)
395 | plt.ylabel('Encoder timestep')
396 | plt.tight_layout()
397 |
398 | if phoneme_seq != None:
399 | # for debugging of phonemes and durs in maps. Not used by def in training code
400 | ax.set_yticks(np.arange(len(phoneme_seq)))
401 | ax.set_yticklabels(phoneme_seq)
402 | ax.hlines(np.arange(len(phoneme_seq)), xmin=0.0, xmax=max(ax.get_xticks()))
403 |
404 | fig.canvas.draw()
405 | data = save_figure_to_numpy(fig)
406 | plt.close()
407 | return data
408 |
409 |
410 | def plot_pitch_to_numpy(pitch, ylim_range=None):
411 | fig, ax = plt.subplots(figsize=(12, 3))
412 | plt.plot(pitch)
413 | if ylim_range is not None:
414 | plt.ylim(ylim_range)
415 | plt.xlabel("Frames")
416 | plt.ylabel("Pitch")
417 | plt.tight_layout()
418 |
419 | fig.canvas.draw()
420 | data = save_figure_to_numpy(fig)
421 | plt.close()
422 | return data
423 |
424 |
425 | def plot_spectrogram_to_numpy(spectrogram):
426 | spectrogram = spectrogram.astype(np.float32)
427 | fig, ax = plt.subplots(figsize=(12, 3))
428 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation='none')
429 | plt.colorbar(im, ax=ax)
430 | plt.xlabel("Frames")
431 | plt.ylabel("Channels")
432 | plt.tight_layout()
433 |
434 | fig.canvas.draw()
435 | data = save_figure_to_numpy(fig)
436 | plt.close()
437 | return data
438 |
439 |
440 | def plot_gate_outputs_to_numpy(gate_targets, gate_outputs):
441 | fig, ax = plt.subplots(figsize=(12, 3))
442 | ax.scatter(
443 | range(len(gate_targets)), gate_targets, alpha=0.5, color='green', marker='+', s=1, label='target',
444 | )
445 | ax.scatter(
446 | range(len(gate_outputs)), gate_outputs, alpha=0.5, color='red', marker='.', s=1, label='predicted',
447 | )
448 |
449 | plt.xlabel("Frames (Green target, Red predicted)")
450 | plt.ylabel("Gate State")
451 | plt.tight_layout()
452 |
453 | fig.canvas.draw()
454 | data = save_figure_to_numpy(fig)
455 | plt.close()
456 | return data
457 |
458 |
459 | def save_figure_to_numpy(fig):
460 | # save it to a numpy array.
461 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
462 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
463 | return data
464 |
465 |
466 | @rank_zero_only
467 | def waveglow_log_to_tb_func(
468 | swriter, tensors, step, tag="train", n_fft=1024, hop_length=256, window="hann", mel_fb=None,
469 | ):
470 | _, audio_pred, spec_target, mel_length = tensors
471 | mel_length = mel_length[0]
472 | spec_target = spec_target[0].data.cpu().numpy()[:, :mel_length]
473 | swriter.add_image(
474 | f"{tag}_mel_target", plot_spectrogram_to_numpy(spec_target), step, dataformats="HWC",
475 | )
476 | if mel_fb is not None:
477 | mag, _ = librosa.core.magphase(
478 | librosa.core.stft(
479 | np.nan_to_num(audio_pred[0].cpu().detach().numpy()), n_fft=n_fft, hop_length=hop_length, window=window,
480 | )
481 | )
482 | mel_pred = np.matmul(mel_fb.cpu().numpy(), mag).squeeze()
483 | log_mel_pred = np.log(np.clip(mel_pred, a_min=1e-5, a_max=None))
484 | swriter.add_image(
485 | f"{tag}_mel_predicted", plot_spectrogram_to_numpy(log_mel_pred[:, :mel_length]), step, dataformats="HWC",
486 | )
487 |
488 |
489 | def remove(conv_list):
490 | new_conv_list = torch.nn.ModuleList()
491 | for old_conv in conv_list:
492 | old_conv = torch.nn.utils.remove_weight_norm(old_conv)
493 | new_conv_list.append(old_conv)
494 | return new_conv_list
495 |
496 |
497 | def eval_tts_scores(
498 | y_clean: ndarray, y_est: ndarray, T_ys: Sequence[int] = (0,), sampling_rate=22050
499 | ) -> Dict[str, float]:
500 | """
501 | calculate metric using EvalModule. y can be a batch.
502 | Args:
503 | y_clean: real audio
504 | y_est: estimated audio
505 | T_ys: length of the non-zero parts of the histograms
506 | sampling_rate: The used Sampling rate.
507 |
508 | Returns:
509 | A dictionary mapping scoring systems (string) to numerical scores.
510 | 1st entry: 'STOI'
511 | 2nd entry: 'PESQ'
512 | """
513 |
514 | if y_clean.ndim == 1:
515 | y_clean = y_clean[np.newaxis, ...]
516 | y_est = y_est[np.newaxis, ...]
517 | if T_ys == (0,):
518 | T_ys = (y_clean.shape[1],) * y_clean.shape[0]
519 |
520 | clean = y_clean[0, : T_ys[0]]
521 | estimated = y_est[0, : T_ys[0]]
522 | stoi_score = stoi(clean, estimated, sampling_rate, extended=False)
523 | pesq_score = pesq(16000, np.asarray(clean), estimated, 'wb')
524 | ## fs was set 16,000, as pesq lib doesnt currently support felxible fs.
525 |
526 | return {'STOI': stoi_score, 'PESQ': pesq_score}
527 |
528 |
529 | def regulate_len(durations, enc_out, pace=1.0, mel_max_len=None):
530 | """A function that takes predicted durations per encoded token, and repeats enc_out according to the duration.
531 | NOTE: durations.shape[1] == enc_out.shape[1]
532 |
533 | Args:
534 | durations (torch.tensor): A tensor of shape (batch x enc_length) that represents how many times to repeat each
535 | token in enc_out.
536 | enc_out (torch.tensor): A tensor of shape (batch x enc_length x enc_hidden) that represents the encoded tokens.
537 | pace (float): The pace of speaker. Higher values result in faster speaking pace. Defaults to 1.0.
538 | max_mel_len (int): The maximum length above which the output will be removed. If sum(durations, dim=1) >
539 | max_mel_len, the values after max_mel_len will be removed. Defaults to None, which has no max length.
540 | """
541 |
542 | dtype = enc_out.dtype
543 | reps = durations.float() / pace
544 | reps = (reps + 0.5).floor().long()
545 | dec_lens = reps.sum(dim=1)
546 |
547 | max_len = dec_lens.max()
548 | reps_cumsum = torch.cumsum(torch.nn.functional.pad(reps, (1, 0, 0, 0), value=0.0), dim=1)[:, None, :]
549 | reps_cumsum = reps_cumsum.to(dtype=dtype, device=enc_out.device)
550 |
551 | range_ = torch.arange(max_len).to(enc_out.device)[None, :, None]
552 | mult = (reps_cumsum[:, :, :-1] <= range_) & (reps_cumsum[:, :, 1:] > range_)
553 | mult = mult.to(dtype)
554 | enc_rep = torch.matmul(mult, enc_out)
555 |
556 | if mel_max_len is not None:
557 | enc_rep = enc_rep[:, :mel_max_len]
558 | dec_lens = torch.clamp_max(dec_lens, mel_max_len)
559 |
560 | return enc_rep, dec_lens
561 |
562 |
563 | def split_view(tensor, split_size: int, dim: int = 0):
564 | if dim < 0: # Support negative indexing
565 | dim = len(tensor.shape) + dim
566 | # If not divisible by split_size, we need to pad with 0
567 | if tensor.shape[dim] % split_size != 0:
568 | to_pad = split_size - (tensor.shape[dim] % split_size)
569 | padding = [0] * len(tensor.shape) * 2
570 | padding[dim * 2 + 1] = to_pad
571 | padding.reverse()
572 | tensor = torch.nn.functional.pad(tensor, padding)
573 | cur_shape = tensor.shape
574 | new_shape = cur_shape[:dim] + (tensor.shape[dim] // split_size, split_size) + cur_shape[dim + 1 :]
575 | return tensor.reshape(*new_shape)
576 |
577 |
578 | def process_batch(batch_data, sup_data_types_set):
579 | batch_dict = {}
580 | batch_index = 0
581 | for name, datatype in DATA_STR2DATA_CLASS.items():
582 | if datatype in MAIN_DATA_TYPES or datatype in sup_data_types_set:
583 | batch_dict[name] = batch_data[batch_index]
584 | batch_index = batch_index + 1
585 | if issubclass(datatype, WithLens):
586 | batch_dict[name + "_lens"] = batch_data[batch_index]
587 | batch_index = batch_index + 1
588 | return batch_dict
589 |
590 |
--------------------------------------------------------------------------------
/utils/tts_data_types.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | class TTSDataType:
17 | """Represent TTSDataType."""
18 |
19 | name = None
20 |
21 |
22 | class WithLens:
23 | """Represent that this data type also returns lengths for data."""
24 |
25 |
26 | class Audio(TTSDataType, WithLens):
27 | name = "audio"
28 |
29 |
30 | class Text(TTSDataType, WithLens):
31 | name = "text"
32 |
33 |
34 | class LogMel(TTSDataType, WithLens):
35 | name = "log_mel"
36 |
37 |
38 | class Durations(TTSDataType):
39 | name = "durations"
40 |
41 |
42 | class AlignPriorMatrix(TTSDataType):
43 | name = "align_prior_matrix"
44 |
45 |
46 | class Pitch(TTSDataType, WithLens):
47 | name = "pitch"
48 |
49 |
50 | class Energy(TTSDataType, WithLens):
51 | name = "energy"
52 |
53 |
54 | class SpeakerID(TTSDataType):
55 | name = "speaker_id"
56 |
57 |
58 | class SpeakerEmb(TTSDataType):
59 | name = "speaker_emb"
60 |
61 |
62 | class Voiced_mask(TTSDataType):
63 | name = "voiced_mask"
64 |
65 |
66 | class P_voiced(TTSDataType):
67 | name = "p_voiced"
68 |
69 |
70 | class LMTokens(TTSDataType):
71 | name = "lm_tokens"
72 |
73 |
74 | MAIN_DATA_TYPES = [Audio, Text]
75 | VALID_SUPPLEMENTARY_DATA_TYPES = [
76 | LogMel,
77 | Durations,
78 | AlignPriorMatrix,
79 | Pitch,
80 | Energy,
81 | SpeakerID,
82 | SpeakerEmb,
83 | LMTokens,
84 | Voiced_mask,
85 | P_voiced,
86 | ]
87 | DATA_STR2DATA_CLASS = {d.name: d for d in MAIN_DATA_TYPES + VALID_SUPPLEMENTARY_DATA_TYPES}
88 |
--------------------------------------------------------------------------------
/utils/tts_tokenizers.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import itertools
17 | import string
18 | from abc import ABC, abstractmethod
19 | from contextlib import contextmanager
20 | from typing import List, Optional
21 |
22 | from nemo_text_processing.g2p.data.data_utils import english_text_preprocessing
23 |
24 | from nemo.utils import logging
25 |
26 |
27 | class BaseTokenizer(ABC):
28 | PAD, BLANK, OOV = '', '', ''
29 |
30 | def __init__(self, tokens, *, pad=PAD, blank=BLANK, oov=OOV, sep='', add_blank_at=None, ):
31 | """Abstract class for creating an arbitrary tokenizer to convert string to list of int tokens.
32 | Args:
33 | tokens: List of tokens.
34 | pad: Pad token as string.
35 | blank: Blank token as string.
36 | oov: OOV token as string.
37 | sep: Separation token as string.
38 | add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None),
39 | if None then no blank in labels.
40 | """
41 | super().__init__()
42 |
43 | tokens = list(tokens)
44 | self.pad, tokens = len(tokens), tokens + [pad] # Padding
45 |
46 | if add_blank_at is not None:
47 | self.blank, tokens = len(tokens), tokens + [blank] # Reserved for blank from asr-model
48 | else:
49 | # use add_blank_at=None only for ASR where blank is added automatically, disable blank here
50 | self.blank = None
51 |
52 | self.oov, tokens = len(tokens), tokens + [oov] # Out Of Vocabulary
53 |
54 | if add_blank_at == "last":
55 | tokens[-1], tokens[-2] = tokens[-2], tokens[-1]
56 | self.oov, self.blank = self.blank, self.oov
57 |
58 | self.tokens = tokens
59 | self.sep = sep
60 |
61 | self._util_ids = {self.pad, self.blank, self.oov}
62 | self._token2id = {l: i for i, l in enumerate(tokens)}
63 | self._id2token = tokens
64 |
65 | def __call__(self, text: str) -> List[int]:
66 | return self.encode(text)
67 |
68 | @abstractmethod
69 | def encode(self, text: str) -> List[int]:
70 | """Turns str text into int tokens."""
71 | pass
72 |
73 | def decode(self, tokens: List[int]) -> str:
74 | """Turns ints tokens into str text."""
75 | return self.sep.join(self._id2token[t] for t in tokens if t not in self._util_ids)
76 |
77 | def intersperse(self, token, item):
78 | result = [item] * (len(token) * 2 + 1)
79 | result[1::2] = token
80 | return result
81 |
82 |
83 | class BaseCharsTokenizer(BaseTokenizer):
84 | # fmt: off
85 | PUNCT_LIST = ( # Derived from LJSpeech and "/" additionally
86 | ',', '.', '!', '?', '-',
87 | ':', ';', '/', '"', '(',
88 | ')', '[', ']', '{', '}',
89 | )
90 | # fmt: on
91 |
92 | def __init__(
93 | self,
94 | chars,
95 | punct=True,
96 | apostrophe=True,
97 | add_blank_at=None,
98 | pad_with_space=False,
99 | non_default_punct_list=None,
100 | text_preprocessing_func=lambda x: x,
101 | add_blank_to_text=False,
102 | ):
103 | """Base class for char-based tokenizer.
104 | Args:
105 | chars: string that represents all possible characters.
106 | punct: Whether to reserve grapheme for basic punctuation or not.
107 | apostrophe: Whether to use apostrophe or not.
108 | add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None),
109 | if None then no blank in labels.
110 | pad_with_space: Whether to pad text with spaces at the beginning and at the end or not.
111 | non_default_punct_list: List of punctuation marks which will be used instead default.
112 | text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer.
113 | """
114 |
115 | tokens = []
116 | self.space, tokens = len(tokens), tokens + [' '] # Space
117 | tokens.extend(chars)
118 | if apostrophe:
119 | tokens.append("'") # Apostrophe for saving "don't" and "Joe's"
120 |
121 | if punct:
122 | if non_default_punct_list is not None:
123 | self.PUNCT_LIST = non_default_punct_list
124 | tokens.extend(self.PUNCT_LIST)
125 |
126 | add_blank_at = "last" if add_blank_to_text else add_blank_at
127 | self.add_blank_to_text = add_blank_to_text
128 |
129 | super().__init__(tokens, add_blank_at=add_blank_at)
130 |
131 | self.punct = punct
132 | self.pad_with_space = pad_with_space
133 |
134 | self.text_preprocessing_func = text_preprocessing_func
135 |
136 | def encode(self, text):
137 | """See base class."""
138 | cs, space, tokens = [], self.tokens[self.space], set(self.tokens)
139 |
140 | text = self.text_preprocessing_func(text)
141 | for c in text:
142 | # Add space if last one isn't one
143 | if c == space and len(cs) > 0 and cs[-1] != space:
144 | cs.append(c)
145 | # Add next char
146 | elif (c.isalnum() or c == "'") and c in tokens:
147 | cs.append(c)
148 | # Add punct
149 | elif (c in self.PUNCT_LIST) and self.punct:
150 | cs.append(c)
151 | # Warn about unknown char
152 | elif c != space:
153 | logging.warning(f"Text: [{text}] contains unknown char: [{c}]. Symbol will be skipped.")
154 |
155 | # Remove trailing spaces
156 | if cs:
157 | while cs[-1] == space:
158 | cs.pop()
159 |
160 | if self.add_blank_to_text:
161 | blank = self._id2token[self.blank]
162 | cs = self.intersperse(cs, blank)
163 | elif self.pad_with_space:
164 | cs = [space] + cs + [space]
165 |
166 | return [self._token2id[p] for p in cs]
167 |
168 |
169 | class EnglishPhonemesTokenizer(BaseTokenizer):
170 | # fmt: off
171 | PUNCT_LIST = ( # Derived from LJSpeech and "/" additionally
172 | ',', '.', '!', '?', '-',
173 | ':', ';', '/', '"', '(',
174 | ')', '[', ']', '{', '}',
175 | )
176 | VOWELS = (
177 | 'AA', 'AE', 'AH', 'AO', 'AW',
178 | 'AY', 'EH', 'ER', 'EY', 'IH',
179 | 'IY', 'OW', 'OY', 'UH', 'UW',
180 | )
181 | CONSONANTS = (
182 | 'B', 'CH', 'D', 'DH', 'F', 'G',
183 | 'HH', 'JH', 'K', 'L', 'M', 'N',
184 | 'NG', 'P', 'R', 'S', 'SH', 'T',
185 | 'TH', 'V', 'W', 'Y', 'Z', 'ZH',
186 | )
187 | # fmt: on
188 |
189 | def __init__(
190 | self,
191 | g2p,
192 | punct=True,
193 | non_default_punct_list=None,
194 | stresses=False,
195 | chars=False,
196 | *,
197 | space=' ',
198 | silence=None,
199 | apostrophe=True,
200 | oov=BaseTokenizer.OOV,
201 | sep='|', # To be able to distinguish between 2/3 letters codes.
202 | add_blank_at=None,
203 | pad_with_space=False,
204 | text_preprocessing_func=lambda text: english_text_preprocessing(text, lower=False),
205 | add_blank_to_text=False,
206 | ):
207 | """English phoneme-based tokenizer.
208 | Args:
209 | g2p: Grapheme to phoneme module.
210 | punct: Whether to reserve grapheme for basic punctuation or not.
211 | non_default_punct_list: List of punctuation marks which will be used instead default.
212 | stresses: Whether to use phonemes codes with stresses (0-2) or not.
213 | chars: Whether to additionally use chars together with phonemes. It is useful if g2p module can return chars too.
214 | space: Space token as string.
215 | silence: Silence token as string (will be disabled if it is None).
216 | apostrophe: Whether to use apostrophe or not.
217 | oov: OOV token as string.
218 | sep: Separation token as string.
219 | add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None),
220 | if None then no blank in labels.
221 | pad_with_space: Whether to pad text with spaces at the beginning and at the end or not.
222 | text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer.
223 | Basically, it replaces all non-unicode characters with unicode ones.
224 | Note that lower() function shouldn't applied here, in case the text contains phonemes (it will be handled by g2p).
225 | """
226 |
227 | self.phoneme_probability = None
228 | if hasattr(g2p, "phoneme_probability"):
229 | self.phoneme_probability = g2p.phoneme_probability
230 | tokens = []
231 | self.space, tokens = len(tokens), tokens + [space] # Space
232 |
233 | if silence is not None:
234 | self.silence, tokens = len(tokens), tokens + [silence] # Silence
235 |
236 | tokens.extend(self.CONSONANTS)
237 | vowels = list(self.VOWELS)
238 |
239 | if stresses:
240 | vowels = [f'{p}{s}' for p, s in itertools.product(vowels, (0, 1, 2))]
241 | tokens.extend(vowels)
242 |
243 | if chars or self.phoneme_probability is not None:
244 | if not chars:
245 | logging.warning(
246 | "phoneme_probability was not None, characters will be enabled even though "
247 | "chars was set to False."
248 | )
249 | tokens.extend(string.ascii_lowercase)
250 |
251 | if apostrophe:
252 | tokens.append("'") # Apostrophe
253 |
254 | if punct:
255 | if non_default_punct_list is not None:
256 | self.PUNCT_LIST = non_default_punct_list
257 | tokens.extend(self.PUNCT_LIST)
258 |
259 | add_blank_at = "last" if add_blank_to_text else add_blank_at
260 | self.add_blank_to_text = add_blank_to_text
261 |
262 | super().__init__(tokens, oov=oov, sep=sep, add_blank_at=add_blank_at)
263 |
264 | self.chars = chars if self.phoneme_probability is None else True
265 | self.punct = punct
266 | self.stresses = stresses
267 | self.pad_with_space = pad_with_space
268 |
269 | self.text_preprocessing_func = text_preprocessing_func
270 | self.g2p = g2p
271 |
272 | def encode(self, text):
273 | """See base class for more information."""
274 |
275 | text = self.text_preprocessing_func(text)
276 | g2p_text = self.g2p(text) # TODO: handle infer
277 | return self.encode_from_g2p(g2p_text, text)
278 |
279 | def encode_from_g2p(self, g2p_text: List[str], raw_text: Optional[str] = None):
280 | """
281 | Encodes text that has already been run through G2P.
282 | Called for encoding to tokens after text preprocessing and G2P.
283 |
284 | Args:
285 | g2p_text: G2P's output, could be a mixture of phonemes and graphemes,
286 | e.g. "see OOV" -> ['S', 'IY1', ' ', 'O', 'O', 'V']
287 | raw_text: original raw input
288 | """
289 | ps, space, tokens = [], self.tokens[self.space], set(self.tokens)
290 | for p in g2p_text: # noqa
291 | # Remove stress
292 | if p.isalnum() and len(p) == 3 and not self.stresses:
293 | p = p[:2]
294 |
295 | # Add space if last one isn't one
296 | if p == space and len(ps) > 0 and ps[-1] != space:
297 | ps.append(p)
298 | # Add next phoneme or char (if chars=True)
299 | elif (p.isalnum() or p == "'") and p in tokens:
300 | ps.append(p)
301 | # Add punct
302 | elif (p in self.PUNCT_LIST) and self.punct:
303 | ps.append(p)
304 | # Warn about unknown char/phoneme
305 | elif p != space:
306 | pass
307 | # message = f"Text: [{''.join(g2p_text)}] contains unknown char/phoneme: [{p}]."
308 | # if raw_text is not None:
309 | # message += f"Original text: [{raw_text}]. Symbol will be skipped."
310 | # logging.warning(message)
311 |
312 | # Remove trailing spaces
313 | if ps:
314 | while ps[-1] == space:
315 | ps.pop()
316 |
317 | if self.add_blank_to_text:
318 | blank = self._id2token[self.blank]
319 | ps = self.intersperse(ps, blank)
320 | elif self.pad_with_space:
321 | ps = [space] + ps + [space]
322 | # if self.pad_with_space:
323 | # ps = [space] + ps + [space]
324 |
325 | return [self._token2id[p] for p in ps]
326 |
327 | @contextmanager
328 | def set_phone_prob(self, prob):
329 | if hasattr(self.g2p, "phoneme_probability"):
330 | self.g2p.phoneme_probability = prob
331 | try:
332 | yield
333 | finally:
334 | if hasattr(self.g2p, "phoneme_probability"):
335 | self.g2p.phoneme_probability = self.phoneme_probability
336 |
337 |
--------------------------------------------------------------------------------